From 8af44e8f67e3a66730d645232aaa69cdc56c84c8 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Sun, 15 Feb 2026 19:13:49 +0900 Subject: [PATCH 01/24] Add cpu monitoring and memory usage tools. --- src/toolviper/utils/profile.py | 102 +++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 src/toolviper/utils/profile.py diff --git a/src/toolviper/utils/profile.py b/src/toolviper/utils/profile.py new file mode 100644 index 0000000..9ba0410 --- /dev/null +++ b/src/toolviper/utils/profile.py @@ -0,0 +1,102 @@ +import tracemalloc +import uuid +import csv +import functools +import multiprocessing +import time +import psutil + +import toolviper.utils.logger as logger + + +def cpu_usage_(stop_event, filename): + if filename is None: + filename = f"cpu_usage_{uuid.uuid4()}.csv" + + with open(filename, "w") as csvfile: + number_of_cores = psutil.cpu_count(logical=True) + + core_list = [f"c{core}" for core in range(number_of_cores)] + writer = csv.writer(csvfile, delimiter=",", lineterminator="\n") + writer.writerow(core_list) + while not stop_event.is_set(): + usage = psutil.cpu_percent(percpu=True, interval=1) + writer.writerow(usage) + + +def monitor(filename=None): + def function_wrapper(function): + @functools.wraps(function) + def wrapper(*args, **kwargs): + stop_event = multiprocessing.Event() + + monitor_process = multiprocessing.Process( + target=cpu_usage_, args=(stop_event, filename) + ) + monitor_process.start() + + time.sleep(1) + + try: + results = function(*args, **kwargs) + finally: + stop_event.set() + monitor_process.join(timeout=1) + monitor_process.terminate() + + return results + + return wrapper + + return function_wrapper + + +# Not for production. Yet. +def memory(): + def decorator(function): + @functools.wraps(function) + def wrapper(*args, **kwargs): + import csv + + logger.debug(f"start memory profiling on function {function.__name__}") + + tracemalloc.start() + result = function(*args, **kwargs) + snapshot = tracemalloc.take_snapshot() + snapshot = snapshot.filter_traces( + ( + tracemalloc.Filter(False, ""), + tracemalloc.Filter(False, ""), + ) + ) + stats = snapshot.statistics("lineno") + record = [] + for index, stat in enumerate(stats, 1): + frame = stat.traceback[0] + record.append( + { + "index": index, + "filename": frame.filename, + "lineno": frame.lineno, + "memory": int(stat.size / 1024), + } + ) + + tracemalloc.stop() + field_names = ["index", "filename", "lineno", "memory"] + + with open( + f"memory_usage_{function.__name__}.csv", + "w", + newline="", + encoding="utf-8", + ) as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=field_names) + writer.writeheader() + writer.writerows(record) + + return result + + return wrapper + + return decorator From 9410515792f720fc2b167cc174d8c40de66d80e2 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Fri, 20 Feb 2026 17:18:54 +0900 Subject: [PATCH 02/24] Add cpu monitoring and memory usage tools. --- src/toolviper/utils/__init__.py | 1 + src/toolviper/utils/data/cloudflare.py | 8 ++++---- src/toolviper/utils/display.py | 25 ++++++++++++++++++++++++- src/toolviper/utils/profile.py | 7 +++---- 4 files changed, 32 insertions(+), 9 deletions(-) diff --git a/src/toolviper/utils/__init__.py b/src/toolviper/utils/__init__.py index e7f5a5d..b34a5c3 100644 --- a/src/toolviper/utils/__init__.py +++ b/src/toolviper/utils/__init__.py @@ -3,6 +3,7 @@ from .protego import Protego from .logger import info, debug, warning, error, critical, get_logger, setup_logger from .tools import open_json, calculate_checksum, verify, add_entry +from .profile import memory_usage, cpu_usage from .data import download diff --git a/src/toolviper/utils/data/cloudflare.py b/src/toolviper/utils/data/cloudflare.py index dfbe1b6..49c18aa 100644 --- a/src/toolviper/utils/data/cloudflare.py +++ b/src/toolviper/utils/data/cloudflare.py @@ -68,7 +68,7 @@ def download( No return """ - logger.info("Downloading from [cloudflare] ....") + logger.info("Downloading from ....") if not isinstance(file, list): file = [file] @@ -86,7 +86,7 @@ def download( ) pathlib.Path(folder).resolve().mkdir() - logger.debug(f"Initializing [cloudflare] downloader ...") + logger.debug(f"Initializing downloader ...") meta_data_path = pathlib.Path(__file__).parent.joinpath( ".cloudflare/file.download.json" @@ -426,7 +426,7 @@ def _print_file_queue(files: list) -> None: assert type(files) == list - console = Console() + console_ = Console() table = Table(show_header=True, box=box.SIMPLE) table.add_column("Download List", justify="left") @@ -434,7 +434,7 @@ def _print_file_queue(files: list) -> None: for file in files: table.add_row(f"[magenta]{file}[/magenta]") - console.print(table) + console_.print(table) def _make_dir(path, folder): diff --git a/src/toolviper/utils/display.py b/src/toolviper/utils/display.py index f0c9932..0fe53e6 100755 --- a/src/toolviper/utils/display.py +++ b/src/toolviper/utils/display.py @@ -1,5 +1,6 @@ +import rich + def dict_to_html(d, indent=0): - from IPython.display import HTML html = "" for key, value in d.items(): @@ -8,3 +9,25 @@ def dict_to_html(d, indent=0): else: html += f"
{key}: {value}
" return html + +class DisplayDict(dict): + def __init__(self, dictionary): + super().__init__() + self._dict = dictionary + + def __getattr__(self, key): + return self[key] + + def __setattr__(self, key, value): + self[key] = value + + def display(self): + import rich + from toolviper.utils.parameter import is_notebook + + if is_notebook(): + from IPython.display import JSON + print("notebook") + return JSON(self._dict) + + return rich.print_json(data=self._dict) \ No newline at end of file diff --git a/src/toolviper/utils/profile.py b/src/toolviper/utils/profile.py index 9ba0410..60bd959 100644 --- a/src/toolviper/utils/profile.py +++ b/src/toolviper/utils/profile.py @@ -8,7 +8,6 @@ import toolviper.utils.logger as logger - def cpu_usage_(stop_event, filename): if filename is None: filename = f"cpu_usage_{uuid.uuid4()}.csv" @@ -24,7 +23,7 @@ def cpu_usage_(stop_event, filename): writer.writerow(usage) -def monitor(filename=None): +def cpu_usage(filename=None): def function_wrapper(function): @functools.wraps(function) def wrapper(*args, **kwargs): @@ -52,7 +51,7 @@ def wrapper(*args, **kwargs): # Not for production. Yet. -def memory(): +def memory_usage(): def decorator(function): @functools.wraps(function) def wrapper(*args, **kwargs): @@ -86,7 +85,7 @@ def wrapper(*args, **kwargs): field_names = ["index", "filename", "lineno", "memory"] with open( - f"memory_usage_{function.__name__}.csv", + f"memory_usage_{function.__name__}_{uuid.uuid4()}.csv", "w", newline="", encoding="utf-8", From 1e9eb7310d6b2fb3b39083ea43cb8609e4da0a24 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Sat, 21 Feb 2026 00:16:26 +0900 Subject: [PATCH 03/24] Add DisplayDict --- src/toolviper/utils/display.py | 37 +++++++++++++++++++++++++++------- src/toolviper/utils/profile.py | 1 + 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/toolviper/utils/display.py b/src/toolviper/utils/display.py index 0fe53e6..6c45df1 100755 --- a/src/toolviper/utils/display.py +++ b/src/toolviper/utils/display.py @@ -1,4 +1,6 @@ -import rich +import re +import operator + def dict_to_html(d, indent=0): @@ -10,16 +12,37 @@ def dict_to_html(d, indent=0): html += f"
{key}: {value}
" return html + class DisplayDict(dict): def __init__(self, dictionary): super().__init__() self._dict = dictionary - def __getattr__(self, key): - return self[key] + @classmethod + def from_dict(cls, dictionary): + if isinstance(dictionary, dict): + return cls(dictionary) + + return None + + def fetch(self, keys): + return list(operator.itemgetter(*keys)(self._dict)) - def __setattr__(self, key, value): - self[key] = value + def get_entries_(self, keys): + return {key: value for key, value in self._dict.items() if key in keys} + + def filter(self, query): + if isinstance(query, list): + _result = self.get_entries_(query) + return DisplayDict.from_dict(_result) + + if isinstance(query, str): + _result = { + key: value for key, value in self._dict.items() if re.search(query, key) + } + return self.from_dict(_result) + + return None def display(self): import rich @@ -27,7 +50,7 @@ def display(self): if is_notebook(): from IPython.display import JSON - print("notebook") + return JSON(self._dict) - return rich.print_json(data=self._dict) \ No newline at end of file + return rich.print_json(data=self._dict) diff --git a/src/toolviper/utils/profile.py b/src/toolviper/utils/profile.py index 60bd959..57a7713 100644 --- a/src/toolviper/utils/profile.py +++ b/src/toolviper/utils/profile.py @@ -8,6 +8,7 @@ import toolviper.utils.logger as logger + def cpu_usage_(stop_event, filename): if filename is None: filename = f"cpu_usage_{uuid.uuid4()}.csv" From 71935bcb126cde007d74102b23c73238ce2b0d07 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Tue, 24 Feb 2026 15:52:51 +0900 Subject: [PATCH 04/24] Add additional feature. --- src/toolviper/utils/display.py | 51 +++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/src/toolviper/utils/display.py b/src/toolviper/utils/display.py index 6c45df1..9396097 100755 --- a/src/toolviper/utils/display.py +++ b/src/toolviper/utils/display.py @@ -1,15 +1,17 @@ import re import operator - +from IPython.core.display import HTML def dict_to_html(d, indent=0): + print(f"THIS FUNCTION WILL BE DEPRECATED SOON") html = "" for key, value in d.items(): if isinstance(value, dict): html += f"
{key}{dict_to_html(value, indent + 1)}
" else: - html += f"
{key}: {value}
" + html += f"
{key}: {value}
" + return html @@ -18,6 +20,9 @@ def __init__(self, dictionary): super().__init__() self._dict = dictionary + def __repr__(self, *args, **kwargs): + return f"" + @classmethod def from_dict(cls, dictionary): if isinstance(dictionary, dict): @@ -25,32 +30,58 @@ def from_dict(cls, dictionary): return None - def fetch(self, keys): - return list(operator.itemgetter(*keys)(self._dict)) + @property + def data(self): + return self._dict + + def select(self, keys, in_place=False): + _result = list(operator.itemgetter(*keys)(self._dict)) + + if in_place: + self._dict = _result + + return DisplayDict.from_dict({key: value for key, value in zip(keys, _result)}) def get_entries_(self, keys): return {key: value for key, value in self._dict.items() if key in keys} - def filter(self, query): + def filter(self, query, in_place=False): + _result = None + if isinstance(query, list): _result = self.get_entries_(query) - return DisplayDict.from_dict(_result) if isinstance(query, str): _result = { key: value for key, value in self._dict.items() if re.search(query, key) } - return self.from_dict(_result) - return None + if in_place: + self._dict = _result + return None + + return DisplayDict.from_dict(_result) + + - def display(self): + def display(self, interactive=True): import rich from toolviper.utils.parameter import is_notebook - if is_notebook(): + if is_notebook() and interactive: from IPython.display import JSON return JSON(self._dict) return rich.print_json(data=self._dict) + + def html(self, indent=0): + html = "" + for key, value in self._dict.items(): + if isinstance(value, dict): + html += f"
{key}{dict_to_html(value, indent + 1)}
" + + else: + html += f"
{key}: {value}
" + + return HTML(html) From 7c5cc6bd1410e49160cf714f8681e7596c1244c8 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Wed, 25 Feb 2026 14:27:28 +0900 Subject: [PATCH 05/24] Add additional feature + tests? --- src/toolviper/utils/display.py | 9 +- tests/test_logger.py | 206 +++++++++++++++++++++++++++++ tests/test_tools.py | 234 ++++++++++++++++++++++++--------- 3 files changed, 385 insertions(+), 64 deletions(-) create mode 100644 tests/test_logger.py diff --git a/src/toolviper/utils/display.py b/src/toolviper/utils/display.py index 9396097..b4faebf 100755 --- a/src/toolviper/utils/display.py +++ b/src/toolviper/utils/display.py @@ -2,6 +2,7 @@ import operator from IPython.core.display import HTML + def dict_to_html(d, indent=0): print(f"THIS FUNCTION WILL BE DEPRECATED SOON") @@ -15,7 +16,7 @@ def dict_to_html(d, indent=0): return html -class DisplayDict(dict): +class DataDict(dict): def __init__(self, dictionary): super().__init__() self._dict = dictionary @@ -40,7 +41,7 @@ def select(self, keys, in_place=False): if in_place: self._dict = _result - return DisplayDict.from_dict({key: value for key, value in zip(keys, _result)}) + return DataDict.from_dict({key: value for key, value in zip(keys, _result)}) def get_entries_(self, keys): return {key: value for key, value in self._dict.items() if key in keys} @@ -60,9 +61,7 @@ def filter(self, query, in_place=False): self._dict = _result return None - return DisplayDict.from_dict(_result) - - + return DataDict.from_dict(_result) def display(self, interactive=True): import rich diff --git a/tests/test_logger.py b/tests/test_logger.py new file mode 100644 index 0000000..7f1bada --- /dev/null +++ b/tests/test_logger.py @@ -0,0 +1,206 @@ +import pytest +import os +import logging +from unittest.mock import MagicMock, patch +from toolviper.utils.logger import ( + set_verbosity, + verbosity, + info, + debug, + warning, + error, + critical, + exception, + log, + get_logger, + setup_logger, + ColorLoggingFormatter, + LoggingFormatter, + get_worker_logger_name, + setup_worker_logger, +) + + +@pytest.fixture(autouse=True) +def reset_verbosity(): + token = verbosity.set(None) + yield + verbosity.reset(token) + + +@pytest.fixture +def mock_logger(): + with patch("toolviper.utils.logger.get_logger") as mock: + logger = MagicMock(spec=logging.Logger) + mock.return_value = logger + yield logger + + +def test_set_verbosity(): + set_verbosity(True) + assert verbosity.get() is True + set_verbosity(False) + assert verbosity.get() is False + set_verbosity(None) + assert verbosity.get() is None + + +def test_info_logging(mock_logger): + info("test message") + mock_logger.info.assert_called_with("test message") + + +def test_info_verbose_logging(mock_logger): + with patch( + "toolviper.utils.logger.add_verbose_info", return_value="verbose test" + ) as mock_add: + info("test message", verbose=True) + mock_add.assert_called() + mock_logger.info.assert_called_with("verbose test") + + +def test_info_verbosity_context(mock_logger): + set_verbosity(True) + with patch("toolviper.utils.logger.add_verbose_info", return_value="verbose test"): + info("test message") + mock_logger.info.assert_called_with("verbose test") + + +def test_debug_logging(mock_logger): + debug("debug message") + mock_logger.debug.assert_called_with("debug message") + + +def test_warning_logging(mock_logger): + warning("warning message") + mock_logger.warning.assert_called_with("warning message") + + +def test_error_logging(mock_logger): + with patch( + "toolviper.utils.logger.add_verbose_info", + side_effect=lambda message, color: message, + ): + error("error message") + mock_logger.error.assert_called_with("error message") + + +def test_critical_logging(mock_logger): + with patch( + "toolviper.utils.logger.add_verbose_info", + side_effect=lambda message, color: message, + ): + critical("critical message") + mock_logger.critical.assert_called_with("critical message") + + +def test_exception_logging(mock_logger): + exception("exception message") + mock_logger.exception.assert_called_with("exception message") + + +def test_log_logging(mock_logger): + mock_logger.level = logging.INFO + log("log message") + mock_logger.log.assert_called_with(logging.INFO, "log message") + + +def test_get_logger_no_env_no_worker(monkeypatch): + monkeypatch.delenv("LOGGER_NAME", raising=False) + with patch("toolviper.utils.logger.get_worker", side_effect=ValueError): + logger = get_logger() + assert logger.name == "viperlog" + # Since it's a new logger, it should have a StreamHandler + assert any(isinstance(h, logging.StreamHandler) for h in logger.handlers) + + +def test_get_logger_existing_logger(monkeypatch): + monkeypatch.delenv("LOGGER_NAME", raising=False) + # Pre-create logger + existing_logger = logging.getLogger("existing_log") + with patch("toolviper.utils.logger.get_worker", side_effect=ValueError): + logger = get_logger("existing_log") + assert logger == existing_logger + + +def test_get_logger_env(): + with ( + patch("os.environ", {"LOGGER_NAME": "env_logger"}), + patch("toolviper.utils.logger.get_worker", side_effect=ValueError), + ): + logger = get_logger() + assert logger.name == "env_logger" + + +def test_get_logger_worker(): + mock_worker = MagicMock() + mock_logger_obj = MagicMock() + mock_worker.plugins = {"worker_logger": MagicMock()} + mock_worker.plugins["worker_logger"].get_logger.return_value = mock_logger_obj + + with patch("toolviper.utils.logger.get_worker", return_value=mock_worker): + logger = get_logger("test_logger") + assert logger == mock_logger_obj + + +def test_setup_logger_basic(tmp_path): + log_file_base = str(tmp_path / "test_log") + logger = setup_logger( + logger_name="setup_test", + log_to_term=True, + log_to_file=True, + log_file=log_file_base, + ) + assert logger.name == "setup_test" + assert len(logger.handlers) == 2 + # Cleanup + for handler in logger.handlers: + handler.close() + + +def test_color_logging_formatter(): + formatter = ColorLoggingFormatter() + record = logging.LogRecord("name", logging.INFO, "path", 10, "msg", None, None) + formatted = formatter.format(record) + assert "INFO" in formatted + assert "msg" in formatted + + +def test_logging_formatter(): + formatter = LoggingFormatter() + record = logging.LogRecord( + "name", logging.ERROR, "path", 20, "error msg", None, None + ) + formatted = formatter.format(record) + assert "ERROR" in formatted + assert "error msg" in formatted + + +def test_get_worker_logger_name(): + mock_worker = MagicMock() + mock_worker.id = "worker-123" + with patch("toolviper.utils.logger.get_worker", return_value=mock_worker): + name = get_worker_logger_name("mylog") + assert name == "mylog_worker-123" + + +def test_setup_worker_logger(tmp_path): + mock_worker = MagicMock() + mock_worker.name = "worker-1" + mock_worker.ip = "127.0.0.1" + log_file_base = str(tmp_path / "worker_log") + + with patch("dask.distributed.print"): + logger = setup_worker_logger( + logger_name="worker_test", + log_to_term=True, + log_to_file=True, + log_file=log_file_base, + log_level="DEBUG", + worker=mock_worker, + ) + assert "worker_test_worker-1" == logger.name + assert logger.level == logging.DEBUG + # Cleanup + for handler in logger.handlers: + handler.close() diff --git a/tests/test_tools.py b/tests/test_tools.py index c3bd8f1..9b047cb 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,85 +1,201 @@ +import pytest +import json import pathlib -import toolviper +import hashlib +import os +from unittest.mock import MagicMock, patch +from toolviper.utils.tools import ( + open_json, + calculate_checksum, + iter_files_, + update_hash, + verify, + process_entry_, + add_entry, + update_version, + ChecksumError, +) -import toolviper.utils.logger as logger +@pytest.fixture +def temp_json_file(tmp_path): + data = {"version": "v1.0.0", "metadata": {}} + file_path = tmp_path / "test.json" + with open(file_path, "w") as f: + json.dump(data, f) + return file_path -class TestToolViperTools: - @classmethod - def setup_class(cls): - """setup any state specific to the execution of the given test class - such as fetching test data""" - pass - @classmethod - def teardown_class(cls): - """teardown any state that was previously setup with a call to setup_class - such as deleting test data""" - # cls.client.shutdown() - pass +@pytest.fixture +def temp_data_file(tmp_path): + file_path = tmp_path / "test_file.txt" + content = b"hello world" + file_path.write_bytes(content) + # sha256 of "hello world" + expected_hash = hashlib.sha256(content).hexdigest() + return file_path, expected_hash - def setup_method(self): - """setup any state specific to all methods of the given class""" - pass - def teardown_method(self): - """teardown any state that was previously setup for all methods of the given class""" - pass +def test_open_json_success(temp_json_file): + data = open_json(str(temp_json_file)) + assert data["version"] == "v1.0.0" - def test_open_json(self): - from toolviper.utils.tools import open_json - try: - open_json("tests/data/test.json") +def test_open_json_not_found(): + with pytest.raises(FileNotFoundError): + open_json("non_existent_file.json") - except FileNotFoundError: - logger.info(f"Function open_json(...) working as expected.") - return None - raise AssertionError +def test_calculate_checksum(temp_data_file): + file_path, expected_hash = temp_data_file + assert calculate_checksum(str(file_path)) == expected_hash - def test_private_iter_files(self): - from toolviper.utils.tools import iter_files_ - try: - for _ in iter_files_("tests/data/"): - pass +def test_iter_files_(tmp_path): + (tmp_path / "file1.txt").write_text("1") + (tmp_path / "file2.txt").write_text("2") + files = list(iter_files_(str(tmp_path))) + assert set(files) == {"file1.txt", "file2.txt"} - except FileNotFoundError: - logger.info(f"Function iter_files_(...) working as expected.") - return None - raise AssertionError +def test_iter_files_not_found(): + with pytest.raises(FileNotFoundError): + list(iter_files_("non_existent_path")) - def test_verify(self): - from toolviper.utils.tools import verify - try: - # Test files that doesn't exist - verify(filename="test.json", folder="tests/data") +def test_update_hash(tmp_path): + # Setup manifest + manifest_path = tmp_path / "manifest.json" + manifest_data = {"metadata": {"test_file": {"hash": "old_hash"}}} + with open(manifest_path, "w") as f: + json.dump(manifest_data, f) - except FileNotFoundError: - logger.info(f"Function verify(...) working as expected.") - return None + # Setup data file + data_file = tmp_path / "test_file" + data_file.write_text("new content") + new_hash = hashlib.sha256(b"new content").hexdigest() - raise AssertionError + update_hash(str(manifest_path), str(tmp_path)) - def test_calculate_checksum(self): - from toolviper.utils.tools import calculate_checksum + updated_manifest = open_json(str(manifest_path)) + assert updated_manifest["metadata"]["test_file"]["hash"] == new_hash - base_address = pathlib.Path(toolviper.__file__).parent - metadata_address = base_address.joinpath( - "utils/data/.cloudflare/file.download.json" - ) - metadata = toolviper.utils.tools.open_json(str(metadata_address)) +def test_verify_success(tmp_path, monkeypatch): + # Setup manifest in a place where verify can find it (mocking toolviper.__file__) + manifest_dir = tmp_path / "utils/data/.cloudflare" + manifest_dir.mkdir(parents=True) + manifest_path = manifest_dir / "file.download.json" + + data_file = tmp_path / "test.zip" + data_file.write_text("zip content") + expected_hash = hashlib.sha256(b"zip content").hexdigest() + + manifest_data = {"metadata": {"test": {"hash": expected_hash}}} + with open(manifest_path, "w") as f: + json.dump(manifest_data, f) + + import toolviper + + monkeypatch.setattr(toolviper, "__file__", str(tmp_path / "__init__.py")) + (tmp_path / "__init__.py").touch() + + # verify(filename, folder) + # verify handles .zip extension by stripping it + verify("test.zip", str(tmp_path)) + + +def test_verify_checksum_error(tmp_path, monkeypatch): + manifest_dir = tmp_path / "utils/data/.cloudflare" + manifest_dir.mkdir(parents=True) + manifest_path = manifest_dir / "file.download.json" + + data_file = tmp_path / "test.zip" + data_file.write_text("wrong content") + + manifest_data = {"metadata": {"test": {"hash": "expected_but_different_hash"}}} + with open(manifest_path, "w") as f: + json.dump(manifest_data, f) + + import toolviper + + monkeypatch.setattr(toolviper, "__file__", str(tmp_path / "__init__.py")) + (tmp_path / "__init__.py").touch() + + with pytest.raises(ChecksumError): + verify("test.zip", str(tmp_path)) + - path = pathlib.Path.cwd().joinpath("data") - path.mkdir(parents=True, exist_ok=True) +def test_update_version(): + with ( + patch("toolviper.utils.tools.open_json") as mock_open, + patch("pathlib.Path.exists", return_value=True), + ): - toolviper.utils.data.download(file="checksum.hash", folder=str(path)) + mock_open.return_value = {"version": "v1.2.3"} - assert ( - toolviper.utils.tools.calculate_checksum(file="data/checksum.hash") - == metadata["metadata"]["checksum.hash"]["hash"] + # current implementation doesn't reset other parts + assert update_version("major") == "v2.2.3" + assert update_version("minor") == "v1.3.3" + assert update_version("patch") == "v1.2.4" + assert update_version("unknown") is None + + +def test_process_entry_(tmp_path): + json_file = {"metadata": {}} + test_file = tmp_path / "test.zip" + test_file.write_text("content") + file_hash = hashlib.sha256(b"content").hexdigest() + + with patch("toolviper.utils.data.get_file_size", return_value={"test": 123}): + process_entry_( + file=str(test_file), + path="verification", + dtype="int", + telescope="VLA", + mode="test", + json_file=json_file, ) + + assert "test" in json_file["metadata"] + assert json_file["metadata"]["test"]["hash"] == file_hash + assert json_file["metadata"]["test"]["size"] == "123" + + +def test_add_entry(tmp_path, monkeypatch): + manifest_dir = tmp_path / "utils/data/.cloudflare" + manifest_dir.mkdir(parents=True) + manifest_path = manifest_dir / "file.download.json" + manifest_data = {"version": "v1.0.0", "metadata": {}} + with open(manifest_path, "w") as f: + json.dump(manifest_data, f) + + import toolviper + + monkeypatch.setattr(toolviper, "__path__", [str(tmp_path)]) + + test_file = tmp_path / "new_file.zip" + test_file.write_text("content") + + entry = { + "file": str(test_file), + "path": "verification", + "dtype": "int", + "telescope": "VLA", + "mode": "test", + } + + with patch("toolviper.utils.data.get_file_size", return_value={"new_file": 456}): + result = add_entry(entries=[entry], manifest=str(manifest_path)) + + assert result["version"] == "v1.0.1" + assert "new_file" in result["metadata"] + assert os.path.exists("file.download.json") + os.remove("file.download.json") + + +def test_checksum_error_str(): + error = ChecksumError("msg", "file.txt", "/folder", 10) + assert "[10]: There was an error verifying the checksum of /folder/file.txt" in str( + error + ) From 4c0d9f975cbd2bbc557ea9729a6780245873af54 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Wed, 25 Feb 2026 14:42:06 +0900 Subject: [PATCH 06/24] Add additional feature + tests? --- tests/test_console.py | 53 +++++++++++++++++++ tests/test_logger.py | 1 - tests/test_parameter.py | 110 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 163 insertions(+), 1 deletion(-) create mode 100644 tests/test_console.py create mode 100644 tests/test_parameter.py diff --git a/tests/test_console.py b/tests/test_console.py new file mode 100644 index 0000000..18dd748 --- /dev/null +++ b/tests/test_console.py @@ -0,0 +1,53 @@ +import pytest +import inspect +from toolviper.utils.console import ColorCodes, Colorize, add_verbose_info + +def test_color_codes(): + codes = ColorCodes() + assert codes.reset == "\033[0m" + assert codes.red == "\033[38;2;220;20;60m" + +def test_colorize_basic(): + c = Colorize() + text = "hello" + assert c.bold(text) == f"\033[1m{text}\033[0m" + assert c.red(text) == f"\033[38;2;220;20;60m{text}\033[0m" + assert c.blue(text) == f"\033[38;2;50;50;205m{text}\033[0m" + +def test_colorize_format_list(): + c = Colorize() + # Testing format with RGB list + formatted = c.format("test", color=[255, 0, 0]) + assert "test" in formatted + assert "38;2;255;0;0" in formatted + +def test_colorize_format_string(): + c = Colorize() + formatted = c.format("test", color="green") + assert "test" in formatted + assert "38;2;46;139;87" in formatted + +def test_get_color_function(): + c = Colorize() + fn = c.get_color_function("red") + assert fn == c.red + + # Default to black if not found + fn_unknown = c.get_color_function("nonexistent") + assert fn_unknown == c.black + +def test_add_verbose_info(): + def dummy_caller(): + # result = add_verbose_info("my message") + # In this context, dummy_caller is the direct caller, so PENULTIMATE_FUNCTION (2) + # might refer to the caller of dummy_caller if add_verbose_info is called from it. + # Actually, add_verbose_info uses PENULTIMATE_FUNCTION = 2. + # stack[0] = add_verbose_info + # stack[1] = dummy_caller + # stack[2] = test_add_verbose_info + return add_verbose_info("my message") + + result = dummy_caller() + # It seems in this pytest execution, it gets 'test_add_verbose_info' as PENULTIMATE_FUNCTION + assert "my message" in result + assert "\033[" in result diff --git a/tests/test_logger.py b/tests/test_logger.py index 7f1bada..05f1410 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,5 +1,4 @@ import pytest -import os import logging from unittest.mock import MagicMock, patch from toolviper.utils.logger import ( diff --git a/tests/test_parameter.py b/tests/test_parameter.py new file mode 100644 index 0000000..0bb0041 --- /dev/null +++ b/tests/test_parameter.py @@ -0,0 +1,110 @@ +import pytest +import os +import json +import pathlib +import shutil +from unittest.mock import patch, MagicMock +from toolviper.utils.parameter import validate, get_path, set_config_directory, is_notebook, verify + +def test_is_notebook(): + # Should be False in normal python environment + assert is_notebook() is False + +def test_get_path_standard(monkeypatch): + def dummy_func(): + pass + + # Mock inspect.getmodule and inspect.getfile + mock_module = MagicMock() + mock_module.__name__ = "toolviper.utils.dummy" + + with patch("inspect.getmodule", return_value=mock_module), \ + patch("inspect.getfile", return_value="/abs/path/src/toolviper/utils/dummy.py"): + base, mod = get_path(dummy_func) + assert "src/toolviper" in base + assert mod == "/abs/path/src/toolviper/utils/dummy" + +def test_set_config_directory(tmp_path): + config_dir = tmp_path / "my_config" + config_dir.mkdir() + + with patch("toolviper.utils.logger.info"): + set_config_directory(str(config_dir)) + assert os.environ["PARAMETER_CONFIG_PATH"] == str(config_dir) + +def test_validate_decorator_success(tmp_path, monkeypatch): + # Setup a mock config file + config_dir = tmp_path / "config" + config_dir.mkdir() + param_file = config_dir / "test_mod.param.json" + schema = { + "my_func": { + "arg1": {"type": "int", "required": True}, + "arg2": {"type": "str"} + } + } + with open(param_file, "w") as f: + json.dump(schema, f) + + def my_func(arg1, arg2="default"): + return f"{arg1}-{arg2}" + + # Manually wrap with validate and trick it + my_func.__module__ = "toolviper.utils.test_mod" + my_func.__name__ = "my_func" + + wrapped = validate(config_dir=str(config_dir))(my_func) + + # We also need to mock get_path to avoid it searching in /tmp or something + with patch("toolviper.utils.parameter.get_path", return_value=(str(tmp_path), str(tmp_path / "test_mod"))): + assert wrapped(10, arg2="hello") == "10-hello" + +def test_validate_decorator_failure(tmp_path): + config_dir = tmp_path / "config" + config_dir.mkdir() + param_file = config_dir / "test_mod.param.json" + schema = { + "fail_func": { + "arg1": {"type": "int"} + } + } + with open(param_file, "w") as f: + json.dump(schema, f) + + def fail_func(arg1): + return arg1 + + fail_func.__module__ = "toolviper.utils.test_mod" + fail_func.__name__ = "fail_func" + + wrapped = validate(config_dir=str(config_dir))(fail_func) + + with patch("toolviper.utils.parameter.get_path", return_value=(str(tmp_path), str(tmp_path / "test_mod"))): + # Should raise AssertionError from verify's assert validator.validate(args) + with pytest.raises(AssertionError): + wrapped("not an int") + +def test_verify_missing_config(): + def no_config_func(): + pass + no_config_func.__module__ = "ghost_module" + + with patch("toolviper.utils.parameter.get_path", return_value=("/tmp", "/tmp/ghost_module")), \ + patch("toolviper.utils.logger.error"): + with pytest.raises(FileNotFoundError): + verify(no_config_func, {}, {"function": "no_config_func", "module": "ghost_module"}) + +def test_verify_function_not_in_schema(tmp_path): + config_dir = tmp_path / "config" + config_dir.mkdir() + param_file = config_dir / "known_mod.param.json" + with open(param_file, "w") as f: + json.dump({"other_func": {}}, f) + + def unknown_func(): + pass + unknown_func.__module__ = "known_mod" + + with patch("toolviper.utils.parameter.get_path", return_value=(str(tmp_path), str(tmp_path / "known_mod"))): + with pytest.raises(KeyError): + verify(unknown_func, {}, {"function": "unknown_func", "module": "known_mod"}, config_dir=str(config_dir)) From 568c9383196f4a92ec877ab98e6c34d5a473e1b5 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Wed, 25 Feb 2026 16:44:40 +0900 Subject: [PATCH 07/24] [black] Add additional feature + tests? --- tests/test_console.py | 12 +++-- tests/test_parameter.py | 99 +++++++++++++++++++++++++++-------------- 2 files changed, 75 insertions(+), 36 deletions(-) diff --git a/tests/test_console.py b/tests/test_console.py index 18dd748..51eb0aa 100644 --- a/tests/test_console.py +++ b/tests/test_console.py @@ -2,11 +2,13 @@ import inspect from toolviper.utils.console import ColorCodes, Colorize, add_verbose_info + def test_color_codes(): codes = ColorCodes() assert codes.reset == "\033[0m" assert codes.red == "\033[38;2;220;20;60m" + def test_colorize_basic(): c = Colorize() text = "hello" @@ -14,6 +16,7 @@ def test_colorize_basic(): assert c.red(text) == f"\033[38;2;220;20;60m{text}\033[0m" assert c.blue(text) == f"\033[38;2;50;50;205m{text}\033[0m" + def test_colorize_format_list(): c = Colorize() # Testing format with RGB list @@ -21,32 +24,35 @@ def test_colorize_format_list(): assert "test" in formatted assert "38;2;255;0;0" in formatted + def test_colorize_format_string(): c = Colorize() formatted = c.format("test", color="green") assert "test" in formatted assert "38;2;46;139;87" in formatted + def test_get_color_function(): c = Colorize() fn = c.get_color_function("red") assert fn == c.red - + # Default to black if not found fn_unknown = c.get_color_function("nonexistent") assert fn_unknown == c.black + def test_add_verbose_info(): def dummy_caller(): # result = add_verbose_info("my message") - # In this context, dummy_caller is the direct caller, so PENULTIMATE_FUNCTION (2) + # In this context, dummy_caller is the direct caller, so PENULTIMATE_FUNCTION (2) # might refer to the caller of dummy_caller if add_verbose_info is called from it. # Actually, add_verbose_info uses PENULTIMATE_FUNCTION = 2. # stack[0] = add_verbose_info # stack[1] = dummy_caller # stack[2] = test_add_verbose_info return add_verbose_info("my message") - + result = dummy_caller() # It seems in this pytest execution, it gets 'test_add_verbose_info' as PENULTIMATE_FUNCTION assert "my message" in result diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 0bb0041..bfcc4c5 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -4,95 +4,119 @@ import pathlib import shutil from unittest.mock import patch, MagicMock -from toolviper.utils.parameter import validate, get_path, set_config_directory, is_notebook, verify +from toolviper.utils.parameter import ( + validate, + get_path, + set_config_directory, + is_notebook, + verify, +) + def test_is_notebook(): # Should be False in normal python environment assert is_notebook() is False + def test_get_path_standard(monkeypatch): def dummy_func(): pass - + # Mock inspect.getmodule and inspect.getfile mock_module = MagicMock() mock_module.__name__ = "toolviper.utils.dummy" - - with patch("inspect.getmodule", return_value=mock_module), \ - patch("inspect.getfile", return_value="/abs/path/src/toolviper/utils/dummy.py"): + + with ( + patch("inspect.getmodule", return_value=mock_module), + patch("inspect.getfile", return_value="/abs/path/src/toolviper/utils/dummy.py"), + ): base, mod = get_path(dummy_func) assert "src/toolviper" in base assert mod == "/abs/path/src/toolviper/utils/dummy" + def test_set_config_directory(tmp_path): config_dir = tmp_path / "my_config" config_dir.mkdir() - + with patch("toolviper.utils.logger.info"): set_config_directory(str(config_dir)) assert os.environ["PARAMETER_CONFIG_PATH"] == str(config_dir) + def test_validate_decorator_success(tmp_path, monkeypatch): # Setup a mock config file config_dir = tmp_path / "config" config_dir.mkdir() param_file = config_dir / "test_mod.param.json" schema = { - "my_func": { - "arg1": {"type": "int", "required": True}, - "arg2": {"type": "str"} - } + "my_func": {"arg1": {"type": "int", "required": True}, "arg2": {"type": "str"}} } with open(param_file, "w") as f: json.dump(schema, f) - + def my_func(arg1, arg2="default"): return f"{arg1}-{arg2}" - + # Manually wrap with validate and trick it my_func.__module__ = "toolviper.utils.test_mod" my_func.__name__ = "my_func" - + wrapped = validate(config_dir=str(config_dir))(my_func) - + # We also need to mock get_path to avoid it searching in /tmp or something - with patch("toolviper.utils.parameter.get_path", return_value=(str(tmp_path), str(tmp_path / "test_mod"))): + with patch( + "toolviper.utils.parameter.get_path", + return_value=(str(tmp_path), str(tmp_path / "test_mod")), + ): assert wrapped(10, arg2="hello") == "10-hello" + def test_validate_decorator_failure(tmp_path): config_dir = tmp_path / "config" config_dir.mkdir() param_file = config_dir / "test_mod.param.json" - schema = { - "fail_func": { - "arg1": {"type": "int"} - } - } + schema = {"fail_func": {"arg1": {"type": "int"}}} with open(param_file, "w") as f: json.dump(schema, f) - + def fail_func(arg1): return arg1 - + fail_func.__module__ = "toolviper.utils.test_mod" fail_func.__name__ = "fail_func" - + wrapped = validate(config_dir=str(config_dir))(fail_func) - - with patch("toolviper.utils.parameter.get_path", return_value=(str(tmp_path), str(tmp_path / "test_mod"))): + + with patch( + "toolviper.utils.parameter.get_path", + return_value=(str(tmp_path), str(tmp_path / "test_mod")), + ): # Should raise AssertionError from verify's assert validator.validate(args) with pytest.raises(AssertionError): wrapped("not an int") + def test_verify_missing_config(): def no_config_func(): pass + no_config_func.__module__ = "ghost_module" - - with patch("toolviper.utils.parameter.get_path", return_value=("/tmp", "/tmp/ghost_module")), \ - patch("toolviper.utils.logger.error"): + + with ( + patch( + "toolviper.utils.parameter.get_path", + return_value=("/tmp", "/tmp/ghost_module"), + ), + patch("toolviper.utils.logger.error"), + ): with pytest.raises(FileNotFoundError): - verify(no_config_func, {}, {"function": "no_config_func", "module": "ghost_module"}) + verify( + no_config_func, + {}, + {"function": "no_config_func", "module": "ghost_module"}, + ) + def test_verify_function_not_in_schema(tmp_path): config_dir = tmp_path / "config" @@ -100,11 +124,20 @@ def test_verify_function_not_in_schema(tmp_path): param_file = config_dir / "known_mod.param.json" with open(param_file, "w") as f: json.dump({"other_func": {}}, f) - + def unknown_func(): pass + unknown_func.__module__ = "known_mod" - - with patch("toolviper.utils.parameter.get_path", return_value=(str(tmp_path), str(tmp_path / "known_mod"))): + + with patch( + "toolviper.utils.parameter.get_path", + return_value=(str(tmp_path), str(tmp_path / "known_mod")), + ): with pytest.raises(KeyError): - verify(unknown_func, {}, {"function": "unknown_func", "module": "known_mod"}, config_dir=str(config_dir)) + verify( + unknown_func, + {}, + {"function": "unknown_func", "module": "known_mod"}, + config_dir=str(config_dir), + ) From 716e77b66449cdc1e5446f149995186b5ed25e6a Mon Sep 17 00:00:00 2001 From: jrhosk Date: Wed, 25 Feb 2026 19:02:29 +0900 Subject: [PATCH 08/24] [black] Add even more additional feature + tests? --- tests/test_client.py | 23 +++++++++++ tests/test_dask_plugins.py | 81 ++++++++++++++++++++++++++++++++++++++ tests/test_menrva.py | 57 +++++++++++++++++++++++++++ tests/test_parameter.py | 2 - 4 files changed, 161 insertions(+), 2 deletions(-) create mode 100644 tests/test_dask_plugins.py diff --git a/tests/test_client.py b/tests/test_client.py index 2f5fd88..9de19e4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,6 +2,7 @@ import re import pathlib import distributed +from unittest.mock import patch, MagicMock from toolviper.dask.client import local_client @@ -172,3 +173,25 @@ def test__set_up_dask(self): _set_up_dask(local_directory=pathlib.Path(".").cwd()) assert dask.config.config["distributed"]["scheduler"]["allowed-failures"] == 10 + + +def test_print_libraries_availability(): + from toolviper.dask.client import print_libraries_availability + import toolviper.utils.logger as logger + + with patch.object(logger, 'debug') as mock_debug: + print_libraries_availability({"CUDA": True, "MPI": False}) + mock_debug.assert_called_once() + args, kwargs = mock_debug.call_args + assert "CUDA" in args[0] + assert "MPI" not in args[0] + +def test_get_client_none(): + from toolviper.dask.client import get_client + with patch("distributed.Client.current", side_effect=ValueError): + assert get_client() is None + +def test_get_cluster_none(): + from toolviper.dask.client import get_cluster + with patch("toolviper.dask.client.get_client", return_value=None): + assert get_cluster() is None diff --git a/tests/test_dask_plugins.py b/tests/test_dask_plugins.py new file mode 100644 index 0000000..cbbfcf0 --- /dev/null +++ b/tests/test_dask_plugins.py @@ -0,0 +1,81 @@ +import pytest +from unittest.mock import MagicMock, patch +from toolviper.dask.plugins.worker import DaskWorker +from toolviper.dask.plugins.scheduler import Scheduler, unravel_deps, get_node_depths + +def test_dask_worker_init(): + log_params = { + "log_level": "DEBUG", + "log_to_term": False, + "log_to_file": True, + "log_file": "test.log" + } + plugin = DaskWorker(local_cache=True, log_params=log_params) + assert plugin.local_cache is True + assert plugin.log_level == "DEBUG" + assert plugin.log_to_term is False + assert plugin.log_to_file is True + assert plugin.log_file == "test.log" + +def test_dask_worker_setup(): + plugin = DaskWorker(log_params={"log_level": "INFO"}) + mock_worker = MagicMock() + mock_worker.id = "worker-1" + mock_worker.address = "tcp://127.0.0.1:1234" + mock_worker.state.available_resources = {} + + with patch("toolviper.utils.logger.setup_worker_logger") as mock_setup_logger: + mock_logger = MagicMock() + mock_setup_logger.return_value = mock_logger + + plugin.setup(mock_worker) + + mock_setup_logger.assert_called_once() + assert plugin.worker == mock_worker + # Check if resource for IP was added + assert "127.0.0.1" in mock_worker.state.available_resources + +def test_scheduler_init(): + scheduler = Scheduler(autorestrictor=True, local_cache=False) + assert scheduler.autorestrictor is True + assert scheduler.local_cache is False + +def test_unravel_deps(): + hlg_deps = { + 'task1': {'task2', 'task3'}, + 'task2': {'task4'}, + 'task3': set(), + 'task4': set() + } + unravelled = unravel_deps(hlg_deps, 'task1') + assert unravelled == {'task2', 'task3', 'task4'} + +def test_get_node_depths(): + dependencies = { + 'A': set(), + 'B': {'A'}, + 'C': {'B'}, + 'D': {'A'} + } + root_nodes = {'A'} + # metrics[node][-1] is the "depth" of the node from terminal nodes (as calculated by graph_metrics) + # get_node_depths calculates: max(metrics[r][-1] - metrics[k][-1] for r in roots) + # For a simple chain A -> B -> C: + # C is terminal, depth 0 in metrics. + # B depends on A, so B's depth in metrics is 1. + # A is root, depth 2 in metrics. + metrics = { + 'A': [0, 2], # depth 2 + 'B': [0, 1], # depth 1 + 'C': [0, 0], # depth 0 + 'D': [0, 1] # depth 1 + } + + node_depths = get_node_depths(dependencies, root_nodes, metrics) + assert node_depths['A'] == 0 + # For B: roots is {'A'}. node_depths['B'] = max(metrics['A'][1] - metrics['B'][1]) = 2 - 1 = 1 + assert node_depths['B'] == 1 + # For C: roots is {'A'}. node_depths['C'] = max(metrics['A'][1] - metrics['C'][1]) = 2 - 0 = 2 + assert node_depths['C'] == 2 + # For D: roots is {'A'}. node_depths['D'] = max(metrics['A'][1] - metrics['D'][1]) = 2 - 1 = 1 + assert node_depths['D'] == 1 diff --git a/tests/test_menrva.py b/tests/test_menrva.py index a6f89dd..0726eb2 100644 --- a/tests/test_menrva.py +++ b/tests/test_menrva.py @@ -2,6 +2,7 @@ import re import pathlib import distributed +from unittest.mock import patch, MagicMock from toolviper.dask import menrva from toolviper.dask.client import local_client @@ -78,3 +79,59 @@ def test_thread_info(self): ) client.shutdown() + + +def test_port_is_free(): + from toolviper.dask.menrva import port_is_free + import socket + + # Test with a definitely free port (hopefully) + # We can use port 0 to let the OS pick a free port, but port_is_free binds it and closes it. + # Let's try to bind a port ourselves and then check if it's free. + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + + # Since 's' is holding the port, port_is_free should return False + assert port_is_free(port) is False + + s.close() + # Now it should be free + assert port_is_free(port) is True + +def test_close_port(): + from toolviper.dask.menrva import close_port, port_is_free + import socket + import time + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + s.listen(1) + + assert port_is_free(port) is False + + # close_port tries to kill the process holding the port. + # Since it's our own process, this might be dangerous if not careful, + # but here it's just a socket in the same process. + # Actually, close_port uses psutil to find processes with that port and SIGKILLs them. + # We should probably mock psutil for this test to avoid killing ourselves. + + with patch("psutil.process_iter") as mock_iter: + mock_proc = MagicMock() + mock_conn = MagicMock() + mock_conn.laddr.port = port + mock_proc.connections.return_value = [mock_conn] + mock_iter.return_value = [mock_proc] + + close_port(port) + + mock_proc.send_signal.assert_called_once() + +def test_menrva_client_call(): + from toolviper.dask.menrva import MenrvaClient + def my_func(a, b=1): + return a + b + + assert MenrvaClient.call(my_func, 2, b=3) == 5 + assert MenrvaClient.call(my_func, 2) == 3 diff --git a/tests/test_parameter.py b/tests/test_parameter.py index bfcc4c5..52a43a9 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,8 +1,6 @@ import pytest import os import json -import pathlib -import shutil from unittest.mock import patch, MagicMock from toolviper.utils.parameter import ( validate, From 67abb7095e1000f9bc0cea36d2a0513857f67252 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Wed, 25 Feb 2026 19:03:05 +0900 Subject: [PATCH 09/24] [black] Add even more additional feature + tests? --- tests/test_client.py | 10 +++++--- tests/test_dask_plugins.py | 52 +++++++++++++++++++------------------- tests/test_menrva.py | 19 ++++++++------ 3 files changed, 44 insertions(+), 37 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index 9de19e4..79a7555 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -173,25 +173,29 @@ def test__set_up_dask(self): _set_up_dask(local_directory=pathlib.Path(".").cwd()) assert dask.config.config["distributed"]["scheduler"]["allowed-failures"] == 10 - + def test_print_libraries_availability(): from toolviper.dask.client import print_libraries_availability import toolviper.utils.logger as logger - - with patch.object(logger, 'debug') as mock_debug: + + with patch.object(logger, "debug") as mock_debug: print_libraries_availability({"CUDA": True, "MPI": False}) mock_debug.assert_called_once() args, kwargs = mock_debug.call_args assert "CUDA" in args[0] assert "MPI" not in args[0] + def test_get_client_none(): from toolviper.dask.client import get_client + with patch("distributed.Client.current", side_effect=ValueError): assert get_client() is None + def test_get_cluster_none(): from toolviper.dask.client import get_cluster + with patch("toolviper.dask.client.get_client", return_value=None): assert get_cluster() is None diff --git a/tests/test_dask_plugins.py b/tests/test_dask_plugins.py index cbbfcf0..0410743 100644 --- a/tests/test_dask_plugins.py +++ b/tests/test_dask_plugins.py @@ -3,12 +3,13 @@ from toolviper.dask.plugins.worker import DaskWorker from toolviper.dask.plugins.scheduler import Scheduler, unravel_deps, get_node_depths + def test_dask_worker_init(): log_params = { "log_level": "DEBUG", "log_to_term": False, "log_to_file": True, - "log_file": "test.log" + "log_file": "test.log", } plugin = DaskWorker(local_cache=True, log_params=log_params) assert plugin.local_cache is True @@ -17,47 +18,46 @@ def test_dask_worker_init(): assert plugin.log_to_file is True assert plugin.log_file == "test.log" + def test_dask_worker_setup(): plugin = DaskWorker(log_params={"log_level": "INFO"}) mock_worker = MagicMock() mock_worker.id = "worker-1" mock_worker.address = "tcp://127.0.0.1:1234" mock_worker.state.available_resources = {} - + with patch("toolviper.utils.logger.setup_worker_logger") as mock_setup_logger: mock_logger = MagicMock() mock_setup_logger.return_value = mock_logger - + plugin.setup(mock_worker) - + mock_setup_logger.assert_called_once() assert plugin.worker == mock_worker # Check if resource for IP was added assert "127.0.0.1" in mock_worker.state.available_resources + def test_scheduler_init(): scheduler = Scheduler(autorestrictor=True, local_cache=False) assert scheduler.autorestrictor is True assert scheduler.local_cache is False + def test_unravel_deps(): hlg_deps = { - 'task1': {'task2', 'task3'}, - 'task2': {'task4'}, - 'task3': set(), - 'task4': set() + "task1": {"task2", "task3"}, + "task2": {"task4"}, + "task3": set(), + "task4": set(), } - unravelled = unravel_deps(hlg_deps, 'task1') - assert unravelled == {'task2', 'task3', 'task4'} + unravelled = unravel_deps(hlg_deps, "task1") + assert unravelled == {"task2", "task3", "task4"} + def test_get_node_depths(): - dependencies = { - 'A': set(), - 'B': {'A'}, - 'C': {'B'}, - 'D': {'A'} - } - root_nodes = {'A'} + dependencies = {"A": set(), "B": {"A"}, "C": {"B"}, "D": {"A"}} + root_nodes = {"A"} # metrics[node][-1] is the "depth" of the node from terminal nodes (as calculated by graph_metrics) # get_node_depths calculates: max(metrics[r][-1] - metrics[k][-1] for r in roots) # For a simple chain A -> B -> C: @@ -65,17 +65,17 @@ def test_get_node_depths(): # B depends on A, so B's depth in metrics is 1. # A is root, depth 2 in metrics. metrics = { - 'A': [0, 2], # depth 2 - 'B': [0, 1], # depth 1 - 'C': [0, 0], # depth 0 - 'D': [0, 1] # depth 1 + "A": [0, 2], # depth 2 + "B": [0, 1], # depth 1 + "C": [0, 0], # depth 0 + "D": [0, 1], # depth 1 } - + node_depths = get_node_depths(dependencies, root_nodes, metrics) - assert node_depths['A'] == 0 + assert node_depths["A"] == 0 # For B: roots is {'A'}. node_depths['B'] = max(metrics['A'][1] - metrics['B'][1]) = 2 - 1 = 1 - assert node_depths['B'] == 1 + assert node_depths["B"] == 1 # For C: roots is {'A'}. node_depths['C'] = max(metrics['A'][1] - metrics['C'][1]) = 2 - 0 = 2 - assert node_depths['C'] == 2 + assert node_depths["C"] == 2 # For D: roots is {'A'}. node_depths['D'] = max(metrics['A'][1] - metrics['D'][1]) = 2 - 1 = 1 - assert node_depths['D'] == 1 + assert node_depths["D"] == 1 diff --git a/tests/test_menrva.py b/tests/test_menrva.py index 0726eb2..edd23db 100644 --- a/tests/test_menrva.py +++ b/tests/test_menrva.py @@ -91,14 +91,15 @@ def test_port_is_free(): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("127.0.0.1", 0)) port = s.getsockname()[1] - + # Since 's' is holding the port, port_is_free should return False assert port_is_free(port) is False - + s.close() # Now it should be free assert port_is_free(port) is True + def test_close_port(): from toolviper.dask.menrva import close_port, port_is_free import socket @@ -108,30 +109,32 @@ def test_close_port(): s.bind(("127.0.0.1", 0)) port = s.getsockname()[1] s.listen(1) - + assert port_is_free(port) is False - + # close_port tries to kill the process holding the port. # Since it's our own process, this might be dangerous if not careful, # but here it's just a socket in the same process. # Actually, close_port uses psutil to find processes with that port and SIGKILLs them. # We should probably mock psutil for this test to avoid killing ourselves. - + with patch("psutil.process_iter") as mock_iter: mock_proc = MagicMock() mock_conn = MagicMock() mock_conn.laddr.port = port mock_proc.connections.return_value = [mock_conn] mock_iter.return_value = [mock_proc] - + close_port(port) - + mock_proc.send_signal.assert_called_once() + def test_menrva_client_call(): from toolviper.dask.menrva import MenrvaClient + def my_func(a, b=1): return a + b - + assert MenrvaClient.call(my_func, 2, b=3) == 5 assert MenrvaClient.call(my_func, 2) == 3 From 277b918a6be1b5982f838612763387fd196da045 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Sat, 28 Feb 2026 00:54:19 +0900 Subject: [PATCH 10/24] [black] Add even more additional feature + tests? --- src/toolviper/dask/client.py | 457 +++++++-------------- src/toolviper/utils/data/cloudflare.py | 505 ++++++++++++----------- src/toolviper/utils/logger.py | 529 +++++++++++++++---------- src/toolviper/utils/tools.py | 2 +- tests/test_cloudflare.py | 109 +++++ 5 files changed, 842 insertions(+), 760 deletions(-) create mode 100644 tests/test_cloudflare.py diff --git a/src/toolviper/dask/client.py b/src/toolviper/dask/client.py index 9a8d7b4..e967f90 100644 --- a/src/toolviper/dask/client.py +++ b/src/toolviper/dask/client.py @@ -6,12 +6,11 @@ import dask_jobqueue import distributed import psutil -import inspect import functools from importlib import import_module from importlib.util import find_spec -from typing import Dict, Union +from typing import Dict, Union, Any, Optional import toolviper.dask.menrva import toolviper.utils.console as console @@ -20,43 +19,64 @@ colorize = console.Colorize() +DEFAULT_CLIENT_LOG_PARAMS = { + "logger_name": "client", + "log_to_term": True, + "log_level": "INFO", + "log_to_file": False, + "log_file": "client.log", +} + +DEFAULT_WORKER_LOG_PARAMS = { + "logger_name": "worker", + "log_to_term": True, + "log_level": "INFO", + "log_to_file": False, + "log_file": "client_worker.log", +} + + +def _get_log_params( + log_params: Optional[Dict[str, Any]], defaults: Dict[str, Any] +) -> Dict[str, Any]: + if log_params is None: + log_params = {} + return {**defaults, **log_params} + def load_libraries(name: str, libs: Union[str, list[str]]) -> dict[str, bool]: """Load libraries if they were installed and can be loaded. Parameters ---------- - name : library group name - A library group name based on a function of a distributed environment will be imported. + name : str + A library group name based on a function of a distributed environment. libs : Union[str, list[str]] - a library or a list of libraries to import + A library or a list of libraries to import. Returns ------- - an item of dict has the name and the flag whether all libraries were loaded successfully. + dict[str, bool] + A dictionary mapping the group name to a boolean indicating if all libraries were loaded successfully. """ def _load_library(_lib): if find_spec(_lib) is not None: import_module(_lib) - return [True, f" {colorize.blue(_lib)} is available"] - else: - return [False, f" {colorize.blue(_lib)} is unavailable"] - - if isinstance(libs, list): - _tmp = list(map(_load_library, libs)) - _avail = [all([x[0] for x in _tmp]), [x[1] for x in _tmp]] - elif isinstance(libs, str): - _tmp = _load_library(libs) - _avail = [_tmp[0], [_tmp[1]]] - else: - _avail = [False, " illegal module specification"] + return True, f" {colorize.blue(_lib)} is available" + return False, f" {colorize.blue(_lib)} is unavailable" + + if isinstance(libs, str): + libs = [libs] - _result = "Success" if _avail[0] else "Fail" - logger.info(f"Loading module: {name} -- {_result}") - [logger.info(x) for x in _avail[1]] + results = [_load_library(lib) for lib in libs] + all_available = all(res[0] for res in results) - return {name: _avail[0]} + logger.info(f"Loading module: {name} -- {'Success' if all_available else 'Fail'}") + for _, message in results: + logger.info(message) + + return {name: all_available} def print_libraries_availability(spec: dict[str, bool]): @@ -117,109 +137,53 @@ def get_cluster() -> Union[None, distributed.LocalCluster]: @parameter.validate() def local_client( - cores: int = None, - memory_limit: str = None, + cores: Optional[int] = None, + memory_limit: Optional[str] = None, autorestrictor: bool = False, - dask_local_dir: str = None, - local_dir: str = None, + dask_local_dir: Optional[str] = None, + local_dir: Optional[str] = None, wait_for_workers: bool = True, - log_params: Union[None, Dict] = None, - worker_log_params: Union[None, Dict] = None, + log_params: Optional[Dict[str, Any]] = None, + worker_log_params: Optional[Dict[str, Any]] = None, dashboard_address: str = ":8787", serial_execution: bool = False, -) -> Union[distributed.Client, None]: - """ Creates a local client, scheduler and workers using Dask Distributed LocalCluster (https://docs.dask.org/en/stable/deploying-python.html#reference) - with Dask configuration tuned for VIPER and the option to use autorestrictor plugin and local cache. +) -> Optional[distributed.Client]: + """Create a local client, scheduler and workers using Dask Distributed LocalCluster. + + With Dask configuration tuned for VIPER and the option to use autorestrictor plugin and local cache. + See https://docs.dask.org/en/stable/deploying-python.html#reference for more details. Parameters ---------- - cores : int - Number of cores in Dask cluster, defaults to None - memory_limit : str - Amount of memory per core. It is suggested to use '8GB', defaults to None - autorestrictor : bool - Boolean determining usage of autorestrictor plugin, defaults to False - dask_local_dir : str - Where Dask should store temporary files, defaults to None. If None Dask will use \ - `./dask-worker-space`, defaults to None - local_dir : str - Defines client local directory, defaults to None - - wait_for_workers : bool - Boolean determining usage of wait_for_workers option in dask, defaults to False - log_params : dict - The logger for the main process (code that does not run in parallel), defaults to {} - worker_log_params : dict - worker_log_params: Keys as same as log_params, default values given in `Additional \ - Information`_. - - dashboard_address: str - Address on which to listen for the Bokeh diagnostics server like ‘localhost:8787’ or ‘0.0.0.0:8787’. Defaults to ‘:8787’. - Set to None to disable the dashboard. Use ‘:0’ for a random port. See https://docs.dask.org/en/stable/deploying-python.html#reference for more information. - - serial_execution : bool - This is an option that forces dask to run in serial mode while also setting up the logger to work. This is - really only appropriate for debugging. - - .. _Description: - - ** _log_params ** - - The log_params (worker_log_params) dictionary stores initialization information for the logger and associated - workers. the following are the acceptable key: value pairs and their usage information. - - log_params["logger_name"] : str - Defines the logger name to use - log_params["log_to_term"] : bool - Should messages log to the terminal output. - log_params["log_level"] : str - Defines logging level, valid options: - - DEBUG - - INFO - - WARNING - - ERROR - - CRITICAL - - Only messages flagged as at the given level or below are logged. - - log_params["log_to_file"] : str - Should messages log to file. - - log_params["log_file"] : str - Name of log file to create. If none is given, the file name 'logger' will be used. + cores : int, optional + Number of cores in Dask cluster. Defaults to number of physical cores. + memory_limit : str, optional + Amount of memory per core. Suggested: '8GB'. Defaults to available memory divided by cores. + autorestrictor : bool, optional + Whether to use the autorestrictor plugin. Defaults to False. + dask_local_dir : str, optional + Temporary files directory for Dask. Defaults to None. + local_dir : str, optional + Client local directory. Defaults to None. + wait_for_workers : bool, optional + Whether to wait for workers to start. Defaults to True. + log_params : dict, optional + Logger configuration for the main process. + worker_log_params : dict, optional + Logger configuration for workers. + dashboard_address : str, optional + Address for the Bokeh diagnostics server (e.g., 'localhost:8787'). Defaults to ':8787'. + serial_execution : bool, optional + If True, runs Dask in serial mode (synchronous) for debugging. Defaults to False. Returns ------- - Dask Distributed Client + distributed.Client or None + Dask Distributed Client, or None if serial_execution is True. """ - if log_params is None: - log_params = {} - - log_params = { - **{ - "logger_name": "client", - "log_to_term": True, - "log_level": "INFO", - "log_to_file": False, - "log_file": "client.log", - }, - **log_params, - } - - if worker_log_params is None: - worker_log_params = {} - - worker_log_params = { - **{ - "logger_name": "worker", - "log_to_term": True, - "log_level": "INFO", - "log_to_file": False, - "log_file": "client_worker.log", - }, - **worker_log_params, - } + log_params = _get_log_params(log_params, DEFAULT_CLIENT_LOG_PARAMS) + worker_log_params = _get_log_params(worker_log_params, DEFAULT_WORKER_LOG_PARAMS) # If the user wants to change the global logger name from the # default value of toolviper @@ -329,82 +293,34 @@ def local_client( return client +@parameter.validate() def distributed_client( - cluster: None, - dask_local_dir: str = None, - log_params: Union[None, Dict] = None, - worker_log_params: Union[None, Dict] = None, -) -> Union[distributed.Client, None]: - """ Setup dask cluster and logger. + cluster: Any, + dask_local_dir: Optional[str] = None, + log_params: Optional[Dict[str, Any]] = None, + worker_log_params: Optional[Dict[str, Any]] = None, +) -> distributed.Client: + """Setup dask cluster and logger. Parameters ---------- - cluster - log_params : dict - The logger for the main process (code that does not run in parallel), defaults to {} - worker_log_params : dict - worker_log_params: Keys as same as log_params, default values given in `Additional \ - Information`_. - - .. _Description: - - ** _log_params ** - - The log_params (worker_log_params) dictionary stores initialization information for the logger and associated - workers. the following are the acceptable key: value pairs and their usage information. - - log_params["logger_name"] : str - Defines the logger name to use - log_params["log_to_term"] : bool - Should messages log to the terminal output. - log_params["log_level"] : str - Defines logging level, valid options: - - DEBUG - - INFO - - WARNING - - ERROR - - CRITICAL - - Only messages flagged as at the given level or below are logged. - - log_params["log_to_file"] : str - Should messages log to file. - - log_params["log_filee"] : str - Name of log file to create. If none is given, the file name 'logger' will be used. + cluster : Any + An existing dask cluster instance. + dask_local_dir : str, optional + Where Dask should store temporary files. + log_params : dict, optional + The logger for the main process. + worker_log_params : dict, optional + The logger for the workers. Returns ------- + distributed.Client Dask Distributed Client """ - if log_params is None: - log_params = {} - - log_params = { - **{ - "logger_name": "client", - "log_to_term": True, - "log_level": "INFO", - "log_to_file": False, - "log_file": "client.log", - }, - **log_params, - } - - if worker_log_params is None: - worker_log_params = {} - - worker_log_params = { - **{ - "logger_name": "worker", - "log_to_term": True, - "log_level": "INFO", - "log_to_file": False, - "log_file": "client_worker.log", - }, - **worker_log_params, - } + log_params = _get_log_params(log_params, DEFAULT_CLIENT_LOG_PARAMS) + worker_log_params = _get_log_params(worker_log_params, DEFAULT_WORKER_LOG_PARAMS) # If the user wants to change the global logger name from the # default value of toolviper @@ -420,11 +336,6 @@ def distributed_client( _set_up_dask(dask_local_dir) - """ - load libraries related functions of a distributed environment - 'available_specs' contains the function name and a flag that the function was loaded successfully - """ - logger.debug(colorize.green("Checking functions availability:")) available_specs = { **load_libraries("slurm", "dask_jobqueue"), @@ -434,16 +345,13 @@ def distributed_client( print_libraries_availability(available_specs) - # This will work as long as the scheduler path isn't in some outside directory. Being that it is a plugin specific - # to this module, I think keeping it static in the module directory it good. - plugin_path = str(pathlib.Path(__file__).parent.resolve().joinpath("plugins/")) - client = toolviper.dask.menrva.MenrvaClient(cluster) client.get_versions(check=True) logger.info("Created client " + str(client)) return client +@parameter.validate() def slurm_cluster_client( workers_per_node: int, cores_per_node: int, @@ -456,111 +364,61 @@ def slurm_cluster_client( dask_log_dir: str, exclude_nodes: str = "", dashboard_port: int = 8787, - local_dir: str = None, + local_dir: Optional[str] = None, autorestrictor: bool = False, wait_for_workers: bool = True, - log_params: Union[None, Dict] = None, - worker_log_params: Union[None, Dict] = None, -): - """Creates a Dask slurm_cluster_client on a multinode cluster. - - interface eth0, ib0 + log_params: Optional[Dict[str, Any]] = None, + worker_log_params: Optional[Dict[str, Any]] = None, +) -> distributed.Client: + """Create a SLURM cluster and return a client. Parameters ---------- workers_per_node : int - Number of workers per node ... - + Number of workers per node. cores_per_node : int - Number of cores per node ... - + Number of cores per node. memory_per_node : str - Memory allocation per node ... - + Memory per node (e.g., '64GB'). number_of_nodes : int - Number of nodes ... - + Number of nodes to request. queue : str - Destination queue for each worker job. Passed to #SBATCH -p option - + SLURM queue name. interface : str - Network interface like ‘eth0’ or ‘ib0’. This will be used both for the Dask scheduler and the Dask workers - interface. If you need a different interface for the Dask scheduler you can pass it through the - scheduler_options argument: interface=your_worker_interface, - scheduler_options={'interface': your_scheduler_interface}. - + Network interface to use (e.g., 'ib0'). python_env_dir : str - Python executable used to launch Dask workers. Defaults to the Python that is submitting these jobs. - + Path to the python executable in the environment. dask_local_dir : str - Where Dask should store temporary files, defaults to None. If None Dask will use \ - `./dask-worker-space`, defaults to None - - local_dir : str - Defines client local directory, defaults to None - + Local directory for dask workers. dask_log_dir : str - Destination directory for dask log files. - - exclude_nodes : str - Nodes to exclude. - - dashboard_port : int - Port to use for dashboard connection. - - autorestrictor : bool - Boolean determining usage of autorestrictor plugin, defaults to False - - wait_for_workers : bool - Boolean determining usage of wait_for_workers option in dask, defaults to False - - log_params : dict - Dictionary containing parameters to using for logging. - - worker_log_params : dict - Dictionary containing parameters to using for worker logging. - - .. _Description: - - ** _log_params ** - - The log_params (worker_log_params) dictionary stores initialization information for the logger and associated - workers. the following are the acceptable key: value pairs and their usage information. - - log_params["logger_name"] : str - Defines the logger name to use - log_params["log_to_term"] : bool - Should messages log to the terminal output. - log_params["log_level"] : str - Defines logging level, valid options: - - DEBUG - - INFO - - WARNING - - ERROR - - CRITICAL - - Only messages flagged as at the given level or below are logged. - - log_params["log_to_file"] : str - Should messages log to file. - - log_params["log_filee"] : str - Name of log file to create. If none is given, the file name 'logger' will be used. + Directory for dask logs. + exclude_nodes : str, optional + Comma-separated list of nodes to exclude. + dashboard_port : int, optional + Port for the dask dashboard. + local_dir : str, optional + Client local directory. + autorestrictor : bool, optional + Whether to use the autorestrictor plugin. + wait_for_workers : bool, optional + Whether to wait for workers to start. + log_params : dict, optional + Logger parameters for the client. + worker_log_params : dict, optional + Logger parameters for the workers. Returns ------- - distributed.Client + distributed.Client + The dask client connected to the SLURM cluster. """ # https://github.com/dask/dask/issues/5577 # from distributed import Client - if log_params is None: - log_params = {} - - if worker_log_params is None: - worker_log_params = {} + log_params = _get_log_params(log_params, DEFAULT_CLIENT_LOG_PARAMS) + worker_log_params = _get_log_params(worker_log_params, DEFAULT_WORKER_LOG_PARAMS) if local_dir: os.environ["VIPER_LOCAL_DIR"] = local_dir @@ -568,9 +426,6 @@ def slurm_cluster_client( else: local_cache = False - # Viper logger for code that is not part of the Dask graph. The worker logger is setup in the _worker plugin. - # from viper._utils._logger import setup_logger - logger.setup_logger(**log_params) _set_up_dask(dask_local_dir) @@ -606,22 +461,6 @@ def slurm_cluster_client( } ) - # This method of assigning a worker plugin does not seem to work when using dask_jobqueue. Consequently, using - # client.register_plugin so that the method of assigning a worker plugin is the same for local_client and - # slurm_cluster_client. - # - # if local_cache or worker_log_params: - # dask.config.set({"distributed.worker.preload": os.path.join(plugin_path,"_utils/_worker.py")}) - # dask.config.set({ - # "distributed.worker.preload-argv": [ - # "--local_cache",local_cache, - # "--log_to_term",worker_log_params["log_to_term"], - # "--log_to_file",worker_log_params["log_to_file"], - # "--log_file",worker_log_params["log_file"], - # "--log_level",worker_log_params["log_level"]] - # }) - # - cluster = dask_jobqueue.SLURMCluster( processes=workers_per_node, cores=cores_per_node, @@ -634,16 +473,13 @@ def slurm_cluster_client( local_directory=dask_local_dir, log_directory=dask_log_dir, job_extra_directives=["--exclude=" + exclude_nodes], - # job_extra_directives=["--exclude=nmpost087,nmpost089,nmpost088"], scheduler_options={"dashboard_address": ":" + str(dashboard_port)}, - ) # interface="ib0" + ) client = toolviper.dask.menrva.MenrvaClient(cluster) - cluster.scale(workers_per_node * number_of_nodes) # When constructing a graph that has local cache enabled all workers need to be up and running. - if local_cache or wait_for_workers: client.wait_for_workers(n_workers=workers_per_node * number_of_nodes) @@ -662,34 +498,39 @@ def slurm_cluster_client( def auto_client(): + """ + A decorator that automatically manages a Dask client for the decorated function. + + If a client already exists, it uses the existing one. + Otherwise, it creates a new local_client and shuts it down after the function completes. + """ + def function_wrapper(function): @functools.wraps(function) def wrapper(*args, **kwargs): - persistent_client = False - - if not get_client() is None: - client = get_client() - persistent_client = True - else: + client = get_client() + persistent_client = client is not None + if not persistent_client: # Get client inputs if they exist - arguments = inspect.getcallargs(function, *args, **kwargs) - if "client" in kwargs.keys(): - client = local_client(**kwargs["client"]) - + if "client" in kwargs: + client_kwargs = kwargs["client"] + if isinstance(client_kwargs, dict): + client = local_client(**client_kwargs) + else: + client = local_client() else: client = local_client() try: - print(f"Dask dashboard started at: {client.dashboard_link}") + if client: + logger.info(f"Dask dashboard started at: {client.dashboard_link}") # Run the decorated function - result = function(*args, **kwargs) - - return result + return function(*args, **kwargs) finally: # Ensure the client is closed even if the function raises an exception - if not persistent_client: + if not persistent_client and client: client.shutdown() return wrapper diff --git a/src/toolviper/utils/data/cloudflare.py b/src/toolviper/utils/data/cloudflare.py index 49c18aa..2b9b5c3 100644 --- a/src/toolviper/utils/data/cloudflare.py +++ b/src/toolviper/utils/data/cloudflare.py @@ -4,280 +4,326 @@ import shutil import zipfile from threading import Thread -from typing import Any, Optional, Union +from typing import Any, Dict, List, Optional, Union import requests +import pandas as pd from rich.progress import Progress, TaskID import toolviper import toolviper.utils.console as console import toolviper.utils.logger as logger from toolviper.utils import parameter - -from collections import defaultdict from toolviper.utils.parameter import is_notebook -import pandas as pd +from collections import defaultdict colorize = console.Colorize() +# Constants PROGRESS_MAX_CHARACTERS = 28 MINIMUM_CHUNK_SIZE = 1024 +BASE_URL = "https://downloadnrao.org" +METADATA_REL_PATH = ".cloudflare/file.download.json" +USER_AGENT = "Wget/1.16 (linux-gnu)" + + +def _get_metadata_path() -> pathlib.Path: + """Get the absolute path to the local metadata file.""" + return pathlib.Path(__file__).parent.resolve().joinpath(METADATA_REL_PATH) def version() -> None: - # Load the file dropbox file meta data. - meta_data_path = pathlib.Path(__file__).parent.joinpath( - ".cloudflare/file.download.json" - ) + """ + Print the version of the cloudflare manifest. + """ + meta_data_path = _get_metadata_path() if not meta_data_path.parent.exists(): - logger.debug("metadata path doesn't exist... creating") - meta_data_path.parent.mkdir(parents=True) + logger.debug(f"Metadata path {meta_data_path.parent} doesn't exist... creating") + meta_data_path.parent.mkdir(parents=True, exist_ok=True) - # Verify that the download metadata exists and updates if not. _verify_metadata_file() - with open(meta_data_path) as json_file: - file_meta_data = json.load(json_file) + try: + with open(meta_data_path, "r") as json_file: + file_meta_data = json.load(json_file) + logger.info(f"Manifest version: {file_meta_data.get('version', 'unknown')}") - logger.info(f"{file_meta_data['version']}") + except (FileNotFoundError, json.JSONDecodeError) as e: + logger.error(f"Failed to read metadata file: {e}") @parameter.validate() def download( - file: Union[str, list], + file: Union[str, List[str]], folder: str = ".", overwrite: bool = False, decompress: bool = True, ) -> None: """ - Download tool for data stored externally. + Download tool for data stored externally. + Parameters ---------- - file : str - Filename as stored on an external source. - folder : str - Destination folder. - overwrite : bool - Should file be overwritten. - decompress : bool - Should file be unzipped. - - Returns - ------- - No return + file : str or list of str + Filename(s) as stored on an external source. + folder : str, optional + Destination folder. Defaults to ".". + overwrite : bool, optional + Whether to overwrite existing files. Defaults to False. + decompress : bool, optional + Whether to unzip downloaded files. Defaults to True. """ + logger.info("Initializing download...") - logger.info("Downloading from ....") - - if not isinstance(file, list): + if isinstance(file, str): file = [file] try: _print_file_queue(file) - except Exception as e: - logger.warning(f"There was a problem printing the file list... {e}") - - finally: - if not pathlib.Path(folder).resolve().exists(): - toolviper.utils.logger.info( - f"Creating path:{colorize.blue(str(pathlib.Path(folder).resolve()))}" - ) - pathlib.Path(folder).resolve().mkdir() - - logger.debug(f"Initializing downloader ...") - - meta_data_path = pathlib.Path(__file__).parent.joinpath( - ".cloudflare/file.download.json" - ) - - tasks = [] + logger.warning(f"Problem printing file list: {e}") - # Make a list of files that aren't available from cloudflare yet - missing_files = [] + dest_path = pathlib.Path(folder).resolve() + if not dest_path.exists(): + logger.info(f"Creating path: {colorize.blue(str(dest_path))}") + dest_path.mkdir(parents=True, exist_ok=True) - # Load the file dropbox file meta data. + meta_data_path = _get_metadata_path() if not meta_data_path.exists(): logger.warning( - f"Couldn't find file metadata locally in {colorize.blue(str(meta_data_path))}" + f"Metadata not found locally at {colorize.blue(str(meta_data_path))}" ) + update() - toolviper.utils.data.update() - - with open(meta_data_path) as json_file: - file_meta_data = json.load(json_file) + try: + with open(meta_data_path, "r") as json_file: + file_meta_data = json.load(json_file) + except (FileNotFoundError, json.JSONDecodeError) as e: + logger.error(f"Failed to load metadata: {e}") + return - # Build the task list - for file_ in file: - full_file_path = pathlib.Path(folder).joinpath(file_) + tasks = [] + missing_files = [] - if full_file_path.exists() and not overwrite: - logger.info(f"File exists: {str(full_file_path)}") - continue + def name_format(string): + return ( + f"{string[: (PROGRESS_MAX_CHARACTERS - 4)]} ..." + if len(string) > PROGRESS_MAX_CHARACTERS + else string + ) - if file_ not in file_meta_data["metadata"].keys(): - logger.error(f"Requested file not found: {file_}") - logger.info( - f"For a list of available files try using " - f"{colorize.blue('toolviper.utils.data.list_files()')}." - ) + for f_name in file: + full_file_path = dest_path.joinpath(f_name) - missing_files.append(file_) - continue + if full_file_path.exists() and not overwrite: + logger.info(f"File already exists: {full_file_path}") + continue - name_format = lambda string: ( - f"{string[: (PROGRESS_MAX_CHARACTERS - 4)]} ..." - if len(string) > PROGRESS_MAX_CHARACTERS - else string + if f_name not in file_meta_data.get("metadata", {}): + logger.error(f"Requested file not found in manifest: {f_name}") + logger.info( + f"Use {colorize.blue('toolviper.utils.data.list_files()')} for available files." ) + missing_files.append(f_name) + continue + + meta = file_meta_data["metadata"][f_name] + tasks.append( + { + "description": name_format(f_name), + "metadata": meta, + "folder": str(dest_path), + "visible": True, + "size": int(meta.get("size", 0)), + } + ) - tasks.append( - { - "description": name_format(file_), - "metadata": file_meta_data["metadata"][file_], - "folder": folder, - "visible": True, - "size": int(file_meta_data["metadata"][file_]["size"]), - } - ) + if not tasks: + if missing_files: + logger.error(f"Missing files: {missing_files}") + + return - threads = [] progress = Progress() + threads = [] with progress: - task_ids = [ - progress.add_task(task["description"]) for task in tasks if len(tasks) > 0 - ] - - for i, task in enumerate(tasks): - thread = Thread( - target=worker, args=(progress, task_ids[i], task, decompress) - ) + for task in tasks: + task_id = progress.add_task(task["description"]) + thread = Thread(target=worker, args=(progress, task_id, task, decompress)) thread.start() threads.append(thread) for thread in threads: thread.join() - if len(missing_files) > 0: - logger.error(f"Missing files: {missing_files}") + if missing_files: + logger.error(f"Could not download: {missing_files}") -def worker(progress: Progress, task_id: TaskID, task: dict, decompress=True) -> None: - """Simulate work being done in a thread""" - - filename = task["metadata"]["file"] +def worker( + progress: Progress, task_id: TaskID, task: dict, decompress: bool = True +) -> None: + """ + Worker function to download a file in a thread. - url = f"https://downloadnrao.org/{task['metadata']['path']}/{task['metadata']['file']}" + Parameters + ---------- + progress : Progress + Rich Progress instance. + task_id : TaskID + ID of the task in the progress bar. + task : dict + Task details including metadata and destination folder. + decompress : bool, optional + Whether to decompress the file after download. Defaults to True. + """ + metadata = task["metadata"] + filename = metadata["file"] + path = metadata.get("path", "").strip("/") + url = f"{BASE_URL}/{path}/{filename}" if path else f"{BASE_URL}/{filename}" - r = requests.get(url, stream=True, headers={"user-agent": "Wget/1.16 (linux-gnu)"}) - total = int(r.headers.get("Content-Length", 0)) + try: + response = requests.get( + url, stream=True, headers={"user-agent": USER_AGENT}, timeout=30 + ) + response.raise_for_status() + except Exception as e: + logger.error(f"Failed to initiate download for {filename}: {e}") + return + total = int(response.headers.get("Content-Length", 0)) if total == 0: - total = task["size"] - - fullname = str(pathlib.Path(task["folder"]).joinpath(filename)) - - size = 0 - - with open(fullname, "wb") as fd: - for chunk in r.iter_content(chunk_size=MINIMUM_CHUNK_SIZE): - if chunk: - size += fd.write(chunk) - progress.update( - task_id, completed=size, total=total, visible=task["visible"] - ) + total = task.get("size", 0) - # Verify checksum on file - # toolviper.utils.verify(filename, task["folder"]) + dest_folder = pathlib.Path(task["folder"]) + fullname = dest_folder.joinpath(filename) - if decompress: - if zipfile.is_zipfile(fullname): - shutil.unpack_archive(filename=fullname, extract_dir=task["folder"]) + try: + size = 0 + with open(fullname, "wb") as fd: + for chunk in response.iter_content(chunk_size=MINIMUM_CHUNK_SIZE): + if chunk: + size += fd.write(chunk) + progress.update( + task_id, completed=size, total=total, visible=task["visible"] + ) + except Exception as e: + logger.error(f"Error writing file {filename}: {e}") + return - # Let's clean up after ourselves + if decompress and zipfile.is_zipfile(fullname): + try: + shutil.unpack_archive(filename=str(fullname), extract_dir=str(dest_folder)) os.remove(fullname) + except Exception as e: + logger.error(f"Failed to decompress {filename}: {e}") class ToolviperFiles: - def __init__(self, manifest, dataframe=None): + """ + Helper class for managing and displaying toolviper data manifests. + """ + def __init__(self, manifest: str, dataframe: Optional[pd.DataFrame] = None) -> None: self.manifest = manifest self.dataframe = dataframe - self.notebook_mode = False + self.notebook_mode = is_notebook() - if is_notebook(): - import itables + if self.notebook_mode: + try: + import itables - self.notebook_mode = True + itables.init_notebook_mode() - itables.init_notebook_mode() + except ImportError: + logger.debug("itables not found, falling back to standard display.") - def __call__(self): + def __call__(self) -> Optional[pd.DataFrame]: if not self.notebook_mode: - return print(self.dataframe) + print(self.dataframe) + return None - else: - return self.dataframe + return self.dataframe - def print(self) -> Union[None, pd.DataFrame]: + def print(self) -> Optional[pd.DataFrame]: + """ + Display the dataframe using appropriate formatting. + """ if not self.notebook_mode: - import tabulate - - print( - tabulate.tabulate( - self.dataframe, showindex=False, headers=self.dataframe.columns + try: + import tabulate + + print( + tabulate.tabulate( + self.dataframe, + showindex=False, + headers=self.dataframe.columns, + ) ) - ) + except ImportError: + print(self.dataframe) + return None return self.dataframe @classmethod - def from_manifest(cls, manifest: str): + def from_manifest(cls, manifest: str) -> "ToolviperFiles": + """ + Create a ToolviperFiles instance from a manifest file. + """ meta_data_path = pathlib.Path(manifest) - # Verify that the download metadata exist and update if not. - # _verify_metadata_file() + try: + with open(meta_data_path, "r") as json_file: + file_meta_data = json.load(json_file) - with open(meta_data_path) as json_file: + except (FileNotFoundError, json.JSONDecodeError) as e: + logger.error(f"Failed to load manifest {manifest}: {e}") - file_meta_data = json.load(json_file) + return cls(manifest=manifest, dataframe=pd.DataFrame()) - files = file_meta_data["metadata"].keys() + metadata_dict = file_meta_data.get("metadata", {}) + data = defaultdict(list) - data = defaultdict(list) - data["file"] = list(files) + for file_name, meta in metadata_dict.items(): + data["file"].append(file_name) - for file_, metadata_ in file_meta_data["metadata"].items(): - for key_, value_ in metadata_.items(): - if key_ == "file": - continue + for key, value in meta.items(): + if key == "file": + continue - # I think we could do this with a JSON ENCODER - # but this is easier since the file is small - # and everything is a string already + if key == "size": + try: + value = int(value) - if value_ == "size": - value_ = int(value_) + except (ValueError, TypeError): + pass - data[key_].append(value_) + data[key].append(value) - return cls(manifest=manifest, dataframe=pd.DataFrame(data)) + return cls(manifest=manifest, dataframe=pd.DataFrame(data)) -def list_files(truncate=None) -> pd.DataFrame: +def list_files(truncate: Optional[int] = None) -> Optional[pd.DataFrame]: + """ + List all files available in the cloudflare manifest. + Parameters + ---------- + truncate : int, optional + Maximum number of rows to display. Defaults to None. + """ pd.set_option("display.max_rows", truncate) pd.set_option("display.colheader_justify", "left") - meta_data_path = pathlib.Path(__file__).parent.joinpath( - ".cloudflare/file.download.json" - ) + meta_data_path = _get_metadata_path() + if not meta_data_path.exists(): + _verify_metadata_file() table = ToolviperFiles.from_manifest(str(meta_data_path)) - return table.print() @@ -322,46 +368,44 @@ def list_files_() -> None: console.print(table) -def get_files() -> list[Any]: +def get_files() -> List[str]: """ - Get all files available in cloudflare manifest. This is retrieved from the local cloudflare - metadata file. - + Get a list of all file names available in the cloudflare manifest. """ - meta_data_path = pathlib.Path(__file__).parent.joinpath( - ".cloudflare/file.download.json" - ) - - # Verify that the download metadata exists and updates if not. + meta_data_path = _get_metadata_path() _verify_metadata_file() - with open(meta_data_path) as json_file: - file_meta_data = json.load(json_file) + try: + with open(meta_data_path, "r") as json_file: + file_meta_data = json.load(json_file) + return list(file_meta_data.get("metadata", {}).keys()) - return list(file_meta_data["metadata"].keys()) + except (FileNotFoundError, json.JSONDecodeError): + return [] @parameter.validate() -def update(path: Union[str, None] = None) -> None: +def update(path: Optional[str] = None) -> None: """ - Update cloudflare manifest. + Update the local cloudflare manifest by downloading the latest version. Parameters ---------- - path : str - In the case that you want an updated copy of the manifest for modification, this is the path to save it to. + path : str, optional + Custom path to save the manifest to. Defaults to the internal .cloudflare directory. """ - if path is None: - meta_data_path = pathlib.Path(__file__).parent.joinpath(".cloudflare") + meta_data_dir = _get_metadata_path().parent + meta_data_path = _get_metadata_path() else: - # I know this is an unnecessary copy but I don't want a big erbose path name in the inpute variables. - meta_data_path = pathlib.Path(path) + meta_data_dir = pathlib.Path(path) + meta_data_path = meta_data_dir.joinpath("file.download.json") - if not meta_data_path.exists(): - _make_dir(str(pathlib.Path(__file__).parent), ".cloudflare") + if not meta_data_dir.exists(): + meta_data_dir.mkdir(parents=True, exist_ok=True) + # Temporary metadata to kickstart the download of the actual manifest file_meta_data = { "file": "file.download.json", "path": "/", @@ -371,94 +415,81 @@ def update(path: Union[str, None] = None) -> None: "mode": "NA", } - tasks = { - "description": "file.download.json", + task = { + "description": "Updating manifest", "metadata": file_meta_data, - "folder": meta_data_path, + "folder": str(meta_data_dir), "visible": False, "size": 12484, } - logger.info("Updating file metadata information ... ") + logger.info("Updating file metadata information...") progress = Progress() - task_id = progress.add_task(tasks["description"]) + task_id = progress.add_task(task["description"]) with progress: - worker(progress, task_id, tasks) + worker(progress, task_id, task, decompress=False) if not meta_data_path.exists(): logger.error("Unable to retrieve download metadata.") - raise FileNotFoundError( - "Download metadata file does not exist at the expected path." - ) + raise FileNotFoundError(f"Download metadata file not found at {meta_data_path}") @parameter.validate() -def get_file_size(path: str) -> Optional[dict]: - """ - Get list file sizes in bytes for a given path. Only works for files; isn't recursive. +def get_file_size(path: str) -> Dict[str, int]: """ - if not pathlib.Path(path).resolve().exists(): - logger.error(f"Path not found...: {path}") - - return None - - file_size_dict = {} + Get file sizes in bytes for all files in a given path. - for item in pathlib.Path(path).resolve().iterdir(): - if pathlib.Path(item).resolve().is_file(): - if item.name.endswith(".zip"): - item_ = item.name.split(".zip")[0] + Parameters + ---------- + path : str + The directory path to scan. - else: - item_ = item.name + Returns + ------- + dict + A dictionary mapping file names to their sizes in bytes. + """ + path_obj = pathlib.Path(path).resolve() + if not path_obj.exists() or not path_obj.is_dir(): + logger.error(f"Path not found or is not a directory: {path}") + return {} - file_size_dict[item_] = os.path.getsize(pathlib.Path(item)) + file_size_dict = {} + for item in path_obj.iterdir(): + if item.is_file(): + name = ( + item.name.split(".zip")[0] if item.name.endswith(".zip") else item.name + ) + file_size_dict[name] = item.stat().st_size return file_size_dict -def _print_file_queue(files: list) -> None: +def _print_file_queue(files: List[str]) -> None: + """ + Print a formatted list of files to be downloaded. + """ from rich import box from rich.console import Console from rich.table import Table - assert type(files) == list - console_ = Console() table = Table(show_header=True, box=box.SIMPLE) - table.add_column("Download List", justify="left") - for file in files: - table.add_row(f"[magenta]{file}[/magenta]") + for f_name in files: + table.add_row(f"[magenta]{f_name}[/magenta]") console_.print(table) -def _make_dir(path, folder): - p = pathlib.Path(path).joinpath(folder) - try: - p.mkdir() - logger.info( - f"Creating path:{colorize.blue(str(pathlib.Path(folder).resolve()))}" - ) - - except FileExistsError: - logger.warning(f"File exists: {colorize.blue(str(p.resolve()))}") - - except FileNotFoundError: - logger.warning( - f"One fo the parent directories cannot be found: {colorize.blue(str(p.resolve()))}" - ) - - -def _verify_metadata_file(): - meta_data_path = pathlib.Path(__file__).parent.joinpath( - ".cloudflare/file.download.json" - ) - +def _verify_metadata_file() -> None: + """ + Ensure the metadata file exists, or trigger an update. + """ + meta_data_path = _get_metadata_path() if not meta_data_path.exists(): - logger.warning(f"Couldn't find {colorize.blue(str(meta_data_path))}.") + logger.warning(f"Metadata file {meta_data_path} missing. Updating...") update() diff --git a/src/toolviper/utils/logger.py b/src/toolviper/utils/logger.py index a076d86..993d5d3 100755 --- a/src/toolviper/utils/logger.py +++ b/src/toolviper/utils/logger.py @@ -12,252 +12,321 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import sys -import dask -import logging - from datetime import datetime -from toolviper.utils.console import Colorize -from toolviper.utils.console import add_verbose_info - -from dask.distributed import get_worker +from typing import Any, Dict, Optional, Union +import dask +import dask.distributed from contextvars import ContextVar +from dask.distributed import get_worker -from typing import Union +from toolviper.utils.console import Colorize, add_verbose_info -VERBOSE = True -DEFAULT = False +# Global verbosity flag +verbosity: ContextVar[Optional[bool]] = ContextVar("message_verbosity", default=None) -# global verbosity flag -verbosity: Union[ContextVar[bool], ContextVar[None]] = ContextVar( - "message_verbosity", default=None -) +# Constants for default values +DEFAULT_LOGGER_NAME = "viperlog" +LOGGER_ENV_VAR = "VIPER_LOGGER_NAME" -def set_verbosity(state: Union[None, bool] = None): - print(f"Setting verbosity to {state}") +def set_verbosity(state: Optional[bool] = None) -> None: + """ + Set the global verbosity state. + Parameters + ---------- + state : bool, optional + The verbosity state to set. If None, it uses the default. + """ verbosity.set(state) -def info(message: str, verbose: bool = False): - logger_name = os.getenv("LOGGER_NAME") - - if verbosity.get() is True or False: - verbose = verbosity.get() - - if verbose: - message = add_verbose_info(message=message, color="blue") +def _log_message( + level: str, message: str, verbose: bool = False, color: Optional[str] = None +) -> None: + """ + Helper function to process and log a message. + + Parameters + ---------- + level : str + The logging level (e.g., 'info', 'debug', 'warning'). + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to False. + color : str, optional + The color to use for verbose information. + """ + logger_name = os.getenv(LOGGER_ENV_VAR, DEFAULT_LOGGER_NAME) + + current_verbosity = verbosity.get() + if current_verbosity is not None: + verbose = current_verbosity + + if verbose and color: + message = add_verbose_info(message=message, color=color) logger = get_logger(logger_name=logger_name) - logger.info(message) - - -def log(message: str, verbose: bool = False): - logger_name = os.getenv("LOGGER_NAME") - - if verbosity.get() is True or False: - verbose = verbosity.get() + log_func = getattr(logger, level.lower()) + log_func(message) + + +def info(message: str, verbose: bool = False) -> None: + """ + Log an info level message. + + Parameters + ---------- + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to False. + """ + _log_message("info", message, verbose, color="blue") + + +def log(message: str, verbose: bool = False) -> None: + """ + Log a message at the current logger's level. + + Parameters + ---------- + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to False. + """ + logger_name = os.getenv(LOGGER_ENV_VAR, DEFAULT_LOGGER_NAME) + current_verbosity = verbosity.get() + if current_verbosity is not None: + verbose = current_verbosity if verbose: message = add_verbose_info(message=message, color="blue") logger = get_logger(logger_name=logger_name) - logger.log(logger.level, message) -def exception(message: str, verbose: bool = False): - logger_name = os.getenv("LOGGER_NAME") - - if verbosity.get() is True or False: - verbose = verbosity.get() +def exception(message: str, verbose: bool = False) -> None: + """ + Log an exception level message. + + Parameters + ---------- + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to False. + """ + _log_message("exception", message, verbose, color="blue") + + +def debug(message: str, verbose: bool = False) -> None: + """ + Log a debug level message. + + Parameters + ---------- + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to False. + """ + _log_message("debug", message, verbose, color="green") + + +def warning(message: str, verbose: bool = False) -> None: + """ + Log a warning level message. + + Parameters + ---------- + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to False. + """ + _log_message("warning", message, verbose, color="orange") + + +def error(message: str, verbose: bool = True) -> None: + """ + Log an error level message. + + Parameters + ---------- + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to True. + """ + _log_message("error", message, verbose, color="red") + + +def critical(message: str, verbose: bool = True) -> None: + """ + Log a critical level message. + + Parameters + ---------- + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to True. + """ + _log_message("critical", message, verbose, color="alert") - if verbose: - message = add_verbose_info(message=message, color="blue") - - logger = get_logger(logger_name=logger_name) - - logger.exception(message) - - -def debug(message: str, verbose: bool = False): - logger_name = os.getenv("LOGGER_NAME") - - if verbosity.get() is True or False: - verbose = verbosity.get() - - if verbose: - message = add_verbose_info(message=message, color="green") - - logger = get_logger(logger_name=logger_name) - logger.debug(message) - - -def warning(message: str, verbose: bool = False): - logger_name = os.getenv("LOGGER_NAME") - - if verbosity.get() is True or False: - verbose = verbosity.get() - - if verbose: - message = add_verbose_info(message=message, color="orange") - logger = get_logger(logger_name=logger_name) - logger.warning(message) - - -def error(message: str, verbose: bool = True): - logger_name = os.getenv("LOGGER_NAME") - - if verbosity.get() is True or False: - verbose = verbosity.get() - - if verbose: - message = add_verbose_info(message=message, color="red") - - logger = get_logger(logger_name=logger_name) - logger.error(message) - - -def critical(message: str, verbose: bool = True): - logger_name = os.getenv("LOGGER_NAME") +class ColorLoggingFormatter(logging.Formatter): + """ + A logging formatter that adds colors to the output based on the log level. + """ - if verbosity.get() is True or False: - verbose = verbosity.get() + colorize = Colorize() - if verbose: - message = add_verbose_info(message=message, color="alert") + def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None): + super().__init__(fmt, datefmt) + self.start_msg = f"[{self.colorize.purple('%(asctime)s')}] " + + self.FORMATS = { + logging.DEBUG: self.start_msg + + self.colorize.green("%(levelname)8s") + + self.colorize.grey(" %(name)10s: ") + + " %(message)s", + logging.INFO: self.start_msg + + self.colorize.blue("%(levelname)8s") + + self.colorize.grey(" %(name)10s: ") + + " %(message)s ", + logging.WARNING: self.start_msg + + self.colorize.orange("%(levelname)8s") + + self.colorize.grey(" %(name)10s: ") + + " %(message)s ", + logging.ERROR: self.start_msg + + self.colorize.red("%(levelname)8s") + + self.colorize.grey(" %(name)10s: ") + + " %(message)s", + logging.CRITICAL: self.start_msg + + self.colorize.format( + text="%(levelname)8s", color=[220, 60, 20], highlight=True + ) + + self.colorize.grey(" %(name)10s: ") + + " %(message)s", + } + + def format(self, record: logging.LogRecord) -> str: + log_fmt = self.FORMATS.get(record.levelno, self._fmt) + formatter = logging.Formatter(log_fmt, self.datefmt) - logger = get_logger(logger_name=logger_name) - logger.critical(message) + return formatter.format(record) -class ColorLoggingFormatter(logging.Formatter): - colorize = Colorize() +class LoggingFormatter(logging.Formatter): + """ + A standard logging formatter for file output. + """ + + def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None): + super().__init__(fmt, datefmt) + self.start_msg = "[%(asctime)s] " + self.middle_msg = "%(levelname)8s" + + self.FORMATS = { + level: f"{self.start_msg}{self.middle_msg} %(name)10s: %(message)s" + for level in [ + logging.DEBUG, + logging.INFO, + logging.WARNING, + logging.ERROR, + logging.CRITICAL, + ] + } + + def format(self, record: logging.LogRecord) -> str: + log_fmt = self.FORMATS.get(record.levelno, self._fmt) + formatter = logging.Formatter(log_fmt, self.datefmt) - function = " [{function}] ".format(function=colorize.blue("%(funcName)s")) - verbose = " [{exechain}] ".format( - exechain=colorize.blue("%(filename)s:%(lineno)s : %(module)s.%(funcName)s") - ) - - start_msg = "[{time}] ".format(time=colorize.purple("%(asctime)s")) - middle_msg = "{level}".format(level="%(levelname)8s") - execution_msg = " {name} [ {filename} ]: {exec_info}: ".format( - name="%(name)10s", - filename="%(filename)-20s", - exec_info=colorize.blue("%(callchain)-45s"), - ) - - FORMATS = { - logging.DEBUG: start_msg - + colorize.green(middle_msg) - + colorize.grey(" %(name)10s: ") - + " %(message)s", - logging.INFO: start_msg - + colorize.blue(middle_msg) - + colorize.grey(" %(name)10s: ") - + " %(message)s ", - logging.WARNING: start_msg - + colorize.orange(middle_msg) - + colorize.grey(" %(name)10s: ") - + " %(message)s ", - logging.ERROR: start_msg - + colorize.red(middle_msg) - + colorize.grey(" %(name)10s: ") - + " %(message)s", - logging.CRITICAL: start_msg - + colorize.format(text=middle_msg, color=[220, 60, 20], highlight=True) - + colorize.grey(" %(name)10s: ") - + " %(message)s", - } - - def format(self, record): - log_fmt = self.FORMATS.get(record.levelno) - formatter = logging.Formatter(log_fmt) return formatter.format(record) -class LoggingFormatter(logging.Formatter): - function = " [{function}] ".format(function="%(funcName)s") - verbose = " [{exechain}] ".format( - exechain="%(filename)s:%(lineno)s : %(module)s.%(funcName)s" - ) - - start_msg = "[{time}] ".format(time="%(asctime)s") - middle_msg = "{level}".format(level="%(levelname)8s") - execution_msg = " {name} [ {filename} ]: {exec_info}: ".format( - name="%(name)10s", filename="%(filename)-20s", exec_info="%(callchain)-45s" - ) - - FORMATS = { - logging.DEBUG: start_msg + middle_msg + " %(name)10s: " + " %(message)s", - logging.INFO: start_msg + middle_msg + " %(name)10s: " + " %(message)s ", - logging.WARNING: start_msg + middle_msg + " %(name)10s: " + " %(message)s ", - logging.ERROR: start_msg + middle_msg + " %(name)10s: " + " %(message)s", - logging.CRITICAL: start_msg + middle_msg + " %(name)10s: " + " %(message)s", - } - - def format(self, record): - log_fmt = self.FORMATS.get(record.levelno) - formatter = logging.Formatter(log_fmt) - return formatter.format(record) +def get_logger(logger_name: Optional[str] = None) -> logging.Logger: + """ + Get a logger instance by name, with fallback to environment or defaults. + Parameters + ---------- + logger_name : str, optional + The name of the logger to retrieve. -def get_logger(logger_name: Union[str, None] = None): + Returns + ------- + logging.Logger + The logger instance. + """ if logger_name is None: - if os.getenv("LOGGER_NAME"): - # Return default logger from env if none is specified. - logger_name = os.getenv("LOGGER_NAME") - else: - logger_name = "viperlog" + logger_name = os.getenv(LOGGER_ENV_VAR, DEFAULT_LOGGER_NAME) try: worker = get_worker() + # If we're on a worker, try to get the worker-specific logger from the plugin + if hasattr(worker, "plugins") and "worker_logger" in worker.plugins: + return worker.plugins["worker_logger"].get_logger() - except ValueError: - # Scheduler processes - logger_dict = logging.Logger.manager.loggerDict - if logger_name in logger_dict: - logger = logging.getLogger(logger_name) - else: - # If main logger is not started using client function it defaults to printing to term. - logger = logging.getLogger(logger_name) - stream_handler = logging.StreamHandler(sys.stdout) - stream_handler.setFormatter(ColorLoggingFormatter()) - logger.addHandler(stream_handler) - logger.setLevel(logging.getLevelName("INFO")) - - return logger - - try: - logger = worker.plugins["worker_logger"].get_logger() + except (ValueError, AttributeError, KeyError): + # Not on a worker, or worker logger plugin not available + pass - return logger + logger = logging.getLogger(logger_name) - except Exception as e: - print("Could not load worker logger: {}".format(e)) - print(worker.plugins.keys()) + # If the logger has no handlers, it hasn't been set up yet. + if not logger.handlers: + # Default to a simple stream handler if not explicitly set up + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setFormatter(ColorLoggingFormatter()) + logger.addHandler(stream_handler) + logger.setLevel(logging.INFO) - return logging.getLogger() + return logger def setup_logger( - logger_name: Union[str, None] = None, + logger_name: Optional[str] = None, log_to_term: bool = False, log_to_file: bool = True, log_file: str = "logger", log_level: str = "INFO", -): - """To set up as many loggers as you want""" +) -> logging.Logger: + """ + Configure and return a logger. + + Parameters + ---------- + logger_name : str, optional + The name of the logger to set up. + log_to_term : bool, optional + Whether to log to the terminal. Defaults to False. + log_to_file : bool, optional + Whether to log to a file. Defaults to True. + log_file : str, optional + The base name of the log file. + log_level : str, optional + The logging level (e.g., 'DEBUG', 'INFO'). Defaults to 'INFO'. + + Returns + ------- + logging.Logger + The configured logger. + """ if logger_name is None: - logger_name = "viperlog" + logger_name = DEFAULT_LOGGER_NAME logger = logging.getLogger(logger_name) - logger.setLevel(logging.getLevelName(log_level)) - + logger.setLevel(getattr(logging, log_level.upper(), logging.INFO)) logger.handlers.clear() if log_to_term: @@ -266,19 +335,38 @@ def setup_logger( logger.addHandler(stream_handler) if log_to_file: - log_file = log_file + datetime.today().strftime("%Y%m%d_%H%M%S") + ".log" - log_handler = logging.FileHandler(log_file) + timestamp = datetime.today().strftime("%Y%m%d_%H%M%S") + full_log_file = f"{log_file}{timestamp}.log" + log_handler = logging.FileHandler(full_log_file) log_handler.setFormatter(LoggingFormatter()) logger.addHandler(log_handler) return logger -def get_worker_logger_name(logger_name: Union[str, None] = None): +def get_worker_logger_name(logger_name: Optional[str] = None) -> str: + """ + Generate a unique logger name for a Dask worker. + + Parameters + ---------- + logger_name : str, optional + The base logger name. + + Returns + ------- + str + The worker-specific logger name. + """ if logger_name is None: - logger_name = "viperlog" + logger_name = DEFAULT_LOGGER_NAME + + try: + worker_id = get_worker().id + return f"{logger_name}_{worker_id}" - return "_".join((logger_name, str(get_worker().id))) + except (ValueError, AttributeError): + return logger_name def setup_worker_logger( @@ -287,12 +375,36 @@ def setup_worker_logger( log_to_file: bool, log_file: str, log_level: str, - worker: dask.distributed.worker.Worker, -): - parallel_logger_name = "_".join((logger_name, str(worker.name))) + worker: "dask.distributed.worker.Worker", +) -> logging.Logger: + """ + Configure and return a logger for a Dask worker. + + Parameters + ---------- + logger_name : str + The base name of the logger. + log_to_term : bool + Whether to log to the terminal. + log_to_file : bool + Whether to log to a file. + log_file : str + The base name of the log file. + log_level : str + The logging level. + worker : dask.distributed.worker.Worker + The Dask worker instance. + + Returns + ------- + logging.Logger + The configured worker logger. + """ + parallel_logger_name = f"{logger_name}_{worker.name}" logger = logging.getLogger(parallel_logger_name) - logger.setLevel(logging.getLevelName(log_level)) + logger.setLevel(getattr(logging, log_level.upper(), logging.INFO)) + logger.handlers.clear() if log_to_term: stream_handler = logging.StreamHandler(sys.stdout) @@ -300,20 +412,9 @@ def setup_worker_logger( logger.addHandler(stream_handler) if log_to_file: - logger.info(f"log_to_file: {log_file}") - dask.distributed.print(f"log_to_file: {log_to_file}") - - log_file = ( - log_file - + "_" - + str(worker.name) - + "_" - + datetime.today().strftime("%Y%m%d_%H%M%S") - + "_" - + str(worker.ip) - + ".log" - ) - log_handler = logging.FileHandler(log_file) + timestamp = datetime.today().strftime("%Y%m%d_%H%M%S") + full_log_file = f"{log_file}_{worker.name}_{timestamp}_{worker.ip}.log" + log_handler = logging.FileHandler(full_log_file) log_handler.setFormatter(LoggingFormatter()) logger.addHandler(log_handler) diff --git a/src/toolviper/utils/tools.py b/src/toolviper/utils/tools.py index 37462ab..b087258 100644 --- a/src/toolviper/utils/tools.py +++ b/src/toolviper/utils/tools.py @@ -147,7 +147,7 @@ def add_entry( ---------- entries : dict, list - Dictionary or list of metadata info that are needed to build the new entry. + Dictionary, or list of metadata info that are needed to build the new entry. manifest : str Points to the manifest you want to modify. diff --git a/tests/test_cloudflare.py b/tests/test_cloudflare.py new file mode 100644 index 0000000..8181450 --- /dev/null +++ b/tests/test_cloudflare.py @@ -0,0 +1,109 @@ +import os +import pathlib +import json +import pytest +import responses +from toolviper.utils.data import cloudflare +import pandas as pd + + +@pytest.fixture +def mock_metadata(tmp_path): + metadata = { + "version": "1.0.0", + "metadata": { + "test_file.zip": { + "file": "test_file.zip", + "path": "test", + "dtype": "ZIP", + "telescope": "ALMA", + "size": "100", + "mode": "test", + } + }, + } + meta_dir = tmp_path / ".cloudflare" + meta_dir.mkdir() + meta_file = meta_dir / "file.download.json" + with open(meta_file, "w") as f: + json.dump(metadata, f) + return meta_file + + +def test_version(mock_metadata, monkeypatch, caplog): + # Mock __file__ in cloudflare to point to our temp directory + monkeypatch.setattr( + cloudflare, "__file__", str(mock_metadata.parent.parent / "cloudflare.py") + ) + + with caplog.at_level("INFO"): + cloudflare.version() + assert "1.0.0" in caplog.text + + +@responses.activate +def test_download(mock_metadata, monkeypatch, tmp_path): + monkeypatch.setattr( + cloudflare, "__file__", str(mock_metadata.parent.parent / "cloudflare.py") + ) + + url = "https://downloadnrao.org/test/test_file.zip" + responses.add( + responses.GET, + url, + body=b"test data", + status=200, + headers={"Content-Length": "9"}, + ) + + dest_folder = tmp_path / "dest" + cloudflare.download("test_file.zip", folder=str(dest_folder), decompress=False) + + assert (dest_folder / "test_file.zip").exists() + with open(dest_folder / "test_file.zip", "rb") as f: + assert f.read() == b"test data" + + +def test_get_files(mock_metadata, monkeypatch): + monkeypatch.setattr( + cloudflare, "__file__", str(mock_metadata.parent.parent / "cloudflare.py") + ) + files = cloudflare.get_files() + assert "test_file.zip" in files + + +def test_get_file_size(tmp_path): + test_file = tmp_path / "test.txt" + test_file.write_text("hello") + + sizes = cloudflare.get_file_size(str(tmp_path)) + assert sizes["test.txt"] == 5 + + +@responses.activate +def test_update(mock_metadata, monkeypatch, tmp_path): + monkeypatch.setattr( + cloudflare, "__file__", str(mock_metadata.parent.parent / "cloudflare.py") + ) + + url = "https://downloadnrao.org/file.download.json" + new_metadata = {"version": "1.1.0", "metadata": {}} + responses.add(responses.GET, url, json=new_metadata, status=200) + + update_path = tmp_path / "update_dir" + update_path.mkdir() + cloudflare.update(path=str(update_path)) + + assert (update_path / "file.download.json").exists() + + +def test_list_files(mock_metadata, monkeypatch): + monkeypatch.setattr( + cloudflare, "__file__", str(mock_metadata.parent.parent / "cloudflare.py") + ) + # list_files returns pd.DataFrame or None (if it prints) + # By default it might try to use itables or tabulate + df = cloudflare.list_files() + if df is not None: + assert isinstance(df, pd.DataFrame) + assert "test_file.zip" in df["file"].values From 4dbf9738a5a18245490ff42b27c2f66857284ec1 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Sat, 28 Feb 2026 00:57:08 +0900 Subject: [PATCH 11/24] [black] Add even more additional feature + tests? --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 2362d37..619b2ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ 'pandas', 'itables', 'requests', + 'responses', 'tabulate', 'tqdm', ] From 3ecbbcd4df7f6bed8b410b3cd27abbcec55da21f Mon Sep 17 00:00:00 2001 From: jrhosk Date: Sat, 28 Feb 2026 01:05:15 +0900 Subject: [PATCH 12/24] [black] Add even more additional feature + tests? --- tests/test_logger.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index 05f1410..46567d9 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -105,7 +105,7 @@ def test_log_logging(mock_logger): def test_get_logger_no_env_no_worker(monkeypatch): - monkeypatch.delenv("LOGGER_NAME", raising=False) + monkeypatch.delenv("VIPER_LOGGER_NAME", raising=False) with patch("toolviper.utils.logger.get_worker", side_effect=ValueError): logger = get_logger() assert logger.name == "viperlog" @@ -114,7 +114,7 @@ def test_get_logger_no_env_no_worker(monkeypatch): def test_get_logger_existing_logger(monkeypatch): - monkeypatch.delenv("LOGGER_NAME", raising=False) + monkeypatch.delenv("VIPER_LOGGER_NAME", raising=False) # Pre-create logger existing_logger = logging.getLogger("existing_log") with patch("toolviper.utils.logger.get_worker", side_effect=ValueError): @@ -124,7 +124,7 @@ def test_get_logger_existing_logger(monkeypatch): def test_get_logger_env(): with ( - patch("os.environ", {"LOGGER_NAME": "env_logger"}), + patch("os.environ", {"VIPER_LOGGER_NAME": "env_logger"}), patch("toolviper.utils.logger.get_worker", side_effect=ValueError), ): logger = get_logger() From f43039b8600b7f700a9be2cf24310de1ace0ee09 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Sat, 28 Feb 2026 20:17:32 +0900 Subject: [PATCH 13/24] [black] Add even more additional feature + tests? --- ...graphviper-logger-formatting-example.ipynb | 31 +++++++++---------- src/toolviper/utils/logger.py | 4 +++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/docs/graphviper-logger-formatting-example.ipynb b/docs/graphviper-logger-formatting-example.ipynb index 3b00c85..95c15e2 100644 --- a/docs/graphviper-logger-formatting-example.ipynb +++ b/docs/graphviper-logger-formatting-example.ipynb @@ -313,10 +313,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,504\u001b[0m] \u001b[38;2;50;50;205m INFO\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m Here is an info message. \n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,505\u001b[0m] \u001b[38;2;255;160;0m WARNING\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m Here is an warning message. \n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,508\u001b[0m] \u001b[38;2;220;20;60m ERROR\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[38;2;220;20;60m\u001b[0m]: Here is an error message.\n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,512\u001b[0m] \u001b[7;38;2;220;60;20mCRITICAL\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[7;38;2;220;60;20m\u001b[0m]: Here is an critical message.\n" + "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,646\u001b[0m] \u001b[38;2;50;50;205m INFO\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m Here is an info message. \n", + "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,646\u001b[0m] \u001b[38;2;255;160;0m WARNING\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m Here is an warning message. \n", + "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,650\u001b[0m] \u001b[38;2;220;20;60m ERROR\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m [\u001b[38;2;220;20;60merror\u001b[0m]: Here is an error message.\n", + "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,654\u001b[0m] \u001b[7;38;2;220;60;20mCRITICAL\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m [\u001b[7;38;2;220;60;20mcritical\u001b[0m]: Here is an critical message.\n" ] } ], @@ -338,10 +338,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,519\u001b[0m] \u001b[38;2;50;50;205m INFO\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[38;2;50;50;205mverbose_log\u001b[0m]: Here's a info message. \n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,522\u001b[0m] \u001b[38;2;255;160;0m WARNING\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[38;2;255;160;0mverbose_log\u001b[0m]: Here's a warning message. \n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,528\u001b[0m] \u001b[38;2;220;20;60m ERROR\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[38;2;220;20;60mverbose_log\u001b[0m]: Here's a error message.\n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,532\u001b[0m] \u001b[7;38;2;220;60;20mCRITICAL\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[7;38;2;220;60;20mverbose_log\u001b[0m]: Here's a critical message.\n" + "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,659\u001b[0m] \u001b[38;2;50;50;205m INFO\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m [\u001b[38;2;50;50;205minfo\u001b[0m]: Here's a info message. \n", + "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,662\u001b[0m] \u001b[38;2;255;160;0m WARNING\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m [\u001b[38;2;255;160;0mwarning\u001b[0m]: Here's a warning message. \n", + "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,668\u001b[0m] \u001b[38;2;220;20;60m ERROR\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m [\u001b[38;2;220;20;60merror\u001b[0m]: Here's a error message.\n", + "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,671\u001b[0m] \u001b[7;38;2;220;60;20mCRITICAL\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m [\u001b[7;38;2;220;60;20mcritical\u001b[0m]: Here's a critical message.\n" ] } ], @@ -368,7 +368,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "id": "c9aa9c09-239b-4403-8d83-fb90c93c9782", "metadata": {}, "outputs": [ @@ -376,16 +376,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "Setting verbosity to True\n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,539\u001b[0m] \u001b[38;2;50;50;205m INFO\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[38;2;50;50;205mverbose_log\u001b[0m]: Here's a info message. \n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,542\u001b[0m] \u001b[38;2;255;160;0m WARNING\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[38;2;255;160;0mverbose_log\u001b[0m]: Here's a warning message. \n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,548\u001b[0m] \u001b[38;2;220;20;60m ERROR\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[38;2;220;20;60mverbose_log\u001b[0m]: Here's a error message.\n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,552\u001b[0m] \u001b[7;38;2;220;60;20mCRITICAL\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[7;38;2;220;60;20mverbose_log\u001b[0m]: Here's a critical message.\n" + "[\u001b[38;2;128;05;128m2026-02-28 20:16:10,073\u001b[0m] \u001b[38;2;50;50;205m INFO\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m Here's a info message. \n", + "[\u001b[38;2;128;05;128m2026-02-28 20:16:10,074\u001b[0m] \u001b[38;2;255;160;0m WARNING\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m Here's a warning message. \n", + "[\u001b[38;2;128;05;128m2026-02-28 20:16:10,074\u001b[0m] \u001b[38;2;220;20;60m ERROR\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m Here's a error message.\n", + "[\u001b[38;2;128;05;128m2026-02-28 20:16:10,075\u001b[0m] \u001b[7;38;2;220;60;20mCRITICAL\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m Here's a critical message.\n" ] } ], "source": [ - "logger.set_verbosity(state=logger.VERBOSE)\n", + "logger.set_verbosity(state=logger.DEFAULT)\n", "verbose_log()" ] }, @@ -414,7 +413,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.2" + "version": "3.12.12" } }, "nbformat": 4, diff --git a/src/toolviper/utils/logger.py b/src/toolviper/utils/logger.py index 993d5d3..12f39d3 100755 --- a/src/toolviper/utils/logger.py +++ b/src/toolviper/utils/logger.py @@ -32,6 +32,9 @@ DEFAULT_LOGGER_NAME = "viperlog" LOGGER_ENV_VAR = "VIPER_LOGGER_NAME" +VERBOSE = True +DEFAULT = False + def set_verbosity(state: Optional[bool] = None) -> None: """ @@ -267,6 +270,7 @@ def get_logger(logger_name: Optional[str] = None) -> logging.Logger: ------- logging.Logger The logger instance. + """ if logger_name is None: logger_name = os.getenv(LOGGER_ENV_VAR, DEFAULT_LOGGER_NAME) From a9c3b67f255d962f0b46af5fd3b80582e7854538 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Sat, 28 Feb 2026 23:03:50 +0900 Subject: [PATCH 14/24] [black] Add even more additional feature + tests? --- ...graphviper-logger-formatting-example.ipynb | 161 +++--------------- src/toolviper/utils/data/cloudflare.py | 2 + tests/test_download.py | 12 +- 3 files changed, 32 insertions(+), 143 deletions(-) diff --git a/docs/graphviper-logger-formatting-example.ipynb b/docs/graphviper-logger-formatting-example.ipynb index 95c15e2..f9269aa 100644 --- a/docs/graphviper-logger-formatting-example.ipynb +++ b/docs/graphviper-logger-formatting-example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "25c74380", "metadata": {}, "outputs": [], @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "49e07fb6", "metadata": {}, "outputs": [], @@ -55,36 +55,20 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "85a7848e-acce-4a34-8ace-80045c22ddbb", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Let's format some \u001b[1mtext\u001b[0m!\n" - ] - } - ], + "outputs": [], "source": [ "print(\"Let's format some {text}!\".format(text=colorize.bold(\"text\")))" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "77c1abed-d2b5-4f93-930a-6abe02b83547", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Let's format some \u001b[4mtext\u001b[0m!\n" - ] - } - ], + "outputs": [], "source": [ "print(\"Let's format some {text}!\".format(text=colorize.underline(\"text\")))" ] @@ -111,36 +95,20 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "6b579048", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Let's format some \u001b[38;2;220;20;60mtext\u001b[0m!\n" - ] - } - ], + "outputs": [], "source": [ "print(\"Let's format some {text}!\".format(text=colorize.red(\"text\")))" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "5f6d0b00-580e-42ac-a4da-5458b0854e8a", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Let's format some \u001b[7;38;2;220;60;20mtext\u001b[0m!\n" - ] - } - ], + "outputs": [], "source": [ "print(\"Let's format some {text}!\".format(text=colorize.alert(\"text\")))" ] @@ -157,18 +125,10 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "54ea8c43", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Let's format some \u001b[1;4;38;2;128;05;128mtext\u001b[0m!\n" - ] - } - ], + "outputs": [], "source": [ "print(\n", " \"Let's format some {text}!\".format(\n", @@ -179,18 +139,10 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "fe0edaba", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Let's format some critical \u001b[7;38;2;220;20;60mtext\u001b[0m!\n" - ] - } - ], + "outputs": [], "source": [ "print(\n", " \"Let's format some critical {text}!\".format(\n", @@ -201,18 +153,10 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "a822b7c2", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Let's format some critical \u001b[0;38;2;33;150;243mtext\u001b[0m!\n" - ] - } - ], + "outputs": [], "source": [ "print(\n", " \"Let's format some critical {text}!\".format(\n", @@ -233,7 +177,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "5b8611da", "metadata": {}, "outputs": [], @@ -247,18 +191,10 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "f1ae62b5", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[\u001b[38;2;50;50;205msome_function_to_log\u001b[0m]: Here is some special text\n" - ] - } - ], + "outputs": [], "source": [ "some_function_to_log()" ] @@ -285,7 +221,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "11e59b05", "metadata": {}, "outputs": [], @@ -295,7 +231,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "4f24c870", "metadata": {}, "outputs": [], @@ -305,21 +241,10 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "9b236679", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,646\u001b[0m] \u001b[38;2;50;50;205m INFO\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m Here is an info message. \n", - "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,646\u001b[0m] \u001b[38;2;255;160;0m WARNING\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m Here is an warning message. \n", - "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,650\u001b[0m] \u001b[38;2;220;20;60m ERROR\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m [\u001b[38;2;220;20;60merror\u001b[0m]: Here is an error message.\n", - "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,654\u001b[0m] \u001b[7;38;2;220;60;20mCRITICAL\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m [\u001b[7;38;2;220;60;20mcritical\u001b[0m]: Here is an critical message.\n" - ] - } - ], + "outputs": [], "source": [ "logger.info(\"Here is an info message.\")\n", "logger.warning(\"Here is an warning message.\")\n", @@ -330,21 +255,10 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "2d2b8d48", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,659\u001b[0m] \u001b[38;2;50;50;205m INFO\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m [\u001b[38;2;50;50;205minfo\u001b[0m]: Here's a info message. \n", - "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,662\u001b[0m] \u001b[38;2;255;160;0m WARNING\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m [\u001b[38;2;255;160;0mwarning\u001b[0m]: Here's a warning message. \n", - "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,668\u001b[0m] \u001b[38;2;220;20;60m ERROR\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m [\u001b[38;2;220;20;60merror\u001b[0m]: Here's a error message.\n", - "[\u001b[38;2;128;05;128m2026-02-28 20:15:56,671\u001b[0m] \u001b[7;38;2;220;60;20mCRITICAL\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m [\u001b[7;38;2;220;60;20mcritical\u001b[0m]: Here's a critical message.\n" - ] - } - ], + "outputs": [], "source": [ "def verbose_log():\n", " logger.info(\"Here's a info message.\", verbose=True)\n", @@ -368,33 +282,14 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "c9aa9c09-239b-4403-8d83-fb90c93c9782", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[\u001b[38;2;128;05;128m2026-02-28 20:16:10,073\u001b[0m] \u001b[38;2;50;50;205m INFO\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m Here's a info message. \n", - "[\u001b[38;2;128;05;128m2026-02-28 20:16:10,074\u001b[0m] \u001b[38;2;255;160;0m WARNING\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m Here's a warning message. \n", - "[\u001b[38;2;128;05;128m2026-02-28 20:16:10,074\u001b[0m] \u001b[38;2;220;20;60m ERROR\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m Here's a error message.\n", - "[\u001b[38;2;128;05;128m2026-02-28 20:16:10,075\u001b[0m] \u001b[7;38;2;220;60;20mCRITICAL\u001b[0m\u001b[38;2;112;128;144m toolviper: \u001b[0m Here's a critical message.\n" - ] - } - ], + "outputs": [], "source": [ - "logger.set_verbosity(state=logger.DEFAULT)\n", + "logger.set_verbosity(state=logger.VERBOSE)\n", "verbose_log()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a14d9eda-29a3-4470-801f-dc36958ed00d", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/src/toolviper/utils/data/cloudflare.py b/src/toolviper/utils/data/cloudflare.py index 2b9b5c3..7ac0f88 100644 --- a/src/toolviper/utils/data/cloudflare.py +++ b/src/toolviper/utils/data/cloudflare.py @@ -475,6 +475,8 @@ def _print_file_queue(files: List[str]) -> None: from rich.console import Console from rich.table import Table + assert isinstance(files, list), logger.error("files must be a list") + console_ = Console() table = Table(show_header=True, box=box.SIMPLE) table.add_column("Download List", justify="left") diff --git a/tests/test_download.py b/tests/test_download.py index e4815b0..353f44c 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -148,13 +148,5 @@ def test_private_print_file_queue(self): logger.info("Failure test passed!") return None - # If error isn't as expected, fail the test. - raise AssertionError() - - def test_private_make_dir(self): - from toolviper.utils.data.cloudflare import _make_dir - - _make_dir(path=str(pathlib.Path.cwd()), folder="data") - - if not pathlib.Path.cwd().joinpath("data").exists(): - raise FileNotFoundError("data") + # If the error isn't as expected, fail the test. + raise AssertionError() \ No newline at end of file From c105b904703829e2210050110f8b27a7eb19b2ac Mon Sep 17 00:00:00 2001 From: jrhosk Date: Sat, 28 Feb 2026 23:10:42 +0900 Subject: [PATCH 15/24] [black] Add even more additional feature + tests? --- tests/test_download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_download.py b/tests/test_download.py index 353f44c..438ac3c 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -149,4 +149,4 @@ def test_private_print_file_queue(self): return None # If the error isn't as expected, fail the test. - raise AssertionError() \ No newline at end of file + raise AssertionError() From 2cdeaa0ae12255e0e0994b10a551a8c1b3fef6c9 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Mon, 2 Mar 2026 09:27:14 +0900 Subject: [PATCH 16/24] Fix small things in the notebooks. --- docs/client_tutorial.ipynb | 4 +-- docs/download_example.ipynb | 48 +++++++++++++------------- src/toolviper/utils/data/cloudflare.py | 5 +-- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/docs/client_tutorial.ipynb b/docs/client_tutorial.ipynb index c1b179d..e528e0b 100644 --- a/docs/client_tutorial.ipynb +++ b/docs/client_tutorial.ipynb @@ -40,7 +40,7 @@ }, "outputs": [], "source": [ - "toolviper.utils.data.download(file=\"AA2-Mid-sim_00000.ms\")" + "toolviper.utils.data.download(file=\"AA2-Mid-sim_00000.ms\", folder=\"data\")" ] }, { @@ -197,7 +197,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.13" + "version": "3.12.12" } }, "nbformat": 4, diff --git a/docs/download_example.ipynb b/docs/download_example.ipynb index 895f169..3d6b48b 100644 --- a/docs/download_example.ipynb +++ b/docs/download_example.ipynb @@ -10,15 +10,6 @@ "import toolviper" ] }, - { - "cell_type": "markdown", - "id": "01bc8397-af47-48df-8391-733bba286588", - "metadata": {}, - "source": [ - "## Getting download metadata version\n", - "#### This will retireve the current file download metadata version; if the file is not found it will attempt to update to the most recent version." - ] - }, { "cell_type": "code", "execution_count": null, @@ -31,50 +22,59 @@ }, { "cell_type": "markdown", - "id": "fe18a0d9-bd21-4ce5-91fd-6348e5f1369e", + "id": "0351dda7-7559-449d-9ebb-c853beb0861e", "metadata": {}, "source": [ - "### Manually update metdata info." + "## Getting available downloadable file in a python list.\n", + "#### This will return an unordered list of the available file on the remote dropbox in a python list. This can be used as an input to thie download function as well." ] }, { "cell_type": "code", "execution_count": null, - "id": "deeee662-94d9-44a0-95e7-f062f55223ca", + "id": "2912a5ae-7a4b-42f3-a633-cbe5ecaffc74", "metadata": {}, "outputs": [], "source": [ - "toolviper.utils.data.update()" + "files = toolviper.utils.data.get_files()\n", + "files" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28c8b343-3a05-4217-a581-e51c7db95d02", + "metadata": {}, + "outputs": [], + "source": [ + "toolviper.utils.data.download(file=files[3:6], folder=\"data\")" ] }, { "cell_type": "markdown", - "id": "0351dda7-7559-449d-9ebb-c853beb0861e", + "id": "01bc8397-af47-48df-8391-733bba286588", "metadata": {}, "source": [ - "## Getting available downloadable file in a python list.\n", - "#### This will return an unordered list of the available file on the remote dropbox in a python list. This can be used as an input to thie download function as well." + "## Getting download metadata version\n", + "#### This will retireve the current file download metadata version; if the file is not found it will attempt to update to the most recent version." ] }, { - "cell_type": "code", - "execution_count": null, - "id": "2912a5ae-7a4b-42f3-a633-cbe5ecaffc74", + "cell_type": "markdown", + "id": "fe18a0d9-bd21-4ce5-91fd-6348e5f1369e", "metadata": {}, - "outputs": [], "source": [ - "files = toolviper.utils.data.get_files()\n", - "files" + "### Manually update metdata info." ] }, { "cell_type": "code", "execution_count": null, - "id": "28c8b343-3a05-4217-a581-e51c7db95d02", + "id": "d521c952-7f30-4454-8548-de5ac5b98d82", "metadata": {}, "outputs": [], "source": [ - "toolviper.utils.data.download(file=files[6:8], folder=\"data\")" + "toolviper.utils.data.update()" ] }, { diff --git a/src/toolviper/utils/data/cloudflare.py b/src/toolviper/utils/data/cloudflare.py index 7ac0f88..451fc11 100644 --- a/src/toolviper/utils/data/cloudflare.py +++ b/src/toolviper/utils/data/cloudflare.py @@ -8,11 +8,12 @@ import requests import pandas as pd + from rich.progress import Progress, TaskID -import toolviper import toolviper.utils.console as console import toolviper.utils.logger as logger + from toolviper.utils import parameter from toolviper.utils.parameter import is_notebook from collections import defaultdict @@ -489,7 +490,7 @@ def _print_file_queue(files: List[str]) -> None: def _verify_metadata_file() -> None: """ - Ensure the metadata file exists, or trigger an update. + Ensure the metadata file exists or trigger an update. """ meta_data_path = _get_metadata_path() if not meta_data_path.exists(): From a35d940349bc2ec2a6f10b4fb7c1266ec03149a0 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Mon, 9 Mar 2026 10:25:43 +0900 Subject: [PATCH 17/24] Various fixes along with adding prototyping for graph creation. --- src/toolviper/dask/client.py | 8 ++- src/toolviper/dask/menrva.py | 14 ++--- src/toolviper/dask/plugins/worker.py | 2 +- src/toolviper/utils/__init__.py | 1 + src/toolviper/utils/data/cloudflare.py | 58 ++++++++++++++------ src/toolviper/utils/display.py | 57 ++++++++++++-------- src/toolviper/utils/sd/__init__.py | 5 ++ src/toolviper/utils/sd/graph.py | 46 ++++++++++++++++ src/toolviper/utils/sd/prototype.py | 73 ++++++++++++++++++++++++++ src/toolviper/utils/tools.py | 6 +-- 10 files changed, 221 insertions(+), 49 deletions(-) create mode 100644 src/toolviper/utils/sd/__init__.py create mode 100644 src/toolviper/utils/sd/graph.py create mode 100644 src/toolviper/utils/sd/prototype.py diff --git a/src/toolviper/dask/client.py b/src/toolviper/dask/client.py index e967f90..c71fb4f 100644 --- a/src/toolviper/dask/client.py +++ b/src/toolviper/dask/client.py @@ -13,9 +13,11 @@ from typing import Dict, Union, Any, Optional import toolviper.dask.menrva + import toolviper.utils.console as console import toolviper.utils.logger as logger import toolviper.utils.parameter as parameter +import toolviper.utils.display as display colorize = console.Colorize() @@ -274,11 +276,13 @@ def local_client( client = toolviper.dask.menrva.MenrvaClient(cluster) client.get_versions(check=True) - # When constructing a graph that has local cache enabled, all workers need to be up and running. + # When constructing a graph that has local-cache enabled, all workers need to be up and running. if local_cache or wait_for_workers: client.wait_for_workers(n_workers=cores) - logger.debug(f"These are the worker log parameters:\n {worker_log_params}") + # logger.debug(f"These are the worker log parameters:\n") + # logger.debug(f"{display.DataDict.from_dict(worker_log_params).display(interactive=False)}") + if local_cache or worker_log_params: client.load_plugin( directory=plugin_path, diff --git a/src/toolviper/dask/menrva.py b/src/toolviper/dask/menrva.py index 8bcf4ac..53d9b03 100644 --- a/src/toolviper/dask/menrva.py +++ b/src/toolviper/dask/menrva.py @@ -118,17 +118,17 @@ def call(func: Callable, *args: Tuple[Any], **kwargs: Dict[str, Any]): @staticmethod def instantiate_module( plugin: str, plugin_file: str, *args: Tuple[Any], **kwargs: Dict[str, Any] - ) -> WorkerPlugin: + ) -> WorkerPlugin | None: """ Args: plugin (str): Name of plugin module. - plugin_file (str): Name of module file. ** This should be moved into the module itself not passed ** - *args (tuple(Any)): This is any *arg that needs to be passed to the plugin module. + plugin_file (str): Name of a module file. ** This should be moved into the module itself, not passed ** + *args (tuple (Any)): This is any *arg that needs to be passed to the plugin module. **kwargs (dict[str, Any]): This is any **kwarg default values that need to be passed to the plugin module. Returns: - Instance of plugin class. + Instance of plugin-class. """ spec = importlib.util.spec_from_file_location(plugin, plugin_file) module = importlib.util.module_from_spec(spec) @@ -138,6 +138,8 @@ def instantiate_module( logger.debug("Loading plugin module: {}".format(plugin_instance)) return MenrvaClient.call(plugin_instance, *args, **kwargs) + return None + def load_plugin( self, directory: str, @@ -154,13 +156,13 @@ def load_plugin( *args, **kwargs, ) - logger.debug(f"{plugin}") + if sys.version_info.major == 3: if sys.version_info.minor > 8: self.register_plugin(plugin_instance, name=name) else: - self.register_worker_plugin(plugin_instance, name=name) + self.register_plugin(plugin_instance, name=name) else: logger.warning("Python version may not be supported.") else: diff --git a/src/toolviper/dask/plugins/worker.py b/src/toolviper/dask/plugins/worker.py index 12c9e1b..9e5681d 100644 --- a/src/toolviper/dask/plugins/worker.py +++ b/src/toolviper/dask/plugins/worker.py @@ -99,7 +99,7 @@ async def dask_setup( await worker.client.register_plugin(plugin, name="worker_logger") else: - await worker.client.register_worker_plugin(plugin, name="worker_logger") + await worker.client.register_plugin(plugin, name="worker_logger") else: logger.warning("Python version may not be supported.") diff --git a/src/toolviper/utils/__init__.py b/src/toolviper/utils/__init__.py index b34a5c3..9b160b8 100644 --- a/src/toolviper/utils/__init__.py +++ b/src/toolviper/utils/__init__.py @@ -4,6 +4,7 @@ from .logger import info, debug, warning, error, critical, get_logger, setup_logger from .tools import open_json, calculate_checksum, verify, add_entry from .profile import memory_usage, cpu_usage +from .sd import prototype from .data import download diff --git a/src/toolviper/utils/data/cloudflare.py b/src/toolviper/utils/data/cloudflare.py index 451fc11..bccad66 100644 --- a/src/toolviper/utils/data/cloudflare.py +++ b/src/toolviper/utils/data/cloudflare.py @@ -9,6 +9,7 @@ import requests import pandas as pd +from rich.console import Console from rich.progress import Progress, TaskID import toolviper.utils.console as console @@ -80,10 +81,10 @@ def download( if isinstance(file, str): file = [file] - try: - _print_file_queue(file) - except Exception as e: - logger.warning(f"Problem printing file list: {e}") + # try: + # _print_file_queue(file) + # except Exception as e: + # logger.warning(f"Problem printing file list: {e}") dest_path = pathlib.Path(folder).resolve() if not dest_path.exists(): @@ -123,6 +124,9 @@ def name_format(string): if f_name not in file_meta_data.get("metadata", {}): logger.error(f"Requested file not found in manifest: {f_name}") + logger.error( + f"Use {colorize.blue('toolviper.utils.data.update()')} for the most recent version of the manifest." + ) logger.info( f"Use {colorize.blue('toolviper.utils.data.list_files()')} for available files." ) @@ -137,6 +141,7 @@ def name_format(string): "folder": str(dest_path), "visible": True, "size": int(meta.get("size", 0)), + "jupyter": is_notebook(), } ) @@ -147,24 +152,32 @@ def name_format(string): return progress = Progress() + if is_notebook(): + _ = Console(force_terminal=True, force_jupyter=False) + _console = Console(force_jupyter=is_notebook()) + + progress = Progress(console=_console) + threads = [] with progress: for task in tasks: task_id = progress.add_task(task["description"]) - thread = Thread(target=worker, args=(progress, task_id, task, decompress)) + thread = Thread(target=worker, args=(task_id, task, progress, decompress)) thread.start() threads.append(thread) for thread in threads: thread.join() + progress.refresh() + if missing_files: logger.error(f"Could not download: {missing_files}") def worker( - progress: Progress, task_id: TaskID, task: dict, decompress: bool = True + task_id: TaskID, task: dict, progress: Progress = None, decompress: bool = True ) -> None: """ Worker function to download a file in a thread. @@ -190,6 +203,7 @@ def worker( url, stream=True, headers={"user-agent": USER_AGENT}, timeout=30 ) response.raise_for_status() + except Exception as e: logger.error(f"Failed to initiate download for {filename}: {e}") return @@ -207,9 +221,14 @@ def worker( for chunk in response.iter_content(chunk_size=MINIMUM_CHUNK_SIZE): if chunk: size += fd.write(chunk) - progress.update( - task_id, completed=size, total=total, visible=task["visible"] - ) + if progress is not None: + progress.update( + task_id, + completed=size, + total=total, + visible=task["visible"], + ) + except Exception as e: logger.error(f"Error writing file {filename}: {e}") return @@ -412,7 +431,7 @@ def update(path: Optional[str] = None) -> None: "path": "/", "dtype": "JSON", "telescope": "NA", - "size": "12484", + "size": "23879", "mode": "NA", } @@ -420,17 +439,24 @@ def update(path: Optional[str] = None) -> None: "description": "Updating manifest", "metadata": file_meta_data, "folder": str(meta_data_dir), - "visible": False, - "size": 12484, + "visible": True, + "size": 23879, } logger.info("Updating file metadata information...") - progress = Progress() - task_id = progress.add_task(task["description"]) + task_id = 0 + # with progress: + _console = Console(force_jupyter=is_notebook()) + tasks = [f"\nManifest update "] - with progress: - worker(progress, task_id, task, decompress=False) + with _console.status( + "[bold green]Working on download manifest update ..." + ) as status: + while tasks: + worker(task_id, task, progress=None, decompress=False) + + task = tasks.pop(0) if not meta_data_path.exists(): logger.error("Unable to retrieve download metadata.") diff --git a/src/toolviper/utils/display.py b/src/toolviper/utils/display.py index b4faebf..621b89a 100755 --- a/src/toolviper/utils/display.py +++ b/src/toolviper/utils/display.py @@ -1,19 +1,9 @@ import re import operator -from IPython.core.display import HTML - +import sys +from typing import Dict -def dict_to_html(d, indent=0): - print(f"THIS FUNCTION WILL BE DEPRECATED SOON") - - html = "" - for key, value in d.items(): - if isinstance(value, dict): - html += f"
{key}{dict_to_html(value, indent + 1)}
" - else: - html += f"
{key}: {value}
" - - return html +from IPython.core.display import HTML class DataDict(dict): @@ -74,13 +64,38 @@ def display(self, interactive=True): return rich.print_json(data=self._dict) - def html(self, indent=0): - html = "" - for key, value in self._dict.items(): - if isinstance(value, dict): - html += f"
{key}{dict_to_html(value, indent + 1)}
" + @staticmethod + def html(dictionary: Dict, indent: int = 0): + _html = _write_html(dictionary, indent) - else: - html += f"
{key}: {value}
" + return HTML(_html) - return HTML(html) + +def _write_html(d, indent=0): + _html = "" + + for key, value in d.items(): + if isinstance(value, dict): + _html += f"
{key}{_write_html(value, indent + 1)}
" + + else: + _html += f"
{key}: {value}
" + + return _html + + +def dict_to_html(d, indent=0): + + html = "" + for key, value in d.items(): + if isinstance(value, dict): + html += f"
{key}{dict_to_html(value, indent + 1)}
" + else: + html += f"
{key}: {value}
" + + if indent == 0: + print( + f"THIS FUNCTION WILL BE DEPRECATED SOON, switch to: toolviper.utils.display.DataDict.html(d)" + ) + + return html diff --git a/src/toolviper/utils/sd/__init__.py b/src/toolviper/utils/sd/__init__.py new file mode 100644 index 0000000..0622cdf --- /dev/null +++ b/src/toolviper/utils/sd/__init__.py @@ -0,0 +1,5 @@ +from .prototype import * +from .graph import * + +__submodules__ = ["prototype"] +__all__ = __submodules__ + [s for s in dir() if not s.startswith("_")] diff --git a/src/toolviper/utils/sd/graph.py b/src/toolviper/utils/sd/graph.py new file mode 100644 index 0000000..f1b7865 --- /dev/null +++ b/src/toolviper/utils/sd/graph.py @@ -0,0 +1,46 @@ +import dask +import collections +import toolviper + +import toolviper.utils.logger as logger + + +class Graph: + """A class representing a directed graph for dependency management.""" + + def __init__(self): + self._graph = None + self._results = collections.defaultdict(list) + + def source(self, job, axes, connect=False, node=None): + function_name = job["function"].__name__ + previous = None + + if connect: + previous = self._graph + + if node is not None: + try: + previous = self._results[node] + + except KeyError: + logger.error(f"Node {node} not found in results.") + + self._graph = toolviper.utils.sd.distribute( + job=job, axes=axes, function=job["function"], previous=previous + ) + + self._results[function_name].append(self._graph) + + def sink(self, function, edges=None): + self._graph = dask.delayed(function)(self._graph) + + def visualize(self): + return dask.visualize(self._graph) + + def compute(self): + return dask.compute(self._graph) + + @property + def nodes(self): + return list(self._results.keys()) diff --git a/src/toolviper/utils/sd/prototype.py b/src/toolviper/utils/sd/prototype.py new file mode 100644 index 0000000..46c598b --- /dev/null +++ b/src/toolviper/utils/sd/prototype.py @@ -0,0 +1,73 @@ +import dask +import typing +import itertools + +import numpy as np +import xarray as xr + + +# Build a simple Dask dataset based on a given set of axes +def simulate(field, spw, polarization, antenna, row): + data_shape = { + "field": [f"field_{i}" for i in range(field)], + "spw": [f"spw_{i}" for i in range(spw)], + "polarization": polarization, + "antenna": [f"antenna_{i}" for i in range(antenna)], + "row": [i for i in range(row)], + } + + dataset = xr.Dataset( + coords=data_shape, + data_vars=dict( + DATA=( + list(data_shape.keys()), + np.zeros((field, spw, len(polarization), antenna, row)), + ) + ), + ) + + return dataset + + +def distribute( + job: typing.Dict, axes: typing.List[str], function: typing.Callable, previous=None +) -> typing.List[dask.delayed]: + """ + Distribute a function across a dataset along specified axes. + + This function creates a list of delayed dask tasks, where each task + represents a call to the specified function with the dataset or previous + result, and the values of the distribution axes. + + Parameters + ---------- + dataset : xr.Dataset + The input dataset to be distributed. + axes : typing.List[str] + The axes to distribute along. + function : typing.Callable + The function to be applied in a delayed manner. + previous : typing.Any, optional + A previous result to be passed to the function. Defaults to None. + + Returns + ------- + typing.List[dask.delayed] + A list of dask delayed objects. + """ + # Get the coordinate values for each axis + axis_values = [job["dataset"].coords[axis].values for axis in axes] + + if isinstance(previous, list): + axis_values.append(previous) + + # Create a delayed version of the function + delayed_func = dask.delayed(function) + + # Use itertools.product to generate all combinations of axis values + # and create a delayed task for each combination. + # The axis values are passed as positional arguments after 'previous'. + return [ + delayed_func(*values) if previous is not None else delayed_func(*values) + for values in itertools.product(*axis_values) + ] diff --git a/src/toolviper/utils/tools.py b/src/toolviper/utils/tools.py index b087258..5dd6dca 100644 --- a/src/toolviper/utils/tools.py +++ b/src/toolviper/utils/tools.py @@ -186,9 +186,9 @@ def add_entry( for entry in entries: process_entry_(**entry, json_file=json_file) - except KeyError as key_error: - logger.error(f"entry not found in metadata ... skipping: {key_error}") - return None + # except KeyError as key_error: + # logger.error(f"entry not found in metadata ... skipping: {key_error}") + # return None except TypeError: logger.error( From bb472d564ce73f7dfeb4ff822289e93276ca53d9 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Mon, 9 Mar 2026 10:28:27 +0900 Subject: [PATCH 18/24] Add notebooks. --- docs/add_entry.ipynb | 142 ++++++++++- docs/download_example.ipynb | 5 +- docs/file-manifest-update.ipynb | 10 + docs/hsd_imaging_skeleton.ipynb | 235 ++++++++++++++++++ ...toolviper-logger-formatting-example.ipynb} | 0 5 files changed, 376 insertions(+), 16 deletions(-) create mode 100644 docs/hsd_imaging_skeleton.ipynb rename docs/{graphviper-logger-formatting-example.ipynb => toolviper-logger-formatting-example.ipynb} (100%) diff --git a/docs/add_entry.ipynb b/docs/add_entry.ipynb index 99e0814..24b23e2 100644 --- a/docs/add_entry.ipynb +++ b/docs/add_entry.ipynb @@ -23,6 +23,16 @@ "toolviper.utils.data.update(path=str(path))" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "0dda04ad-1ecc-4925-92bb-9a608f81d7e1", + "metadata": {}, + "outputs": [], + "source": [ + "toolviper.utils.data.get_file_size(str(path))" + ] + }, { "cell_type": "code", "execution_count": null, @@ -33,11 +43,11 @@ "entries = []\n", "\n", "entry = {\n", - " \"file\": str(path.joinpath(\"ngc5921-lsrk-cube.psf.zip\")),\n", + " \"file\": str(path.joinpath(\"upload/casa_no_sky_to_xds_true.zarr.zip\")),\n", " \"path\": \"radps/image\",\n", " \"dtype\": \"CASA image\",\n", " \"telescope\": \"VLA\",\n", - " \"mode\": \"Interferometric\"\n", + " \"mode\": \"Simulated\"\n", "}\n", "\n", "entries.append(entry)" @@ -46,36 +56,142 @@ { "cell_type": "code", "execution_count": null, - "id": "0870dba2-cf08-4130-aea2-8d64c60cff07", + "id": "571ff4a0-252a-4086-a203-a9d199e20aea", "metadata": {}, "outputs": [], "source": [ - "_ = toolviper.utils.tools.add_entry(\n", - " entries=entries,\n", - " manifest=str(path.joinpath(\"file.download.json\")),\n", - " versioning=\"patch\"\n", - ")" + "entry = {\n", + " \"file\": str(path.joinpath(\"upload/casa_to_xds_true.zarr.zip\")),\n", + " \"path\": \"radps/image\",\n", + " \"dtype\": \"CASA image\",\n", + " \"telescope\": \"VLA\",\n", + " \"mode\": \"Simulated\"\n", + "}\n", + "\n", + "entries.append(entry)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6a2a9e7-b2bd-41ea-8508-7e9a332042eb", + "metadata": {}, + "outputs": [], + "source": [ + "entry = {\n", + " \"file\": str(path.joinpath(\"upload/empty_sky_image_true.zarr.zip\")),\n", + " \"path\": \"radps/image\",\n", + " \"dtype\": \"CASA image\",\n", + " \"telescope\": \"VLA\",\n", + " \"mode\": \"Simulated\"\n", + "}\n", + "\n", + "entries.append(entry)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "faa111cf-3406-49b9-8980-40fa484ab5d5", + "metadata": {}, + "outputs": [], + "source": [ + "entry = {\n", + " \"file\": str(path.joinpath(\"upload/empty_sky_image_no_sky_coords_true.zarr.zip\")),\n", + " \"path\": \"radps/image\",\n", + " \"dtype\": \"CASA image\",\n", + " \"telescope\": \"VLA\",\n", + " \"mode\": \"Simulated\"\n", + "}\n", + "\n", + "entries.append(entry)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae6583fb-e078-4758-a32a-08bbb099c874", + "metadata": {}, + "outputs": [], + "source": [ + "entry = {\n", + " \"file\": str(path.joinpath(\"upload/empty_aperture_image_true.zarr.zip\")),\n", + " \"path\": \"radps/image\",\n", + " \"dtype\": \"CASA image\",\n", + " \"telescope\": \"VLA\",\n", + " \"mode\": \"Simulated\"\n", + "}\n", + "\n", + "entries.append(entry)" ] }, { "cell_type": "code", "execution_count": null, - "id": "a113bb2b-fed4-4ede-8d8b-353958f59602", + "id": "cfb61527-d651-4423-9334-5b8847010374", "metadata": {}, "outputs": [], "source": [ - "#toolviper.utils.data.update()\n", + "entry = {\n", + " \"file\": str(path.joinpath(\"upload/empty_lmuv_image_true.zarr.zip\")),\n", + " \"path\": \"radps/image\",\n", + " \"dtype\": \"CASA image\",\n", + " \"telescope\": \"VLA\",\n", + " \"mode\": \"Simulated\"\n", + "}\n", "\n", - "#toolviper.utils.data.download(file=\"ngc5921-lsrk-cube.psf\", folder=\"test\")" + "entries.append(entry)" ] }, { "cell_type": "code", "execution_count": null, - "id": "7df91b02-3e25-4989-a76e-20ca4b7a9cb2", + "id": "adf0ba7d-81e0-4847-96c5-61cbb1f92e52", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "entry = {\n", + " \"file\": str(path.joinpath(\"upload/empty_lmuv_image_no_sky_coords_true.zarr.zip\")),\n", + " \"path\": \"radps/image\",\n", + " \"dtype\": \"CASA image\",\n", + " \"telescope\": \"VLA\",\n", + " \"mode\": \"Simulated\"\n", + "}\n", + "\n", + "entries.append(entry)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6d75323-8c4b-43ab-ac47-091e8af0eeee", + "metadata": {}, + "outputs": [], + "source": [ + "entry = {\n", + " \"file\": str(path.joinpath(\"upload/casa_uv_true.zarr.zip\")),\n", + " \"path\": \"radps/image\",\n", + " \"dtype\": \"CASA image\",\n", + " \"telescope\": \"VLA\",\n", + " \"mode\": \"Simulated\"\n", + "}\n", + "\n", + "entries.append(entry)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0870dba2-cf08-4130-aea2-8d64c60cff07", + "metadata": {}, + "outputs": [], + "source": [ + "_ = toolviper.utils.tools.add_entry(\n", + " entries=entries,\n", + " manifest=str(path.joinpath(\"file.download.json\")),\n", + " versioning=\"patch\"\n", + ")" + ] } ], "metadata": { diff --git a/docs/download_example.ipynb b/docs/download_example.ipynb index 3d6b48b..c08ca49 100644 --- a/docs/download_example.ipynb +++ b/docs/download_example.ipynb @@ -36,8 +36,7 @@ "metadata": {}, "outputs": [], "source": [ - "files = toolviper.utils.data.get_files()\n", - "files" + "files = toolviper.utils.data.get_files()" ] }, { @@ -47,7 +46,7 @@ "metadata": {}, "outputs": [], "source": [ - "toolviper.utils.data.download(file=files[3:6], folder=\"data\")" + "toolviper.utils.data.download(file=files[3:8], folder=\"data\")" ] }, { diff --git a/docs/file-manifest-update.ipynb b/docs/file-manifest-update.ipynb index 8b40c6d..9f52031 100644 --- a/docs/file-manifest-update.ipynb +++ b/docs/file-manifest-update.ipynb @@ -41,6 +41,16 @@ "toolviper.utils.data.update(path=str(path))" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "8589d72e-c594-400a-8660-f54336b9c4d6", + "metadata": {}, + "outputs": [], + "source": [ + "_json = toolviper.utils.tools.open_json(\"file.download.json\")" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/docs/hsd_imaging_skeleton.ipynb b/docs/hsd_imaging_skeleton.ipynb new file mode 100644 index 0000000..6335f7a --- /dev/null +++ b/docs/hsd_imaging_skeleton.ipynb @@ -0,0 +1,235 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "01a56c62-2896-4857-8af6-e70539e8b2eb", + "metadata": {}, + "outputs": [], + "source": [ + "import dask\n", + "import time\n", + "import toolviper\n", + "import random\n", + "\n", + "import toolviper.utils.logger as logger\n", + "import toolviper.utils.display as display\n", + "\n", + "import toolviper.dask.client as client\n", + "\n", + "from collections import defaultdict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b14a64bf-3470-45da-b3f2-f44f59eb26f1", + "metadata": {}, + "outputs": [], + "source": [ + "client = client.local_client(\n", + " cores=10,\n", + " log_params={\n", + " \"log_to_file\":False,\n", + " \"log_to_term\":True,\n", + " \"log_level\":\"DEBUG\" \n", + " },\n", + " worker_log_params={\n", + " \"log_to_file\":False,\n", + " \"log_to_term\":True,\n", + " \"log_level\":\"DEBUG\" \n", + " }\n", + ")\n", + "\n", + "# Spawn dashboard window in a seperate tab,\n", + "# comment out if you don't want this to spawn.\n", + "# webbrowser.open(url=client.dashboard_link)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9273c405-6f67-46a9-9851-23b7778ce877", + "metadata": {}, + "outputs": [], + "source": [ + "data = toolviper.utils.sd.prototype.simulate(field=1, spw=2, polarization=[\"XX\", \"YY\"], antenna=1, row=2)\n", + "data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "219ccbce-d0a2-4b9a-b2fa-cfc64cc32545", + "metadata": {}, + "outputs": [], + "source": [ + "def gather(result):\n", + " if result is None:\n", + " return\n", + " \n", + " return result\n", + "\n", + "# Simple function to generate a time delay and simulate \n", + "# data processing\n", + "def generate_delay(n=1, m=2):\n", + " time.sleep(random.uniform(n, m))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71e0d665-a13c-4cd8-ab0e-609c72a7bbe9", + "metadata": {}, + "outputs": [], + "source": [ + "def imaging_parameter_setup(*arg, **kwargs):\n", + " generate_delay(n=1, m=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c67c4659-1eeb-4df6-9960-99bce924ed63", + "metadata": {}, + "outputs": [], + "source": [ + "class Graph:\n", + " def __init__(self):\n", + " self._graph = None\n", + " self._results = defaultdict(list)\n", + " \n", + " def source(self, function, axes, connect=False):\n", + " function_name = function.__name__\n", + " previous = None\n", + " \n", + " if connect:\n", + " previous = self._graph\n", + " \n", + " self._graph = toolviper.utils.sd.distribute(\n", + " dataset=data,\n", + " axes=axes,\n", + " function=function,\n", + " previous=previous\n", + " )\n", + " \n", + " self._results[function_name].append(self._graph)\n", + "\n", + " def sink(self, function, edges=None):\n", + " self._graph = dask.delayed(function)(self._graph)\n", + "\n", + " def visualize(self):\n", + " return dask.visualize(self._graph)\n", + "\n", + " def compute(self):\n", + " return dask.compute(self._graph)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa444240-f5cf-4e43-974b-81b7fe0b8e50", + "metadata": {}, + "outputs": [], + "source": [ + "def check_values(*args, **kwargs):\n", + " print(f\"check::args::{args}\\n\")\n", + "\n", + "def set_values(*args, **kwargs):\n", + " print(f\"set::args::{args}\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1a9cf04-da1f-415e-869c-647f90145e5b", + "metadata": {}, + "outputs": [], + "source": [ + "#dask.visualize(new)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c44c2d8d-fbc6-4a8b-b40b-29e6bdf18eea", + "metadata": {}, + "outputs": [], + "source": [ + "#dask.compute(new)\n", + "\n", + "job = {\n", + " \"dataset\": data,\n", + " \"function\": check_values\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a7ddcdcd-1cb8-4c20-b9c9-1989549caf19", + "metadata": {}, + "outputs": [], + "source": [ + "graph = toolviper.utils.sd.graph.Graph()\n", + "\n", + "graph.source(job=job, axes=[\"field\", \"antenna\"])\n", + "job[\"function\"] = set_values\n", + "\n", + "graph.source(job=job, axes=[\"field\", \"spw\", \"polarization\"], connect=True, node=\"check_values\")\n", + "graph.sink(function=gather)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64dcb301-9551-436c-90ef-9ef5d105906b", + "metadata": {}, + "outputs": [], + "source": [ + "graph.visualize()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb391466-e2ca-43f3-8872-c1896eadd823", + "metadata": {}, + "outputs": [], + "source": [ + "graph.compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f92ea73-85a5-44f0-9d4e-b96fbc74faa4", + "metadata": {}, + "outputs": [], + "source": [ + "graph.nodes" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/graphviper-logger-formatting-example.ipynb b/docs/toolviper-logger-formatting-example.ipynb similarity index 100% rename from docs/graphviper-logger-formatting-example.ipynb rename to docs/toolviper-logger-formatting-example.ipynb From bb7f9edd883bc159bee22f01ea59fd71d1e8be8e Mon Sep 17 00:00:00 2001 From: jrhosk Date: Mon, 9 Mar 2026 10:38:57 +0900 Subject: [PATCH 19/24] Add xarray. --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 619b2ce..158e8a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,5 +82,6 @@ all = [ 'sphinx-autosummary-accessors', 'sphinx_rtd_theme', 'twine', - 'pandoc' + 'pandoc', + 'xarray' ] From 1f465d4d5704a9c71f7518f8d5c3cc708a363053 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Mon, 9 Mar 2026 10:46:23 +0900 Subject: [PATCH 20/24] Stupid notebook tests. --- docs/add_entry.ipynb | 126 ------------------------------------------- 1 file changed, 126 deletions(-) diff --git a/docs/add_entry.ipynb b/docs/add_entry.ipynb index 24b23e2..9fcf689 100644 --- a/docs/add_entry.ipynb +++ b/docs/add_entry.ipynb @@ -53,132 +53,6 @@ "entries.append(entry)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "571ff4a0-252a-4086-a203-a9d199e20aea", - "metadata": {}, - "outputs": [], - "source": [ - "entry = {\n", - " \"file\": str(path.joinpath(\"upload/casa_to_xds_true.zarr.zip\")),\n", - " \"path\": \"radps/image\",\n", - " \"dtype\": \"CASA image\",\n", - " \"telescope\": \"VLA\",\n", - " \"mode\": \"Simulated\"\n", - "}\n", - "\n", - "entries.append(entry)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d6a2a9e7-b2bd-41ea-8508-7e9a332042eb", - "metadata": {}, - "outputs": [], - "source": [ - "entry = {\n", - " \"file\": str(path.joinpath(\"upload/empty_sky_image_true.zarr.zip\")),\n", - " \"path\": \"radps/image\",\n", - " \"dtype\": \"CASA image\",\n", - " \"telescope\": \"VLA\",\n", - " \"mode\": \"Simulated\"\n", - "}\n", - "\n", - "entries.append(entry)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "faa111cf-3406-49b9-8980-40fa484ab5d5", - "metadata": {}, - "outputs": [], - "source": [ - "entry = {\n", - " \"file\": str(path.joinpath(\"upload/empty_sky_image_no_sky_coords_true.zarr.zip\")),\n", - " \"path\": \"radps/image\",\n", - " \"dtype\": \"CASA image\",\n", - " \"telescope\": \"VLA\",\n", - " \"mode\": \"Simulated\"\n", - "}\n", - "\n", - "entries.append(entry)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ae6583fb-e078-4758-a32a-08bbb099c874", - "metadata": {}, - "outputs": [], - "source": [ - "entry = {\n", - " \"file\": str(path.joinpath(\"upload/empty_aperture_image_true.zarr.zip\")),\n", - " \"path\": \"radps/image\",\n", - " \"dtype\": \"CASA image\",\n", - " \"telescope\": \"VLA\",\n", - " \"mode\": \"Simulated\"\n", - "}\n", - "\n", - "entries.append(entry)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cfb61527-d651-4423-9334-5b8847010374", - "metadata": {}, - "outputs": [], - "source": [ - "entry = {\n", - " \"file\": str(path.joinpath(\"upload/empty_lmuv_image_true.zarr.zip\")),\n", - " \"path\": \"radps/image\",\n", - " \"dtype\": \"CASA image\",\n", - " \"telescope\": \"VLA\",\n", - " \"mode\": \"Simulated\"\n", - "}\n", - "\n", - "entries.append(entry)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "adf0ba7d-81e0-4847-96c5-61cbb1f92e52", - "metadata": {}, - "outputs": [], - "source": [ - "entry = {\n", - " \"file\": str(path.joinpath(\"upload/empty_lmuv_image_no_sky_coords_true.zarr.zip\")),\n", - " \"path\": \"radps/image\",\n", - " \"dtype\": \"CASA image\",\n", - " \"telescope\": \"VLA\",\n", - " \"mode\": \"Simulated\"\n", - "}\n", - "\n", - "entries.append(entry)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b6d75323-8c4b-43ab-ac47-091e8af0eeee", - "metadata": {}, - "outputs": [], - "source": [ - "entry = {\n", - " \"file\": str(path.joinpath(\"upload/casa_uv_true.zarr.zip\")),\n", - " \"path\": \"radps/image\",\n", - " \"dtype\": \"CASA image\",\n", - " \"telescope\": \"VLA\",\n", - " \"mode\": \"Simulated\"\n", - "}\n", - "\n", - "entries.append(entry)" - ] - }, { "cell_type": "code", "execution_count": null, From b9f85342b7d5e3f72d9bbc52df9a98a0c6a808c5 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Mon, 9 Mar 2026 11:21:31 +0900 Subject: [PATCH 21/24] Stupid notebook tests, part-II --- docs/add_entry.ipynb => add_entry.ipynb | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/add_entry.ipynb => add_entry.ipynb (100%) diff --git a/docs/add_entry.ipynb b/add_entry.ipynb similarity index 100% rename from docs/add_entry.ipynb rename to add_entry.ipynb From 52de9a8cff9b1e480b451d5a34546a1e9f3d65dc Mon Sep 17 00:00:00 2001 From: jrhosk Date: Tue, 10 Mar 2026 16:19:25 +0900 Subject: [PATCH 22/24] Updates~ --- docs/api.rst | 5 +- docs/{Example => example}/ascii/snek.txt | 0 .../config/viper.param.json | 0 docs/{Example => example}/viper.py | 0 docs/hsd_imaging_skeleton.ipynb | 973 ++++++++++++++++-- docs/index.rst | 8 +- pyproject.toml | 2 + src/toolviper/dask/plugins/scheduler.py | 2 +- src/toolviper/utils/app.py | 40 + src/toolviper/utils/css/button.tcss | 12 + src/toolviper/utils/sd/graph.py | 20 +- src/toolviper/utils/sd/prototype.py | 31 +- 12 files changed, 1007 insertions(+), 86 deletions(-) rename docs/{Example => example}/ascii/snek.txt (100%) rename docs/{Example => example}/config/viper.param.json (100%) rename docs/{Example => example}/viper.py (100%) create mode 100644 src/toolviper/utils/app.py create mode 100644 src/toolviper/utils/css/button.tcss diff --git a/docs/api.rst b/docs/api.rst index 06b8cb1..079c611 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -4,7 +4,4 @@ API .. toctree:: :maxdepth: 2 - _api/autoapi/toolviper/dask/client/index - _api/autoapi/toolviper/graph_tools/coordinate_utils/index - _api/autoapi/toolviper/graph_tools/map/index - _api/autoapi/toolviper/graph_tools/reduce/index \ No newline at end of file + _api/autoapi/toolviper/dask/client/index \ No newline at end of file diff --git a/docs/Example/ascii/snek.txt b/docs/example/ascii/snek.txt similarity index 100% rename from docs/Example/ascii/snek.txt rename to docs/example/ascii/snek.txt diff --git a/docs/Example/config/viper.param.json b/docs/example/config/viper.param.json similarity index 100% rename from docs/Example/config/viper.param.json rename to docs/example/config/viper.param.json diff --git a/docs/Example/viper.py b/docs/example/viper.py similarity index 100% rename from docs/Example/viper.py rename to docs/example/viper.py diff --git a/docs/hsd_imaging_skeleton.ipynb b/docs/hsd_imaging_skeleton.ipynb index 6335f7a..680f28d 100644 --- a/docs/hsd_imaging_skeleton.ipynb +++ b/docs/hsd_imaging_skeleton.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "01a56c62-2896-4857-8af6-e70539e8b2eb", "metadata": {}, "outputs": [], @@ -11,6 +11,7 @@ "import time\n", "import toolviper\n", "import random\n", + "import webbrowser\n", "\n", "import toolviper.utils.logger as logger\n", "import toolviper.utils.display as display\n", @@ -22,10 +23,40 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "b14a64bf-3470-45da-b3f2-f44f59eb26f1", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\u001b[38;2;128;05;128m2026-03-10 14:28:45,835\u001b[0m] \u001b[38;2;255;160;0m WARNING\u001b[0m\u001b[38;2;112;128;144m client: \u001b[0m It is recommended that the local cache directory be set using the \u001b[38;2;50;50;205mdask_local_dir\u001b[0m parameter. \n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:46,501\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m client: \u001b[0m Loading plugin module: \n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:46,973\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_9: \u001b[0m Logger created on worker Worker-66fa468d-877c-4fb5-b075-43dabff8712a,*,tcp://127.0.0.1:61787\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:46,995\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_5: \u001b[0m Logger created on worker Worker-838298e3-ea06-416c-b2a5-c16badef1739,*,tcp://127.0.0.1:61773\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,005\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_0: \u001b[0m Logger created on worker Worker-ed10022a-2dbc-4241-a528-5814ec2428a2,*,tcp://127.0.0.1:61770\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,009\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_2: \u001b[0m Logger created on worker Worker-16591390-0985-4a28-8d51-7b98cd74d451,*,tcp://127.0.0.1:61778\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,010\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_1: \u001b[0m Logger created on worker Worker-a02643d0-e433-43b4-8d9b-36c78c514c22,*,tcp://127.0.0.1:61765\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,015\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_3: \u001b[0m Logger created on worker Worker-891fb5de-ca05-4c17-a11a-61be28e34270,*,tcp://127.0.0.1:61768\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,019\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_8: \u001b[0m Logger created on worker Worker-0693380a-b14b-41c4-8db6-50ace56af341,*,tcp://127.0.0.1:61779\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,022\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_6: \u001b[0m Logger created on worker Worker-8db82b11-0870-4414-9cd2-fcf8f4e657d5,*,tcp://127.0.0.1:61786\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,027\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_4: \u001b[0m Logger created on worker Worker-8e911e2d-6ba0-4684-b624-727ad3a771b8,*,tcp://127.0.0.1:61769\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,037\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_7: \u001b[0m Logger created on worker Worker-e72f454a-9bea-4f38-b75b-be4fea40277f,*,tcp://127.0.0.1:61792\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,037\u001b[0m] \u001b[38;2;50;50;205m INFO\u001b[0m\u001b[38;2;112;128;144m client: \u001b[0m Client \n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "client = client.local_client(\n", " cores=10,\n", @@ -43,23 +74,692 @@ "\n", "# Spawn dashboard window in a seperate tab,\n", "# comment out if you don't want this to spawn.\n", - "# webbrowser.open(url=client.dashboard_link)" + "webbrowser.open(url=client.dashboard_link)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "9273c405-6f67-46a9-9851-23b7778ce877", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 864B\n",
+       "Dimensions:       (field: 1, spw: 5, polarization: 2, antenna: 4, row: 2)\n",
+       "Coordinates:\n",
+       "  * field         (field) int64 8B 0\n",
+       "  * spw           (spw) int64 40B 0 1 2 3 4\n",
+       "  * polarization  (polarization) <U2 16B 'XX' 'YY'\n",
+       "  * antenna       (antenna) <U9 144B 'antenna_0' 'antenna_1' ... 'antenna_3'\n",
+       "  * row           (row) int64 16B 0 1\n",
+       "Data variables:\n",
+       "    DATA          (field, spw, polarization, antenna, row) float64 640B 0.0 ....
" + ], + "text/plain": [ + " Size: 864B\n", + "Dimensions: (field: 1, spw: 5, polarization: 2, antenna: 4, row: 2)\n", + "Coordinates:\n", + " * field (field) int64 8B 0\n", + " * spw (spw) int64 40B 0 1 2 3 4\n", + " * polarization (polarization) " + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "graph.visualize()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "id": "eb391466-e2ca-43f3-8872-c1896eadd823", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "([None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None],)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "graph.compute()" ] @@ -209,6 +1036,14 @@ "source": [ "graph.nodes" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16d71a8d-f4d9-4c21-8e9e-55a202d85e52", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/docs/index.rst b/docs/index.rst index efa0e35..13cc99a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,11 +1,10 @@ -Graph Visibility and Image Parallel Execution Reduction +Tools for Visibility and Image Parallel Execution Reduction ======================================================= -toolviper is a `Dask `_ based MapReduce package. It allows for mapping a dictionary of `xarray.Datasets `_ to `Dask graph nodes `_ followed by a reduce step. +toolviper is a `Dask `_ based set of tools that can be used either with or independently with the VIPER framework. -**toolviper is in development and breaking API changes will happen.** +toolviper **is in development and breaking API changes will happen.** -The best place to start with toolviper is doing the `graph building tutorial `_ . `GitHub repository link `_ @@ -14,4 +13,3 @@ The best place to start with toolviper is doing the `graph building tutorial ComposeResult: + yield Header() + yield Footer() + yield VerticalScroll( + DirectoryTreeApp(), + ExitButton(), + ) + + def action_toggle_dark(self) -> None: + self.theme = ( + "textual-dark" if self.theme == "textual-light" else "textual-light" + ) + + +def app_test(): + app = UploadApp() + print(app.run()) + + +class DirectoryTreeApp(VerticalGroup): + def compose(self) -> ComposeResult: + yield DirectoryTree("./") + + +class ExitButton(VerticalGroup): + # CSS_PATH = "css/button.tcss" + + def compose(self) -> ComposeResult: + yield Button("Exit", variant="primary") + + def on_button_pressed(self, event: Button.Pressed) -> None: + self.app.exit() diff --git a/src/toolviper/utils/css/button.tcss b/src/toolviper/utils/css/button.tcss new file mode 100644 index 0000000..704f20e --- /dev/null +++ b/src/toolviper/utils/css/button.tcss @@ -0,0 +1,12 @@ +Button { + margin: 1 2; +} + +Horizontal > VerticalScroll { + width: 24; +} + +.header { + margin: 1 0 0 2; + text-style: bold; +} \ No newline at end of file diff --git a/src/toolviper/utils/sd/graph.py b/src/toolviper/utils/sd/graph.py index f1b7865..5bdddd9 100644 --- a/src/toolviper/utils/sd/graph.py +++ b/src/toolviper/utils/sd/graph.py @@ -12,27 +12,39 @@ def __init__(self): self._graph = None self._results = collections.defaultdict(list) - def source(self, job, axes, connect=False, node=None): + def source(self, job, axes, connect=False, type="", node=None): function_name = job["function"].__name__ previous = None + logger.info(f"Adding sink node for function: {function_name}") if connect: previous = self._graph + logger.info(f"Connecting to previous node: {previous}") if node is not None: try: + logger.info(f"Connecting to user-supplied node: {node}") previous = self._results[node] except KeyError: logger.error(f"Node {node} not found in results.") - self._graph = toolviper.utils.sd.distribute( - job=job, axes=axes, function=job["function"], previous=previous - ) + logger.info(f"Distributing function: {function_name} on axes: {axes}") + if type == "tree": + for _previous in previous: + self._graph = toolviper.utils.sd.distribute( + job=job, axes=axes, function=job["function"], previous=_previous + ) + else: + self._graph = toolviper.utils.sd.distribute( + job=job, axes=axes, function=job["function"], previous=previous + ) self._results[function_name].append(self._graph) def sink(self, function, edges=None): + logger.info(f"Adding sink node for function: {function.__name__}") + self._results[function.__name__].append(self._graph) self._graph = dask.delayed(function)(self._graph) def visualize(self): diff --git a/src/toolviper/utils/sd/prototype.py b/src/toolviper/utils/sd/prototype.py index 46c598b..22f5479 100644 --- a/src/toolviper/utils/sd/prototype.py +++ b/src/toolviper/utils/sd/prototype.py @@ -9,8 +9,8 @@ # Build a simple Dask dataset based on a given set of axes def simulate(field, spw, polarization, antenna, row): data_shape = { - "field": [f"field_{i}" for i in range(field)], - "spw": [f"spw_{i}" for i in range(spw)], + "field": [i for i in range(field)], + "spw": [i for i in range(spw)], "polarization": polarization, "antenna": [f"antenna_{i}" for i in range(antenna)], "row": [i for i in range(row)], @@ -57,9 +57,12 @@ def distribute( """ # Get the coordinate values for each axis axis_values = [job["dataset"].coords[axis].values for axis in axes] + # arguments = {axis: job["dataset"].coords[axis].values for axis in axes} + # inputs = [dict(zip(axes, combo)) for combo in itertools.product(*arguments.values())] if isinstance(previous, list): axis_values.append(previous) + # inputs.append(previous) # Create a delayed version of the function delayed_func = dask.delayed(function) @@ -67,7 +70,29 @@ def distribute( # Use itertools.product to generate all combinations of axis values # and create a delayed task for each combination. # The axis values are passed as positional arguments after 'previous'. + return [ delayed_func(*values) if previous is not None else delayed_func(*values) - for values in itertools.product(*axis_values) + for values in itertools.product(*axis_values) # previously axis_values ] + # output = [] + # for values in inputs: + # print(values) + # if isinstance(values, dict) and previous is not None: + # print("===== dict") + # output.append(delayed_func(**values)) + + # elif isinstance(values, list): + # print("===== list") + # output.append(delayed_func(*values)) + + # else: + # print("===== idk") + # output.append(delayed_func(values)) + + # return output + + # return [ + # delayed_func(**values) if previous is not None else delayed_func(**values) + # for values in inputs # previously axis_values + # ] From 27c2136be7ffbc1e20ef6e056d8c78cb6a1432a6 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Tue, 10 Mar 2026 16:29:40 +0900 Subject: [PATCH 23/24] Updates~ II --- docs/file-manifest-update.ipynb | 57 +++++++++++++++++---------------- src/toolviper/utils/__init__.py | 1 + 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/docs/file-manifest-update.ipynb b/docs/file-manifest-update.ipynb index 9f52031..4b16a4e 100644 --- a/docs/file-manifest-update.ipynb +++ b/docs/file-manifest-update.ipynb @@ -2,21 +2,19 @@ "cells": [ { "cell_type": "code", - "execution_count": null, "id": "33d17704-1b0c-40dc-929c-75338f83c3c2", "metadata": {}, - "outputs": [], "source": [ "import toolviper\n", "import pathlib" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "7b0ca451-df41-4d13-ac4c-3e2b7601eba4", "metadata": {}, - "outputs": [], "source": [ "def make_random_file(file):\n", " import random\n", @@ -27,56 +25,57 @@ " handle.write(random.randbytes(1024))\n", "\n", " subprocess.run([\"zip\", \"-r\", f\"{file}.zip\", file])" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "937a9e64-cc8a-4283-a3d8-1f5118592d52", "metadata": {}, - "outputs": [], "source": [ "path = pathlib.Path().cwd()\n", "\n", "toolviper.utils.data.update(path=str(path))" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "8589d72e-c594-400a-8660-f54336b9c4d6", "metadata": {}, - "outputs": [], "source": [ - "_json = toolviper.utils.tools.open_json(\"file.download.json\")" - ] + "_json = toolviper.utils.tools.open_json(\"file.download.json\")\n", + "toolviper.utils.display.DataDict.html(_json[\"metadata\"])" + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "8e66f912-d356-486b-8992-6e3ce62ad672", "metadata": {}, - "outputs": [], "source": [ "make_random_file(file=\"single-dish.ultra.calibrated.ms\")" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "05c07272-88ee-424a-9219-78e3ea52a0d7", "metadata": {}, - "outputs": [], "source": [ "make_random_file(file=\"alma.mega.uncalibrated.ms\")" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "7fd8b0b1-302b-454d-921b-9929af36e44e", "metadata": {}, - "outputs": [], "source": [ "entries = []\n", "\n", @@ -89,14 +88,14 @@ "}\n", "\n", "entries.append(entry)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "d2bc0fd4-64c3-42e4-9422-9c6dd69a9c32", "metadata": {}, - "outputs": [], "source": [ "entry = {\n", " \"file\": str(path.joinpath(\"alma.mega.uncalibrated.ms.zip\")),\n", @@ -107,21 +106,23 @@ "}\n", "\n", "entries.append(entry)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "0e98fb83-eb71-408d-b476-1d821d14b7f0", "metadata": {}, - "outputs": [], "source": [ "_ = toolviper.utils.tools.add_entry(\n", " entries=entries,\n", " manifest=str(path.joinpath(\"file.download.json\")),\n", " versioning=\"patch\"\n", ")" - ] + ], + "outputs": [], + "execution_count": null } ], "metadata": { diff --git a/src/toolviper/utils/__init__.py b/src/toolviper/utils/__init__.py index 9b160b8..3b5bad1 100644 --- a/src/toolviper/utils/__init__.py +++ b/src/toolviper/utils/__init__.py @@ -5,6 +5,7 @@ from .tools import open_json, calculate_checksum, verify, add_entry from .profile import memory_usage, cpu_usage from .sd import prototype +from .display import DataDict from .data import download From e60bd81902d6d3379502988bf0e9ecddf9a54eb9 Mon Sep 17 00:00:00 2001 From: jrhosk Date: Tue, 10 Mar 2026 20:15:58 +0900 Subject: [PATCH 24/24] Updates~ III --- src/toolviper/utils/display.py | 1 - src/toolviper/utils/parameter.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/toolviper/utils/display.py b/src/toolviper/utils/display.py index 621b89a..ba50f90 100755 --- a/src/toolviper/utils/display.py +++ b/src/toolviper/utils/display.py @@ -1,6 +1,5 @@ import re import operator -import sys from typing import Dict from IPython.core.display import HTML diff --git a/src/toolviper/utils/parameter.py b/src/toolviper/utils/parameter.py index 3a2899e..a5cc0b4 100644 --- a/src/toolviper/utils/parameter.py +++ b/src/toolviper/utils/parameter.py @@ -1,13 +1,12 @@ import functools import glob -import importlib import inspect import json import os import pathlib import pkgutil from types import ModuleType -from typing import Any, Callable, Dict, List, NoReturn, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import toolviper.utils.console as console import toolviper.utils.logger @@ -43,7 +42,7 @@ def wrapper(*args, **kwargs): meta_data["function"] = function.__name__ meta_data["module"] = function.__module__ - # If this is a class method, drop the self entry. + # If this is a class method, drop the self-entry. if "self" in list(arguments.keys()): class_name = args[0].__class__.__name__ meta_data["function"] = ".".join((class_name, function.__name__))