diff --git a/docs/add_entry.ipynb b/add_entry.ipynb similarity index 79% rename from docs/add_entry.ipynb rename to add_entry.ipynb index 99e0814..9fcf689 100644 --- a/docs/add_entry.ipynb +++ b/add_entry.ipynb @@ -23,6 +23,16 @@ "toolviper.utils.data.update(path=str(path))" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "0dda04ad-1ecc-4925-92bb-9a608f81d7e1", + "metadata": {}, + "outputs": [], + "source": [ + "toolviper.utils.data.get_file_size(str(path))" + ] + }, { "cell_type": "code", "execution_count": null, @@ -33,11 +43,11 @@ "entries = []\n", "\n", "entry = {\n", - " \"file\": str(path.joinpath(\"ngc5921-lsrk-cube.psf.zip\")),\n", + " \"file\": str(path.joinpath(\"upload/casa_no_sky_to_xds_true.zarr.zip\")),\n", " \"path\": \"radps/image\",\n", " \"dtype\": \"CASA image\",\n", " \"telescope\": \"VLA\",\n", - " \"mode\": \"Interferometric\"\n", + " \"mode\": \"Simulated\"\n", "}\n", "\n", "entries.append(entry)" @@ -56,26 +66,6 @@ " versioning=\"patch\"\n", ")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a113bb2b-fed4-4ede-8d8b-353958f59602", - "metadata": {}, - "outputs": [], - "source": [ - "#toolviper.utils.data.update()\n", - "\n", - "#toolviper.utils.data.download(file=\"ngc5921-lsrk-cube.psf\", folder=\"test\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7df91b02-3e25-4989-a76e-20ca4b7a9cb2", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/api.rst b/docs/api.rst index 06b8cb1..079c611 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -4,7 +4,4 @@ API .. toctree:: :maxdepth: 2 - _api/autoapi/toolviper/dask/client/index - _api/autoapi/toolviper/graph_tools/coordinate_utils/index - _api/autoapi/toolviper/graph_tools/map/index - _api/autoapi/toolviper/graph_tools/reduce/index \ No newline at end of file + _api/autoapi/toolviper/dask/client/index \ No newline at end of file diff --git a/docs/client_tutorial.ipynb b/docs/client_tutorial.ipynb index c1b179d..e528e0b 100644 --- a/docs/client_tutorial.ipynb +++ b/docs/client_tutorial.ipynb @@ -40,7 +40,7 @@ }, "outputs": [], "source": [ - "toolviper.utils.data.download(file=\"AA2-Mid-sim_00000.ms\")" + "toolviper.utils.data.download(file=\"AA2-Mid-sim_00000.ms\", folder=\"data\")" ] }, { @@ -197,7 +197,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.13" + "version": "3.12.12" } }, "nbformat": 4, diff --git a/docs/download_example.ipynb b/docs/download_example.ipynb index 895f169..c08ca49 100644 --- a/docs/download_example.ipynb +++ b/docs/download_example.ipynb @@ -10,15 +10,6 @@ "import toolviper" ] }, - { - "cell_type": "markdown", - "id": "01bc8397-af47-48df-8391-733bba286588", - "metadata": {}, - "source": [ - "## Getting download metadata version\n", - "#### This will retireve the current file download metadata version; if the file is not found it will attempt to update to the most recent version." - ] - }, { "cell_type": "code", "execution_count": null, @@ -31,50 +22,58 @@ }, { "cell_type": "markdown", - "id": "fe18a0d9-bd21-4ce5-91fd-6348e5f1369e", + "id": "0351dda7-7559-449d-9ebb-c853beb0861e", "metadata": {}, "source": [ - "### Manually update metdata info." + "## Getting available downloadable file in a python list.\n", + "#### This will return an unordered list of the available file on the remote dropbox in a python list. This can be used as an input to thie download function as well." ] }, { "cell_type": "code", "execution_count": null, - "id": "deeee662-94d9-44a0-95e7-f062f55223ca", + "id": "2912a5ae-7a4b-42f3-a633-cbe5ecaffc74", "metadata": {}, "outputs": [], "source": [ - "toolviper.utils.data.update()" + "files = toolviper.utils.data.get_files()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28c8b343-3a05-4217-a581-e51c7db95d02", + "metadata": {}, + "outputs": [], + "source": [ + "toolviper.utils.data.download(file=files[3:8], folder=\"data\")" ] }, { "cell_type": "markdown", - "id": "0351dda7-7559-449d-9ebb-c853beb0861e", + "id": "01bc8397-af47-48df-8391-733bba286588", "metadata": {}, "source": [ - "## Getting available downloadable file in a python list.\n", - "#### This will return an unordered list of the available file on the remote dropbox in a python list. This can be used as an input to thie download function as well." + "## Getting download metadata version\n", + "#### This will retireve the current file download metadata version; if the file is not found it will attempt to update to the most recent version." ] }, { - "cell_type": "code", - "execution_count": null, - "id": "2912a5ae-7a4b-42f3-a633-cbe5ecaffc74", + "cell_type": "markdown", + "id": "fe18a0d9-bd21-4ce5-91fd-6348e5f1369e", "metadata": {}, - "outputs": [], "source": [ - "files = toolviper.utils.data.get_files()\n", - "files" + "### Manually update metdata info." ] }, { "cell_type": "code", "execution_count": null, - "id": "28c8b343-3a05-4217-a581-e51c7db95d02", + "id": "d521c952-7f30-4454-8548-de5ac5b98d82", "metadata": {}, "outputs": [], "source": [ - "toolviper.utils.data.download(file=files[6:8], folder=\"data\")" + "toolviper.utils.data.update()" ] }, { diff --git a/docs/Example/ascii/snek.txt b/docs/example/ascii/snek.txt similarity index 100% rename from docs/Example/ascii/snek.txt rename to docs/example/ascii/snek.txt diff --git a/docs/Example/config/viper.param.json b/docs/example/config/viper.param.json similarity index 100% rename from docs/Example/config/viper.param.json rename to docs/example/config/viper.param.json diff --git a/docs/Example/viper.py b/docs/example/viper.py similarity index 100% rename from docs/Example/viper.py rename to docs/example/viper.py diff --git a/docs/file-manifest-update.ipynb b/docs/file-manifest-update.ipynb index 8b40c6d..4b16a4e 100644 --- a/docs/file-manifest-update.ipynb +++ b/docs/file-manifest-update.ipynb @@ -2,21 +2,19 @@ "cells": [ { "cell_type": "code", - "execution_count": null, "id": "33d17704-1b0c-40dc-929c-75338f83c3c2", "metadata": {}, - "outputs": [], "source": [ "import toolviper\n", "import pathlib" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "7b0ca451-df41-4d13-ac4c-3e2b7601eba4", "metadata": {}, - "outputs": [], "source": [ "def make_random_file(file):\n", " import random\n", @@ -27,46 +25,57 @@ " handle.write(random.randbytes(1024))\n", "\n", " subprocess.run([\"zip\", \"-r\", f\"{file}.zip\", file])" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "937a9e64-cc8a-4283-a3d8-1f5118592d52", "metadata": {}, - "outputs": [], "source": [ "path = pathlib.Path().cwd()\n", "\n", "toolviper.utils.data.update(path=str(path))" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "id": "8e66f912-d356-486b-8992-6e3ce62ad672", + "id": "8589d72e-c594-400a-8660-f54336b9c4d6", "metadata": {}, + "source": [ + "_json = toolviper.utils.tools.open_json(\"file.download.json\")\n", + "toolviper.utils.display.DataDict.html(_json[\"metadata\"])" + ], "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "8e66f912-d356-486b-8992-6e3ce62ad672", + "metadata": {}, "source": [ "make_random_file(file=\"single-dish.ultra.calibrated.ms\")" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "05c07272-88ee-424a-9219-78e3ea52a0d7", "metadata": {}, - "outputs": [], "source": [ "make_random_file(file=\"alma.mega.uncalibrated.ms\")" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "7fd8b0b1-302b-454d-921b-9929af36e44e", "metadata": {}, - "outputs": [], "source": [ "entries = []\n", "\n", @@ -79,14 +88,14 @@ "}\n", "\n", "entries.append(entry)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "d2bc0fd4-64c3-42e4-9422-9c6dd69a9c32", "metadata": {}, - "outputs": [], "source": [ "entry = {\n", " \"file\": str(path.joinpath(\"alma.mega.uncalibrated.ms.zip\")),\n", @@ -97,21 +106,23 @@ "}\n", "\n", "entries.append(entry)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "0e98fb83-eb71-408d-b476-1d821d14b7f0", "metadata": {}, - "outputs": [], "source": [ "_ = toolviper.utils.tools.add_entry(\n", " entries=entries,\n", " manifest=str(path.joinpath(\"file.download.json\")),\n", " versioning=\"patch\"\n", ")" - ] + ], + "outputs": [], + "execution_count": null } ], "metadata": { diff --git a/docs/hsd_imaging_skeleton.ipynb b/docs/hsd_imaging_skeleton.ipynb new file mode 100644 index 0000000..680f28d --- /dev/null +++ b/docs/hsd_imaging_skeleton.ipynb @@ -0,0 +1,1070 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "01a56c62-2896-4857-8af6-e70539e8b2eb", + "metadata": {}, + "outputs": [], + "source": [ + "import dask\n", + "import time\n", + "import toolviper\n", + "import random\n", + "import webbrowser\n", + "\n", + "import toolviper.utils.logger as logger\n", + "import toolviper.utils.display as display\n", + "\n", + "import toolviper.dask.client as client\n", + "\n", + "from collections import defaultdict" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b14a64bf-3470-45da-b3f2-f44f59eb26f1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\u001b[38;2;128;05;128m2026-03-10 14:28:45,835\u001b[0m] \u001b[38;2;255;160;0m WARNING\u001b[0m\u001b[38;2;112;128;144m client: \u001b[0m It is recommended that the local cache directory be set using the \u001b[38;2;50;50;205mdask_local_dir\u001b[0m parameter. \n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:46,501\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m client: \u001b[0m Loading plugin module: \n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:46,973\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_9: \u001b[0m Logger created on worker Worker-66fa468d-877c-4fb5-b075-43dabff8712a,*,tcp://127.0.0.1:61787\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:46,995\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_5: \u001b[0m Logger created on worker Worker-838298e3-ea06-416c-b2a5-c16badef1739,*,tcp://127.0.0.1:61773\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,005\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_0: \u001b[0m Logger created on worker Worker-ed10022a-2dbc-4241-a528-5814ec2428a2,*,tcp://127.0.0.1:61770\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,009\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_2: \u001b[0m Logger created on worker Worker-16591390-0985-4a28-8d51-7b98cd74d451,*,tcp://127.0.0.1:61778\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,010\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_1: \u001b[0m Logger created on worker Worker-a02643d0-e433-43b4-8d9b-36c78c514c22,*,tcp://127.0.0.1:61765\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,015\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_3: \u001b[0m Logger created on worker Worker-891fb5de-ca05-4c17-a11a-61be28e34270,*,tcp://127.0.0.1:61768\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,019\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_8: \u001b[0m Logger created on worker Worker-0693380a-b14b-41c4-8db6-50ace56af341,*,tcp://127.0.0.1:61779\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,022\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_6: \u001b[0m Logger created on worker Worker-8db82b11-0870-4414-9cd2-fcf8f4e657d5,*,tcp://127.0.0.1:61786\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,027\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_4: \u001b[0m Logger created on worker Worker-8e911e2d-6ba0-4684-b624-727ad3a771b8,*,tcp://127.0.0.1:61769\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,037\u001b[0m] \u001b[38;2;46;139;87m DEBUG\u001b[0m\u001b[38;2;112;128;144m worker_7: \u001b[0m Logger created on worker Worker-e72f454a-9bea-4f38-b75b-be4fea40277f,*,tcp://127.0.0.1:61792\n", + "[\u001b[38;2;128;05;128m2026-03-10 14:28:47,037\u001b[0m] \u001b[38;2;50;50;205m INFO\u001b[0m\u001b[38;2;112;128;144m client: \u001b[0m Client \n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client = client.local_client(\n", + " cores=10,\n", + " log_params={\n", + " \"log_to_file\":False,\n", + " \"log_to_term\":True,\n", + " \"log_level\":\"DEBUG\" \n", + " },\n", + " worker_log_params={\n", + " \"log_to_file\":False,\n", + " \"log_to_term\":True,\n", + " \"log_level\":\"DEBUG\" \n", + " }\n", + ")\n", + "\n", + "# Spawn dashboard window in a seperate tab,\n", + "# comment out if you don't want this to spawn.\n", + "webbrowser.open(url=client.dashboard_link)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9273c405-6f67-46a9-9851-23b7778ce877", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 864B\n",
+       "Dimensions:       (field: 1, spw: 5, polarization: 2, antenna: 4, row: 2)\n",
+       "Coordinates:\n",
+       "  * field         (field) int64 8B 0\n",
+       "  * spw           (spw) int64 40B 0 1 2 3 4\n",
+       "  * polarization  (polarization) <U2 16B 'XX' 'YY'\n",
+       "  * antenna       (antenna) <U9 144B 'antenna_0' 'antenna_1' ... 'antenna_3'\n",
+       "  * row           (row) int64 16B 0 1\n",
+       "Data variables:\n",
+       "    DATA          (field, spw, polarization, antenna, row) float64 640B 0.0 ....
" + ], + "text/plain": [ + " Size: 864B\n", + "Dimensions: (field: 1, spw: 5, polarization: 2, antenna: 4, row: 2)\n", + "Coordinates:\n", + " * field (field) int64 8B 0\n", + " * spw (spw) int64 40B 0 1 2 3 4\n", + " * polarization (polarization) " + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "graph.visualize()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "eb391466-e2ca-43f3-8872-c1896eadd823", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "([None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " None],)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "graph.compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f92ea73-85a5-44f0-9d4e-b96fbc74faa4", + "metadata": {}, + "outputs": [], + "source": [ + "graph.nodes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16d71a8d-f4d9-4c21-8e9e-55a202d85e52", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/index.rst b/docs/index.rst index efa0e35..13cc99a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,11 +1,10 @@ -Graph Visibility and Image Parallel Execution Reduction +Tools for Visibility and Image Parallel Execution Reduction ======================================================= -toolviper is a `Dask `_ based MapReduce package. It allows for mapping a dictionary of `xarray.Datasets `_ to `Dask graph nodes `_ followed by a reduce step. +toolviper is a `Dask `_ based set of tools that can be used either with or independently with the VIPER framework. -**toolviper is in development and breaking API changes will happen.** +toolviper **is in development and breaking API changes will happen.** -The best place to start with toolviper is doing the `graph building tutorial `_ . `GitHub repository link `_ @@ -14,4 +13,3 @@ The best place to start with toolviper is doing the `graph building tutorial \u001b[0m]: Here is an error message.\n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,512\u001b[0m] \u001b[7;38;2;220;60;20mCRITICAL\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[7;38;2;220;60;20m\u001b[0m]: Here is an critical message.\n" - ] - } - ], + "outputs": [], "source": [ "logger.info(\"Here is an info message.\")\n", "logger.warning(\"Here is an warning message.\")\n", @@ -330,21 +255,10 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "2d2b8d48", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,519\u001b[0m] \u001b[38;2;50;50;205m INFO\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[38;2;50;50;205mverbose_log\u001b[0m]: Here's a info message. \n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,522\u001b[0m] \u001b[38;2;255;160;0m WARNING\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[38;2;255;160;0mverbose_log\u001b[0m]: Here's a warning message. \n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,528\u001b[0m] \u001b[38;2;220;20;60m ERROR\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[38;2;220;20;60mverbose_log\u001b[0m]: Here's a error message.\n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,532\u001b[0m] \u001b[7;38;2;220;60;20mCRITICAL\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[7;38;2;220;60;20mverbose_log\u001b[0m]: Here's a critical message.\n" - ] - } - ], + "outputs": [], "source": [ "def verbose_log():\n", " logger.info(\"Here's a info message.\", verbose=True)\n", @@ -368,34 +282,14 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "c9aa9c09-239b-4403-8d83-fb90c93c9782", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Setting verbosity to True\n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,539\u001b[0m] \u001b[38;2;50;50;205m INFO\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[38;2;50;50;205mverbose_log\u001b[0m]: Here's a info message. \n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,542\u001b[0m] \u001b[38;2;255;160;0m WARNING\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[38;2;255;160;0mverbose_log\u001b[0m]: Here's a warning message. \n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,548\u001b[0m] \u001b[38;2;220;20;60m ERROR\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[38;2;220;20;60mverbose_log\u001b[0m]: Here's a error message.\n", - "[\u001b[38;2;128;05;128m2025-04-10 12:39:45,552\u001b[0m] \u001b[7;38;2;220;60;20mCRITICAL\u001b[0m\u001b[38;2;112;128;144m viperlog: \u001b[0m [\u001b[7;38;2;220;60;20mverbose_log\u001b[0m]: Here's a critical message.\n" - ] - } - ], + "outputs": [], "source": [ "logger.set_verbosity(state=logger.VERBOSE)\n", "verbose_log()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a14d9eda-29a3-4470-801f-dc36958ed00d", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -414,7 +308,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.2" + "version": "3.12.12" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 2362d37..249b0a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ 'pandas', 'itables', 'requests', + 'responses', 'tabulate', 'tqdm', ] @@ -41,6 +42,7 @@ interactive = [ 'ipympl', 'ipython', 'jupyter-client', + 'textual' ] docs = [ @@ -73,6 +75,7 @@ all = [ 'ipykernel', 'ipympl', 'ipython', + 'ipywidgets', 'jupyter-client', 'nbsphinx', 'recommonmark', @@ -81,5 +84,6 @@ all = [ 'sphinx-autosummary-accessors', 'sphinx_rtd_theme', 'twine', - 'pandoc' + 'pandoc', + 'xarray' ] diff --git a/src/toolviper/dask/client.py b/src/toolviper/dask/client.py index 9a8d7b4..c71fb4f 100644 --- a/src/toolviper/dask/client.py +++ b/src/toolviper/dask/client.py @@ -6,57 +6,79 @@ import dask_jobqueue import distributed import psutil -import inspect import functools from importlib import import_module from importlib.util import find_spec -from typing import Dict, Union +from typing import Dict, Union, Any, Optional import toolviper.dask.menrva + import toolviper.utils.console as console import toolviper.utils.logger as logger import toolviper.utils.parameter as parameter +import toolviper.utils.display as display colorize = console.Colorize() +DEFAULT_CLIENT_LOG_PARAMS = { + "logger_name": "client", + "log_to_term": True, + "log_level": "INFO", + "log_to_file": False, + "log_file": "client.log", +} + +DEFAULT_WORKER_LOG_PARAMS = { + "logger_name": "worker", + "log_to_term": True, + "log_level": "INFO", + "log_to_file": False, + "log_file": "client_worker.log", +} + + +def _get_log_params( + log_params: Optional[Dict[str, Any]], defaults: Dict[str, Any] +) -> Dict[str, Any]: + if log_params is None: + log_params = {} + return {**defaults, **log_params} + def load_libraries(name: str, libs: Union[str, list[str]]) -> dict[str, bool]: """Load libraries if they were installed and can be loaded. Parameters ---------- - name : library group name - A library group name based on a function of a distributed environment will be imported. + name : str + A library group name based on a function of a distributed environment. libs : Union[str, list[str]] - a library or a list of libraries to import + A library or a list of libraries to import. Returns ------- - an item of dict has the name and the flag whether all libraries were loaded successfully. + dict[str, bool] + A dictionary mapping the group name to a boolean indicating if all libraries were loaded successfully. """ def _load_library(_lib): if find_spec(_lib) is not None: import_module(_lib) - return [True, f" {colorize.blue(_lib)} is available"] - else: - return [False, f" {colorize.blue(_lib)} is unavailable"] - - if isinstance(libs, list): - _tmp = list(map(_load_library, libs)) - _avail = [all([x[0] for x in _tmp]), [x[1] for x in _tmp]] - elif isinstance(libs, str): - _tmp = _load_library(libs) - _avail = [_tmp[0], [_tmp[1]]] - else: - _avail = [False, " illegal module specification"] + return True, f" {colorize.blue(_lib)} is available" + return False, f" {colorize.blue(_lib)} is unavailable" - _result = "Success" if _avail[0] else "Fail" - logger.info(f"Loading module: {name} -- {_result}") - [logger.info(x) for x in _avail[1]] + if isinstance(libs, str): + libs = [libs] - return {name: _avail[0]} + results = [_load_library(lib) for lib in libs] + all_available = all(res[0] for res in results) + + logger.info(f"Loading module: {name} -- {'Success' if all_available else 'Fail'}") + for _, message in results: + logger.info(message) + + return {name: all_available} def print_libraries_availability(spec: dict[str, bool]): @@ -117,109 +139,53 @@ def get_cluster() -> Union[None, distributed.LocalCluster]: @parameter.validate() def local_client( - cores: int = None, - memory_limit: str = None, + cores: Optional[int] = None, + memory_limit: Optional[str] = None, autorestrictor: bool = False, - dask_local_dir: str = None, - local_dir: str = None, + dask_local_dir: Optional[str] = None, + local_dir: Optional[str] = None, wait_for_workers: bool = True, - log_params: Union[None, Dict] = None, - worker_log_params: Union[None, Dict] = None, + log_params: Optional[Dict[str, Any]] = None, + worker_log_params: Optional[Dict[str, Any]] = None, dashboard_address: str = ":8787", serial_execution: bool = False, -) -> Union[distributed.Client, None]: - """ Creates a local client, scheduler and workers using Dask Distributed LocalCluster (https://docs.dask.org/en/stable/deploying-python.html#reference) - with Dask configuration tuned for VIPER and the option to use autorestrictor plugin and local cache. +) -> Optional[distributed.Client]: + """Create a local client, scheduler and workers using Dask Distributed LocalCluster. + + With Dask configuration tuned for VIPER and the option to use autorestrictor plugin and local cache. + See https://docs.dask.org/en/stable/deploying-python.html#reference for more details. Parameters ---------- - cores : int - Number of cores in Dask cluster, defaults to None - memory_limit : str - Amount of memory per core. It is suggested to use '8GB', defaults to None - autorestrictor : bool - Boolean determining usage of autorestrictor plugin, defaults to False - dask_local_dir : str - Where Dask should store temporary files, defaults to None. If None Dask will use \ - `./dask-worker-space`, defaults to None - local_dir : str - Defines client local directory, defaults to None - - wait_for_workers : bool - Boolean determining usage of wait_for_workers option in dask, defaults to False - log_params : dict - The logger for the main process (code that does not run in parallel), defaults to {} - worker_log_params : dict - worker_log_params: Keys as same as log_params, default values given in `Additional \ - Information`_. - - dashboard_address: str - Address on which to listen for the Bokeh diagnostics server like ‘localhost:8787’ or ‘0.0.0.0:8787’. Defaults to ‘:8787’. - Set to None to disable the dashboard. Use ‘:0’ for a random port. See https://docs.dask.org/en/stable/deploying-python.html#reference for more information. - - serial_execution : bool - This is an option that forces dask to run in serial mode while also setting up the logger to work. This is - really only appropriate for debugging. - - .. _Description: - - ** _log_params ** - - The log_params (worker_log_params) dictionary stores initialization information for the logger and associated - workers. the following are the acceptable key: value pairs and their usage information. - - log_params["logger_name"] : str - Defines the logger name to use - log_params["log_to_term"] : bool - Should messages log to the terminal output. - log_params["log_level"] : str - Defines logging level, valid options: - - DEBUG - - INFO - - WARNING - - ERROR - - CRITICAL - - Only messages flagged as at the given level or below are logged. - - log_params["log_to_file"] : str - Should messages log to file. - - log_params["log_file"] : str - Name of log file to create. If none is given, the file name 'logger' will be used. + cores : int, optional + Number of cores in Dask cluster. Defaults to number of physical cores. + memory_limit : str, optional + Amount of memory per core. Suggested: '8GB'. Defaults to available memory divided by cores. + autorestrictor : bool, optional + Whether to use the autorestrictor plugin. Defaults to False. + dask_local_dir : str, optional + Temporary files directory for Dask. Defaults to None. + local_dir : str, optional + Client local directory. Defaults to None. + wait_for_workers : bool, optional + Whether to wait for workers to start. Defaults to True. + log_params : dict, optional + Logger configuration for the main process. + worker_log_params : dict, optional + Logger configuration for workers. + dashboard_address : str, optional + Address for the Bokeh diagnostics server (e.g., 'localhost:8787'). Defaults to ':8787'. + serial_execution : bool, optional + If True, runs Dask in serial mode (synchronous) for debugging. Defaults to False. Returns ------- - Dask Distributed Client + distributed.Client or None + Dask Distributed Client, or None if serial_execution is True. """ - if log_params is None: - log_params = {} - - log_params = { - **{ - "logger_name": "client", - "log_to_term": True, - "log_level": "INFO", - "log_to_file": False, - "log_file": "client.log", - }, - **log_params, - } - - if worker_log_params is None: - worker_log_params = {} - - worker_log_params = { - **{ - "logger_name": "worker", - "log_to_term": True, - "log_level": "INFO", - "log_to_file": False, - "log_file": "client_worker.log", - }, - **worker_log_params, - } + log_params = _get_log_params(log_params, DEFAULT_CLIENT_LOG_PARAMS) + worker_log_params = _get_log_params(worker_log_params, DEFAULT_WORKER_LOG_PARAMS) # If the user wants to change the global logger name from the # default value of toolviper @@ -310,11 +276,13 @@ def local_client( client = toolviper.dask.menrva.MenrvaClient(cluster) client.get_versions(check=True) - # When constructing a graph that has local cache enabled, all workers need to be up and running. + # When constructing a graph that has local-cache enabled, all workers need to be up and running. if local_cache or wait_for_workers: client.wait_for_workers(n_workers=cores) - logger.debug(f"These are the worker log parameters:\n {worker_log_params}") + # logger.debug(f"These are the worker log parameters:\n") + # logger.debug(f"{display.DataDict.from_dict(worker_log_params).display(interactive=False)}") + if local_cache or worker_log_params: client.load_plugin( directory=plugin_path, @@ -329,82 +297,34 @@ def local_client( return client +@parameter.validate() def distributed_client( - cluster: None, - dask_local_dir: str = None, - log_params: Union[None, Dict] = None, - worker_log_params: Union[None, Dict] = None, -) -> Union[distributed.Client, None]: - """ Setup dask cluster and logger. + cluster: Any, + dask_local_dir: Optional[str] = None, + log_params: Optional[Dict[str, Any]] = None, + worker_log_params: Optional[Dict[str, Any]] = None, +) -> distributed.Client: + """Setup dask cluster and logger. Parameters ---------- - cluster - log_params : dict - The logger for the main process (code that does not run in parallel), defaults to {} - worker_log_params : dict - worker_log_params: Keys as same as log_params, default values given in `Additional \ - Information`_. - - .. _Description: - - ** _log_params ** - - The log_params (worker_log_params) dictionary stores initialization information for the logger and associated - workers. the following are the acceptable key: value pairs and their usage information. - - log_params["logger_name"] : str - Defines the logger name to use - log_params["log_to_term"] : bool - Should messages log to the terminal output. - log_params["log_level"] : str - Defines logging level, valid options: - - DEBUG - - INFO - - WARNING - - ERROR - - CRITICAL - - Only messages flagged as at the given level or below are logged. - - log_params["log_to_file"] : str - Should messages log to file. - - log_params["log_filee"] : str - Name of log file to create. If none is given, the file name 'logger' will be used. + cluster : Any + An existing dask cluster instance. + dask_local_dir : str, optional + Where Dask should store temporary files. + log_params : dict, optional + The logger for the main process. + worker_log_params : dict, optional + The logger for the workers. Returns ------- + distributed.Client Dask Distributed Client """ - if log_params is None: - log_params = {} - - log_params = { - **{ - "logger_name": "client", - "log_to_term": True, - "log_level": "INFO", - "log_to_file": False, - "log_file": "client.log", - }, - **log_params, - } - - if worker_log_params is None: - worker_log_params = {} - - worker_log_params = { - **{ - "logger_name": "worker", - "log_to_term": True, - "log_level": "INFO", - "log_to_file": False, - "log_file": "client_worker.log", - }, - **worker_log_params, - } + log_params = _get_log_params(log_params, DEFAULT_CLIENT_LOG_PARAMS) + worker_log_params = _get_log_params(worker_log_params, DEFAULT_WORKER_LOG_PARAMS) # If the user wants to change the global logger name from the # default value of toolviper @@ -420,11 +340,6 @@ def distributed_client( _set_up_dask(dask_local_dir) - """ - load libraries related functions of a distributed environment - 'available_specs' contains the function name and a flag that the function was loaded successfully - """ - logger.debug(colorize.green("Checking functions availability:")) available_specs = { **load_libraries("slurm", "dask_jobqueue"), @@ -434,16 +349,13 @@ def distributed_client( print_libraries_availability(available_specs) - # This will work as long as the scheduler path isn't in some outside directory. Being that it is a plugin specific - # to this module, I think keeping it static in the module directory it good. - plugin_path = str(pathlib.Path(__file__).parent.resolve().joinpath("plugins/")) - client = toolviper.dask.menrva.MenrvaClient(cluster) client.get_versions(check=True) logger.info("Created client " + str(client)) return client +@parameter.validate() def slurm_cluster_client( workers_per_node: int, cores_per_node: int, @@ -456,111 +368,61 @@ def slurm_cluster_client( dask_log_dir: str, exclude_nodes: str = "", dashboard_port: int = 8787, - local_dir: str = None, + local_dir: Optional[str] = None, autorestrictor: bool = False, wait_for_workers: bool = True, - log_params: Union[None, Dict] = None, - worker_log_params: Union[None, Dict] = None, -): - """Creates a Dask slurm_cluster_client on a multinode cluster. - - interface eth0, ib0 + log_params: Optional[Dict[str, Any]] = None, + worker_log_params: Optional[Dict[str, Any]] = None, +) -> distributed.Client: + """Create a SLURM cluster and return a client. Parameters ---------- workers_per_node : int - Number of workers per node ... - + Number of workers per node. cores_per_node : int - Number of cores per node ... - + Number of cores per node. memory_per_node : str - Memory allocation per node ... - + Memory per node (e.g., '64GB'). number_of_nodes : int - Number of nodes ... - + Number of nodes to request. queue : str - Destination queue for each worker job. Passed to #SBATCH -p option - + SLURM queue name. interface : str - Network interface like ‘eth0’ or ‘ib0’. This will be used both for the Dask scheduler and the Dask workers - interface. If you need a different interface for the Dask scheduler you can pass it through the - scheduler_options argument: interface=your_worker_interface, - scheduler_options={'interface': your_scheduler_interface}. - + Network interface to use (e.g., 'ib0'). python_env_dir : str - Python executable used to launch Dask workers. Defaults to the Python that is submitting these jobs. - + Path to the python executable in the environment. dask_local_dir : str - Where Dask should store temporary files, defaults to None. If None Dask will use \ - `./dask-worker-space`, defaults to None - - local_dir : str - Defines client local directory, defaults to None - + Local directory for dask workers. dask_log_dir : str - Destination directory for dask log files. - - exclude_nodes : str - Nodes to exclude. - - dashboard_port : int - Port to use for dashboard connection. - - autorestrictor : bool - Boolean determining usage of autorestrictor plugin, defaults to False - - wait_for_workers : bool - Boolean determining usage of wait_for_workers option in dask, defaults to False - - log_params : dict - Dictionary containing parameters to using for logging. - - worker_log_params : dict - Dictionary containing parameters to using for worker logging. - - .. _Description: - - ** _log_params ** - - The log_params (worker_log_params) dictionary stores initialization information for the logger and associated - workers. the following are the acceptable key: value pairs and their usage information. - - log_params["logger_name"] : str - Defines the logger name to use - log_params["log_to_term"] : bool - Should messages log to the terminal output. - log_params["log_level"] : str - Defines logging level, valid options: - - DEBUG - - INFO - - WARNING - - ERROR - - CRITICAL - - Only messages flagged as at the given level or below are logged. - - log_params["log_to_file"] : str - Should messages log to file. - - log_params["log_filee"] : str - Name of log file to create. If none is given, the file name 'logger' will be used. + Directory for dask logs. + exclude_nodes : str, optional + Comma-separated list of nodes to exclude. + dashboard_port : int, optional + Port for the dask dashboard. + local_dir : str, optional + Client local directory. + autorestrictor : bool, optional + Whether to use the autorestrictor plugin. + wait_for_workers : bool, optional + Whether to wait for workers to start. + log_params : dict, optional + Logger parameters for the client. + worker_log_params : dict, optional + Logger parameters for the workers. Returns ------- - distributed.Client + distributed.Client + The dask client connected to the SLURM cluster. """ # https://github.com/dask/dask/issues/5577 # from distributed import Client - if log_params is None: - log_params = {} - - if worker_log_params is None: - worker_log_params = {} + log_params = _get_log_params(log_params, DEFAULT_CLIENT_LOG_PARAMS) + worker_log_params = _get_log_params(worker_log_params, DEFAULT_WORKER_LOG_PARAMS) if local_dir: os.environ["VIPER_LOCAL_DIR"] = local_dir @@ -568,9 +430,6 @@ def slurm_cluster_client( else: local_cache = False - # Viper logger for code that is not part of the Dask graph. The worker logger is setup in the _worker plugin. - # from viper._utils._logger import setup_logger - logger.setup_logger(**log_params) _set_up_dask(dask_local_dir) @@ -606,22 +465,6 @@ def slurm_cluster_client( } ) - # This method of assigning a worker plugin does not seem to work when using dask_jobqueue. Consequently, using - # client.register_plugin so that the method of assigning a worker plugin is the same for local_client and - # slurm_cluster_client. - # - # if local_cache or worker_log_params: - # dask.config.set({"distributed.worker.preload": os.path.join(plugin_path,"_utils/_worker.py")}) - # dask.config.set({ - # "distributed.worker.preload-argv": [ - # "--local_cache",local_cache, - # "--log_to_term",worker_log_params["log_to_term"], - # "--log_to_file",worker_log_params["log_to_file"], - # "--log_file",worker_log_params["log_file"], - # "--log_level",worker_log_params["log_level"]] - # }) - # - cluster = dask_jobqueue.SLURMCluster( processes=workers_per_node, cores=cores_per_node, @@ -634,16 +477,13 @@ def slurm_cluster_client( local_directory=dask_local_dir, log_directory=dask_log_dir, job_extra_directives=["--exclude=" + exclude_nodes], - # job_extra_directives=["--exclude=nmpost087,nmpost089,nmpost088"], scheduler_options={"dashboard_address": ":" + str(dashboard_port)}, - ) # interface="ib0" + ) client = toolviper.dask.menrva.MenrvaClient(cluster) - cluster.scale(workers_per_node * number_of_nodes) # When constructing a graph that has local cache enabled all workers need to be up and running. - if local_cache or wait_for_workers: client.wait_for_workers(n_workers=workers_per_node * number_of_nodes) @@ -662,34 +502,39 @@ def slurm_cluster_client( def auto_client(): + """ + A decorator that automatically manages a Dask client for the decorated function. + + If a client already exists, it uses the existing one. + Otherwise, it creates a new local_client and shuts it down after the function completes. + """ + def function_wrapper(function): @functools.wraps(function) def wrapper(*args, **kwargs): - persistent_client = False - - if not get_client() is None: - client = get_client() - persistent_client = True - else: + client = get_client() + persistent_client = client is not None + if not persistent_client: # Get client inputs if they exist - arguments = inspect.getcallargs(function, *args, **kwargs) - if "client" in kwargs.keys(): - client = local_client(**kwargs["client"]) - + if "client" in kwargs: + client_kwargs = kwargs["client"] + if isinstance(client_kwargs, dict): + client = local_client(**client_kwargs) + else: + client = local_client() else: client = local_client() try: - print(f"Dask dashboard started at: {client.dashboard_link}") + if client: + logger.info(f"Dask dashboard started at: {client.dashboard_link}") # Run the decorated function - result = function(*args, **kwargs) - - return result + return function(*args, **kwargs) finally: # Ensure the client is closed even if the function raises an exception - if not persistent_client: + if not persistent_client and client: client.shutdown() return wrapper diff --git a/src/toolviper/dask/menrva.py b/src/toolviper/dask/menrva.py index 8bcf4ac..53d9b03 100644 --- a/src/toolviper/dask/menrva.py +++ b/src/toolviper/dask/menrva.py @@ -118,17 +118,17 @@ def call(func: Callable, *args: Tuple[Any], **kwargs: Dict[str, Any]): @staticmethod def instantiate_module( plugin: str, plugin_file: str, *args: Tuple[Any], **kwargs: Dict[str, Any] - ) -> WorkerPlugin: + ) -> WorkerPlugin | None: """ Args: plugin (str): Name of plugin module. - plugin_file (str): Name of module file. ** This should be moved into the module itself not passed ** - *args (tuple(Any)): This is any *arg that needs to be passed to the plugin module. + plugin_file (str): Name of a module file. ** This should be moved into the module itself, not passed ** + *args (tuple (Any)): This is any *arg that needs to be passed to the plugin module. **kwargs (dict[str, Any]): This is any **kwarg default values that need to be passed to the plugin module. Returns: - Instance of plugin class. + Instance of plugin-class. """ spec = importlib.util.spec_from_file_location(plugin, plugin_file) module = importlib.util.module_from_spec(spec) @@ -138,6 +138,8 @@ def instantiate_module( logger.debug("Loading plugin module: {}".format(plugin_instance)) return MenrvaClient.call(plugin_instance, *args, **kwargs) + return None + def load_plugin( self, directory: str, @@ -154,13 +156,13 @@ def load_plugin( *args, **kwargs, ) - logger.debug(f"{plugin}") + if sys.version_info.major == 3: if sys.version_info.minor > 8: self.register_plugin(plugin_instance, name=name) else: - self.register_worker_plugin(plugin_instance, name=name) + self.register_plugin(plugin_instance, name=name) else: logger.warning("Python version may not be supported.") else: diff --git a/src/toolviper/dask/plugins/scheduler.py b/src/toolviper/dask/plugins/scheduler.py index 7f9c7b6..cc349fa 100644 --- a/src/toolviper/dask/plugins/scheduler.py +++ b/src/toolviper/dask/plugins/scheduler.py @@ -250,7 +250,7 @@ def dask_setup(scheduler, autorestrictor, local_cache): def graph_metrics(dependencies, dependents, total_dependencies): r"""Useful measures of a graph used by ``dask.order.order`` - Example DAG (a1 has no dependencies; b2 and c1 are root nodes): + example DAG (a1 has no dependencies; b2 and c1 are root nodes): c1 | diff --git a/src/toolviper/dask/plugins/worker.py b/src/toolviper/dask/plugins/worker.py index 12c9e1b..9e5681d 100644 --- a/src/toolviper/dask/plugins/worker.py +++ b/src/toolviper/dask/plugins/worker.py @@ -99,7 +99,7 @@ async def dask_setup( await worker.client.register_plugin(plugin, name="worker_logger") else: - await worker.client.register_worker_plugin(plugin, name="worker_logger") + await worker.client.register_plugin(plugin, name="worker_logger") else: logger.warning("Python version may not be supported.") diff --git a/src/toolviper/utils/__init__.py b/src/toolviper/utils/__init__.py index e7f5a5d..3b5bad1 100644 --- a/src/toolviper/utils/__init__.py +++ b/src/toolviper/utils/__init__.py @@ -3,6 +3,9 @@ from .protego import Protego from .logger import info, debug, warning, error, critical, get_logger, setup_logger from .tools import open_json, calculate_checksum, verify, add_entry +from .profile import memory_usage, cpu_usage +from .sd import prototype +from .display import DataDict from .data import download diff --git a/src/toolviper/utils/app.py b/src/toolviper/utils/app.py new file mode 100644 index 0000000..2730924 --- /dev/null +++ b/src/toolviper/utils/app.py @@ -0,0 +1,40 @@ +from textual.app import App, ComposeResult +from textual.containers import Horizontal, VerticalScroll, VerticalGroup +from textual.widgets import Button, Static, Header, Footer, DirectoryTree + + +class UploadApp(App): + BINDINGS = [("d", "toggle_dark", "Toggle Dark Mode")] + + def compose(self) -> ComposeResult: + yield Header() + yield Footer() + yield VerticalScroll( + DirectoryTreeApp(), + ExitButton(), + ) + + def action_toggle_dark(self) -> None: + self.theme = ( + "textual-dark" if self.theme == "textual-light" else "textual-light" + ) + + +def app_test(): + app = UploadApp() + print(app.run()) + + +class DirectoryTreeApp(VerticalGroup): + def compose(self) -> ComposeResult: + yield DirectoryTree("./") + + +class ExitButton(VerticalGroup): + # CSS_PATH = "css/button.tcss" + + def compose(self) -> ComposeResult: + yield Button("Exit", variant="primary") + + def on_button_pressed(self, event: Button.Pressed) -> None: + self.app.exit() diff --git a/src/toolviper/utils/css/button.tcss b/src/toolviper/utils/css/button.tcss new file mode 100644 index 0000000..704f20e --- /dev/null +++ b/src/toolviper/utils/css/button.tcss @@ -0,0 +1,12 @@ +Button { + margin: 1 2; +} + +Horizontal > VerticalScroll { + width: 24; +} + +.header { + margin: 1 0 0 2; + text-style: bold; +} \ No newline at end of file diff --git a/src/toolviper/utils/data/cloudflare.py b/src/toolviper/utils/data/cloudflare.py index dfbe1b6..bccad66 100644 --- a/src/toolviper/utils/data/cloudflare.py +++ b/src/toolviper/utils/data/cloudflare.py @@ -4,280 +4,346 @@ import shutil import zipfile from threading import Thread -from typing import Any, Optional, Union +from typing import Any, Dict, List, Optional, Union import requests +import pandas as pd + +from rich.console import Console from rich.progress import Progress, TaskID -import toolviper import toolviper.utils.console as console import toolviper.utils.logger as logger -from toolviper.utils import parameter -from collections import defaultdict +from toolviper.utils import parameter from toolviper.utils.parameter import is_notebook -import pandas as pd +from collections import defaultdict colorize = console.Colorize() +# Constants PROGRESS_MAX_CHARACTERS = 28 MINIMUM_CHUNK_SIZE = 1024 +BASE_URL = "https://downloadnrao.org" +METADATA_REL_PATH = ".cloudflare/file.download.json" +USER_AGENT = "Wget/1.16 (linux-gnu)" + + +def _get_metadata_path() -> pathlib.Path: + """Get the absolute path to the local metadata file.""" + return pathlib.Path(__file__).parent.resolve().joinpath(METADATA_REL_PATH) def version() -> None: - # Load the file dropbox file meta data. - meta_data_path = pathlib.Path(__file__).parent.joinpath( - ".cloudflare/file.download.json" - ) + """ + Print the version of the cloudflare manifest. + """ + meta_data_path = _get_metadata_path() if not meta_data_path.parent.exists(): - logger.debug("metadata path doesn't exist... creating") - meta_data_path.parent.mkdir(parents=True) + logger.debug(f"Metadata path {meta_data_path.parent} doesn't exist... creating") + meta_data_path.parent.mkdir(parents=True, exist_ok=True) - # Verify that the download metadata exists and updates if not. _verify_metadata_file() - with open(meta_data_path) as json_file: - file_meta_data = json.load(json_file) + try: + with open(meta_data_path, "r") as json_file: + file_meta_data = json.load(json_file) + logger.info(f"Manifest version: {file_meta_data.get('version', 'unknown')}") - logger.info(f"{file_meta_data['version']}") + except (FileNotFoundError, json.JSONDecodeError) as e: + logger.error(f"Failed to read metadata file: {e}") @parameter.validate() def download( - file: Union[str, list], + file: Union[str, List[str]], folder: str = ".", overwrite: bool = False, decompress: bool = True, ) -> None: """ - Download tool for data stored externally. + Download tool for data stored externally. + Parameters ---------- - file : str - Filename as stored on an external source. - folder : str - Destination folder. - overwrite : bool - Should file be overwritten. - decompress : bool - Should file be unzipped. - - Returns - ------- - No return + file : str or list of str + Filename(s) as stored on an external source. + folder : str, optional + Destination folder. Defaults to ".". + overwrite : bool, optional + Whether to overwrite existing files. Defaults to False. + decompress : bool, optional + Whether to unzip downloaded files. Defaults to True. """ + logger.info("Initializing download...") - logger.info("Downloading from [cloudflare] ....") - - if not isinstance(file, list): + if isinstance(file, str): file = [file] - try: - _print_file_queue(file) - - except Exception as e: - logger.warning(f"There was a problem printing the file list... {e}") + # try: + # _print_file_queue(file) + # except Exception as e: + # logger.warning(f"Problem printing file list: {e}") - finally: - if not pathlib.Path(folder).resolve().exists(): - toolviper.utils.logger.info( - f"Creating path:{colorize.blue(str(pathlib.Path(folder).resolve()))}" - ) - pathlib.Path(folder).resolve().mkdir() + dest_path = pathlib.Path(folder).resolve() + if not dest_path.exists(): + logger.info(f"Creating path: {colorize.blue(str(dest_path))}") + dest_path.mkdir(parents=True, exist_ok=True) - logger.debug(f"Initializing [cloudflare] downloader ...") + meta_data_path = _get_metadata_path() + if not meta_data_path.exists(): + logger.warning( + f"Metadata not found locally at {colorize.blue(str(meta_data_path))}" + ) + update() - meta_data_path = pathlib.Path(__file__).parent.joinpath( - ".cloudflare/file.download.json" - ) + try: + with open(meta_data_path, "r") as json_file: + file_meta_data = json.load(json_file) + except (FileNotFoundError, json.JSONDecodeError) as e: + logger.error(f"Failed to load metadata: {e}") + return tasks = [] - - # Make a list of files that aren't available from cloudflare yet missing_files = [] - # Load the file dropbox file meta data. - if not meta_data_path.exists(): - logger.warning( - f"Couldn't find file metadata locally in {colorize.blue(str(meta_data_path))}" + def name_format(string): + return ( + f"{string[: (PROGRESS_MAX_CHARACTERS - 4)]} ..." + if len(string) > PROGRESS_MAX_CHARACTERS + else string ) - toolviper.utils.data.update() - - with open(meta_data_path) as json_file: - file_meta_data = json.load(json_file) + for f_name in file: + full_file_path = dest_path.joinpath(f_name) - # Build the task list - for file_ in file: - full_file_path = pathlib.Path(folder).joinpath(file_) + if full_file_path.exists() and not overwrite: + logger.info(f"File already exists: {full_file_path}") + continue - if full_file_path.exists() and not overwrite: - logger.info(f"File exists: {str(full_file_path)}") - continue + if f_name not in file_meta_data.get("metadata", {}): + logger.error(f"Requested file not found in manifest: {f_name}") + logger.error( + f"Use {colorize.blue('toolviper.utils.data.update()')} for the most recent version of the manifest." + ) + logger.info( + f"Use {colorize.blue('toolviper.utils.data.list_files()')} for available files." + ) + missing_files.append(f_name) + continue + + meta = file_meta_data["metadata"][f_name] + tasks.append( + { + "description": name_format(f_name), + "metadata": meta, + "folder": str(dest_path), + "visible": True, + "size": int(meta.get("size", 0)), + "jupyter": is_notebook(), + } + ) - if file_ not in file_meta_data["metadata"].keys(): - logger.error(f"Requested file not found: {file_}") - logger.info( - f"For a list of available files try using " - f"{colorize.blue('toolviper.utils.data.list_files()')}." - ) + if not tasks: + if missing_files: + logger.error(f"Missing files: {missing_files}") - missing_files.append(file_) - continue + return - name_format = lambda string: ( - f"{string[: (PROGRESS_MAX_CHARACTERS - 4)]} ..." - if len(string) > PROGRESS_MAX_CHARACTERS - else string - ) + progress = Progress() + if is_notebook(): + _ = Console(force_terminal=True, force_jupyter=False) + _console = Console(force_jupyter=is_notebook()) - tasks.append( - { - "description": name_format(file_), - "metadata": file_meta_data["metadata"][file_], - "folder": folder, - "visible": True, - "size": int(file_meta_data["metadata"][file_]["size"]), - } - ) + progress = Progress(console=_console) threads = [] - progress = Progress() with progress: - task_ids = [ - progress.add_task(task["description"]) for task in tasks if len(tasks) > 0 - ] - - for i, task in enumerate(tasks): - thread = Thread( - target=worker, args=(progress, task_ids[i], task, decompress) - ) + for task in tasks: + task_id = progress.add_task(task["description"]) + thread = Thread(target=worker, args=(task_id, task, progress, decompress)) thread.start() threads.append(thread) for thread in threads: thread.join() - if len(missing_files) > 0: - logger.error(f"Missing files: {missing_files}") - + progress.refresh() -def worker(progress: Progress, task_id: TaskID, task: dict, decompress=True) -> None: - """Simulate work being done in a thread""" + if missing_files: + logger.error(f"Could not download: {missing_files}") - filename = task["metadata"]["file"] - url = f"https://downloadnrao.org/{task['metadata']['path']}/{task['metadata']['file']}" +def worker( + task_id: TaskID, task: dict, progress: Progress = None, decompress: bool = True +) -> None: + """ + Worker function to download a file in a thread. - r = requests.get(url, stream=True, headers={"user-agent": "Wget/1.16 (linux-gnu)"}) - total = int(r.headers.get("Content-Length", 0)) + Parameters + ---------- + progress : Progress + Rich Progress instance. + task_id : TaskID + ID of the task in the progress bar. + task : dict + Task details including metadata and destination folder. + decompress : bool, optional + Whether to decompress the file after download. Defaults to True. + """ + metadata = task["metadata"] + filename = metadata["file"] + path = metadata.get("path", "").strip("/") + url = f"{BASE_URL}/{path}/{filename}" if path else f"{BASE_URL}/{filename}" - if total == 0: - total = task["size"] + try: + response = requests.get( + url, stream=True, headers={"user-agent": USER_AGENT}, timeout=30 + ) + response.raise_for_status() - fullname = str(pathlib.Path(task["folder"]).joinpath(filename)) + except Exception as e: + logger.error(f"Failed to initiate download for {filename}: {e}") + return - size = 0 + total = int(response.headers.get("Content-Length", 0)) + if total == 0: + total = task.get("size", 0) - with open(fullname, "wb") as fd: - for chunk in r.iter_content(chunk_size=MINIMUM_CHUNK_SIZE): - if chunk: - size += fd.write(chunk) - progress.update( - task_id, completed=size, total=total, visible=task["visible"] - ) + dest_folder = pathlib.Path(task["folder"]) + fullname = dest_folder.joinpath(filename) - # Verify checksum on file - # toolviper.utils.verify(filename, task["folder"]) + try: + size = 0 + with open(fullname, "wb") as fd: + for chunk in response.iter_content(chunk_size=MINIMUM_CHUNK_SIZE): + if chunk: + size += fd.write(chunk) + if progress is not None: + progress.update( + task_id, + completed=size, + total=total, + visible=task["visible"], + ) - if decompress: - if zipfile.is_zipfile(fullname): - shutil.unpack_archive(filename=fullname, extract_dir=task["folder"]) + except Exception as e: + logger.error(f"Error writing file {filename}: {e}") + return - # Let's clean up after ourselves + if decompress and zipfile.is_zipfile(fullname): + try: + shutil.unpack_archive(filename=str(fullname), extract_dir=str(dest_folder)) os.remove(fullname) + except Exception as e: + logger.error(f"Failed to decompress {filename}: {e}") class ToolviperFiles: - def __init__(self, manifest, dataframe=None): + """ + Helper class for managing and displaying toolviper data manifests. + """ + def __init__(self, manifest: str, dataframe: Optional[pd.DataFrame] = None) -> None: self.manifest = manifest self.dataframe = dataframe - self.notebook_mode = False + self.notebook_mode = is_notebook() - if is_notebook(): - import itables + if self.notebook_mode: + try: + import itables - self.notebook_mode = True + itables.init_notebook_mode() - itables.init_notebook_mode() + except ImportError: + logger.debug("itables not found, falling back to standard display.") - def __call__(self): + def __call__(self) -> Optional[pd.DataFrame]: if not self.notebook_mode: - return print(self.dataframe) + print(self.dataframe) + return None - else: - return self.dataframe + return self.dataframe - def print(self) -> Union[None, pd.DataFrame]: + def print(self) -> Optional[pd.DataFrame]: + """ + Display the dataframe using appropriate formatting. + """ if not self.notebook_mode: - import tabulate - - print( - tabulate.tabulate( - self.dataframe, showindex=False, headers=self.dataframe.columns + try: + import tabulate + + print( + tabulate.tabulate( + self.dataframe, + showindex=False, + headers=self.dataframe.columns, + ) ) - ) + except ImportError: + print(self.dataframe) + return None return self.dataframe @classmethod - def from_manifest(cls, manifest: str): + def from_manifest(cls, manifest: str) -> "ToolviperFiles": + """ + Create a ToolviperFiles instance from a manifest file. + """ meta_data_path = pathlib.Path(manifest) - # Verify that the download metadata exist and update if not. - # _verify_metadata_file() + try: + with open(meta_data_path, "r") as json_file: + file_meta_data = json.load(json_file) - with open(meta_data_path) as json_file: + except (FileNotFoundError, json.JSONDecodeError) as e: + logger.error(f"Failed to load manifest {manifest}: {e}") - file_meta_data = json.load(json_file) + return cls(manifest=manifest, dataframe=pd.DataFrame()) - files = file_meta_data["metadata"].keys() + metadata_dict = file_meta_data.get("metadata", {}) + data = defaultdict(list) - data = defaultdict(list) - data["file"] = list(files) + for file_name, meta in metadata_dict.items(): + data["file"].append(file_name) - for file_, metadata_ in file_meta_data["metadata"].items(): - for key_, value_ in metadata_.items(): - if key_ == "file": - continue + for key, value in meta.items(): + if key == "file": + continue - # I think we could do this with a JSON ENCODER - # but this is easier since the file is small - # and everything is a string already + if key == "size": + try: + value = int(value) - if value_ == "size": - value_ = int(value_) + except (ValueError, TypeError): + pass - data[key_].append(value_) + data[key].append(value) - return cls(manifest=manifest, dataframe=pd.DataFrame(data)) + return cls(manifest=manifest, dataframe=pd.DataFrame(data)) -def list_files(truncate=None) -> pd.DataFrame: +def list_files(truncate: Optional[int] = None) -> Optional[pd.DataFrame]: + """ + List all files available in the cloudflare manifest. + Parameters + ---------- + truncate : int, optional + Maximum number of rows to display. Defaults to None. + """ pd.set_option("display.max_rows", truncate) pd.set_option("display.colheader_justify", "left") - meta_data_path = pathlib.Path(__file__).parent.joinpath( - ".cloudflare/file.download.json" - ) + meta_data_path = _get_metadata_path() + if not meta_data_path.exists(): + _verify_metadata_file() table = ToolviperFiles.from_manifest(str(meta_data_path)) - return table.print() @@ -322,143 +388,137 @@ def list_files_() -> None: console.print(table) -def get_files() -> list[Any]: +def get_files() -> List[str]: """ - Get all files available in cloudflare manifest. This is retrieved from the local cloudflare - metadata file. - + Get a list of all file names available in the cloudflare manifest. """ - meta_data_path = pathlib.Path(__file__).parent.joinpath( - ".cloudflare/file.download.json" - ) - - # Verify that the download metadata exists and updates if not. + meta_data_path = _get_metadata_path() _verify_metadata_file() - with open(meta_data_path) as json_file: - file_meta_data = json.load(json_file) + try: + with open(meta_data_path, "r") as json_file: + file_meta_data = json.load(json_file) + return list(file_meta_data.get("metadata", {}).keys()) - return list(file_meta_data["metadata"].keys()) + except (FileNotFoundError, json.JSONDecodeError): + return [] @parameter.validate() -def update(path: Union[str, None] = None) -> None: +def update(path: Optional[str] = None) -> None: """ - Update cloudflare manifest. + Update the local cloudflare manifest by downloading the latest version. Parameters ---------- - path : str - In the case that you want an updated copy of the manifest for modification, this is the path to save it to. + path : str, optional + Custom path to save the manifest to. Defaults to the internal .cloudflare directory. """ - if path is None: - meta_data_path = pathlib.Path(__file__).parent.joinpath(".cloudflare") + meta_data_dir = _get_metadata_path().parent + meta_data_path = _get_metadata_path() else: - # I know this is an unnecessary copy but I don't want a big erbose path name in the inpute variables. - meta_data_path = pathlib.Path(path) + meta_data_dir = pathlib.Path(path) + meta_data_path = meta_data_dir.joinpath("file.download.json") - if not meta_data_path.exists(): - _make_dir(str(pathlib.Path(__file__).parent), ".cloudflare") + if not meta_data_dir.exists(): + meta_data_dir.mkdir(parents=True, exist_ok=True) + # Temporary metadata to kickstart the download of the actual manifest file_meta_data = { "file": "file.download.json", "path": "/", "dtype": "JSON", "telescope": "NA", - "size": "12484", + "size": "23879", "mode": "NA", } - tasks = { - "description": "file.download.json", + task = { + "description": "Updating manifest", "metadata": file_meta_data, - "folder": meta_data_path, - "visible": False, - "size": 12484, + "folder": str(meta_data_dir), + "visible": True, + "size": 23879, } - logger.info("Updating file metadata information ... ") + logger.info("Updating file metadata information...") - progress = Progress() - task_id = progress.add_task(tasks["description"]) + task_id = 0 + # with progress: + _console = Console(force_jupyter=is_notebook()) + tasks = [f"\nManifest update "] - with progress: - worker(progress, task_id, tasks) + with _console.status( + "[bold green]Working on download manifest update ..." + ) as status: + while tasks: + worker(task_id, task, progress=None, decompress=False) + + task = tasks.pop(0) if not meta_data_path.exists(): logger.error("Unable to retrieve download metadata.") - raise FileNotFoundError( - "Download metadata file does not exist at the expected path." - ) + raise FileNotFoundError(f"Download metadata file not found at {meta_data_path}") @parameter.validate() -def get_file_size(path: str) -> Optional[dict]: - """ - Get list file sizes in bytes for a given path. Only works for files; isn't recursive. +def get_file_size(path: str) -> Dict[str, int]: """ - if not pathlib.Path(path).resolve().exists(): - logger.error(f"Path not found...: {path}") - - return None - - file_size_dict = {} + Get file sizes in bytes for all files in a given path. - for item in pathlib.Path(path).resolve().iterdir(): - if pathlib.Path(item).resolve().is_file(): - if item.name.endswith(".zip"): - item_ = item.name.split(".zip")[0] + Parameters + ---------- + path : str + The directory path to scan. - else: - item_ = item.name + Returns + ------- + dict + A dictionary mapping file names to their sizes in bytes. + """ + path_obj = pathlib.Path(path).resolve() + if not path_obj.exists() or not path_obj.is_dir(): + logger.error(f"Path not found or is not a directory: {path}") + return {} - file_size_dict[item_] = os.path.getsize(pathlib.Path(item)) + file_size_dict = {} + for item in path_obj.iterdir(): + if item.is_file(): + name = ( + item.name.split(".zip")[0] if item.name.endswith(".zip") else item.name + ) + file_size_dict[name] = item.stat().st_size return file_size_dict -def _print_file_queue(files: list) -> None: +def _print_file_queue(files: List[str]) -> None: + """ + Print a formatted list of files to be downloaded. + """ from rich import box from rich.console import Console from rich.table import Table - assert type(files) == list + assert isinstance(files, list), logger.error("files must be a list") - console = Console() + console_ = Console() table = Table(show_header=True, box=box.SIMPLE) - table.add_column("Download List", justify="left") - for file in files: - table.add_row(f"[magenta]{file}[/magenta]") - - console.print(table) - - -def _make_dir(path, folder): - p = pathlib.Path(path).joinpath(folder) - try: - p.mkdir() - logger.info( - f"Creating path:{colorize.blue(str(pathlib.Path(folder).resolve()))}" - ) - - except FileExistsError: - logger.warning(f"File exists: {colorize.blue(str(p.resolve()))}") + for f_name in files: + table.add_row(f"[magenta]{f_name}[/magenta]") - except FileNotFoundError: - logger.warning( - f"One fo the parent directories cannot be found: {colorize.blue(str(p.resolve()))}" - ) + console_.print(table) -def _verify_metadata_file(): - meta_data_path = pathlib.Path(__file__).parent.joinpath( - ".cloudflare/file.download.json" - ) - +def _verify_metadata_file() -> None: + """ + Ensure the metadata file exists or trigger an update. + """ + meta_data_path = _get_metadata_path() if not meta_data_path.exists(): - logger.warning(f"Couldn't find {colorize.blue(str(meta_data_path))}.") + logger.warning(f"Metadata file {meta_data_path} missing. Updating...") update() diff --git a/src/toolviper/utils/display.py b/src/toolviper/utils/display.py index f0c9932..ba50f90 100755 --- a/src/toolviper/utils/display.py +++ b/src/toolviper/utils/display.py @@ -1,10 +1,100 @@ +import re +import operator +from typing import Dict + +from IPython.core.display import HTML + + +class DataDict(dict): + def __init__(self, dictionary): + super().__init__() + self._dict = dictionary + + def __repr__(self, *args, **kwargs): + return f"" + + @classmethod + def from_dict(cls, dictionary): + if isinstance(dictionary, dict): + return cls(dictionary) + + return None + + @property + def data(self): + return self._dict + + def select(self, keys, in_place=False): + _result = list(operator.itemgetter(*keys)(self._dict)) + + if in_place: + self._dict = _result + + return DataDict.from_dict({key: value for key, value in zip(keys, _result)}) + + def get_entries_(self, keys): + return {key: value for key, value in self._dict.items() if key in keys} + + def filter(self, query, in_place=False): + _result = None + + if isinstance(query, list): + _result = self.get_entries_(query) + + if isinstance(query, str): + _result = { + key: value for key, value in self._dict.items() if re.search(query, key) + } + + if in_place: + self._dict = _result + return None + + return DataDict.from_dict(_result) + + def display(self, interactive=True): + import rich + from toolviper.utils.parameter import is_notebook + + if is_notebook() and interactive: + from IPython.display import JSON + + return JSON(self._dict) + + return rich.print_json(data=self._dict) + + @staticmethod + def html(dictionary: Dict, indent: int = 0): + _html = _write_html(dictionary, indent) + + return HTML(_html) + + +def _write_html(d, indent=0): + _html = "" + + for key, value in d.items(): + if isinstance(value, dict): + _html += f"
{key}{_write_html(value, indent + 1)}
" + + else: + _html += f"
{key}: {value}
" + + return _html + + def dict_to_html(d, indent=0): - from IPython.display import HTML html = "" for key, value in d.items(): if isinstance(value, dict): html += f"
{key}{dict_to_html(value, indent + 1)}
" else: - html += f"
{key}: {value}
" + html += f"
{key}: {value}
" + + if indent == 0: + print( + f"THIS FUNCTION WILL BE DEPRECATED SOON, switch to: toolviper.utils.display.DataDict.html(d)" + ) + return html diff --git a/src/toolviper/utils/logger.py b/src/toolviper/utils/logger.py index a076d86..12f39d3 100755 --- a/src/toolviper/utils/logger.py +++ b/src/toolviper/utils/logger.py @@ -12,252 +12,325 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import sys -import dask -import logging - from datetime import datetime -from toolviper.utils.console import Colorize -from toolviper.utils.console import add_verbose_info +from typing import Any, Dict, Optional, Union +import dask +import dask.distributed +from contextvars import ContextVar from dask.distributed import get_worker -from contextvars import ContextVar +from toolviper.utils.console import Colorize, add_verbose_info -from typing import Union +# Global verbosity flag +verbosity: ContextVar[Optional[bool]] = ContextVar("message_verbosity", default=None) + +# Constants for default values +DEFAULT_LOGGER_NAME = "viperlog" +LOGGER_ENV_VAR = "VIPER_LOGGER_NAME" VERBOSE = True DEFAULT = False -# global verbosity flag -verbosity: Union[ContextVar[bool], ContextVar[None]] = ContextVar( - "message_verbosity", default=None -) - -def set_verbosity(state: Union[None, bool] = None): - print(f"Setting verbosity to {state}") +def set_verbosity(state: Optional[bool] = None) -> None: + """ + Set the global verbosity state. + Parameters + ---------- + state : bool, optional + The verbosity state to set. If None, it uses the default. + """ verbosity.set(state) -def info(message: str, verbose: bool = False): - logger_name = os.getenv("LOGGER_NAME") - - if verbosity.get() is True or False: - verbose = verbosity.get() - - if verbose: - message = add_verbose_info(message=message, color="blue") +def _log_message( + level: str, message: str, verbose: bool = False, color: Optional[str] = None +) -> None: + """ + Helper function to process and log a message. + + Parameters + ---------- + level : str + The logging level (e.g., 'info', 'debug', 'warning'). + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to False. + color : str, optional + The color to use for verbose information. + """ + logger_name = os.getenv(LOGGER_ENV_VAR, DEFAULT_LOGGER_NAME) + + current_verbosity = verbosity.get() + if current_verbosity is not None: + verbose = current_verbosity + + if verbose and color: + message = add_verbose_info(message=message, color=color) logger = get_logger(logger_name=logger_name) - logger.info(message) - - -def log(message: str, verbose: bool = False): - logger_name = os.getenv("LOGGER_NAME") - - if verbosity.get() is True or False: - verbose = verbosity.get() + log_func = getattr(logger, level.lower()) + log_func(message) + + +def info(message: str, verbose: bool = False) -> None: + """ + Log an info level message. + + Parameters + ---------- + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to False. + """ + _log_message("info", message, verbose, color="blue") + + +def log(message: str, verbose: bool = False) -> None: + """ + Log a message at the current logger's level. + + Parameters + ---------- + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to False. + """ + logger_name = os.getenv(LOGGER_ENV_VAR, DEFAULT_LOGGER_NAME) + current_verbosity = verbosity.get() + if current_verbosity is not None: + verbose = current_verbosity if verbose: message = add_verbose_info(message=message, color="blue") logger = get_logger(logger_name=logger_name) - logger.log(logger.level, message) -def exception(message: str, verbose: bool = False): - logger_name = os.getenv("LOGGER_NAME") - - if verbosity.get() is True or False: - verbose = verbosity.get() +def exception(message: str, verbose: bool = False) -> None: + """ + Log an exception level message. + + Parameters + ---------- + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to False. + """ + _log_message("exception", message, verbose, color="blue") + + +def debug(message: str, verbose: bool = False) -> None: + """ + Log a debug level message. + + Parameters + ---------- + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to False. + """ + _log_message("debug", message, verbose, color="green") + + +def warning(message: str, verbose: bool = False) -> None: + """ + Log a warning level message. + + Parameters + ---------- + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to False. + """ + _log_message("warning", message, verbose, color="orange") + + +def error(message: str, verbose: bool = True) -> None: + """ + Log an error level message. + + Parameters + ---------- + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to True. + """ + _log_message("error", message, verbose, color="red") + + +def critical(message: str, verbose: bool = True) -> None: + """ + Log a critical level message. + + Parameters + ---------- + message : str + The message to log. + verbose : bool, optional + Whether to include verbose information. Defaults to True. + """ + _log_message("critical", message, verbose, color="alert") - if verbose: - message = add_verbose_info(message=message, color="blue") - - logger = get_logger(logger_name=logger_name) - - logger.exception(message) - - -def debug(message: str, verbose: bool = False): - logger_name = os.getenv("LOGGER_NAME") - - if verbosity.get() is True or False: - verbose = verbosity.get() - - if verbose: - message = add_verbose_info(message=message, color="green") - - logger = get_logger(logger_name=logger_name) - logger.debug(message) - - -def warning(message: str, verbose: bool = False): - logger_name = os.getenv("LOGGER_NAME") - - if verbosity.get() is True or False: - verbose = verbosity.get() - - if verbose: - message = add_verbose_info(message=message, color="orange") - logger = get_logger(logger_name=logger_name) - logger.warning(message) - - -def error(message: str, verbose: bool = True): - logger_name = os.getenv("LOGGER_NAME") - - if verbosity.get() is True or False: - verbose = verbosity.get() - - if verbose: - message = add_verbose_info(message=message, color="red") - - logger = get_logger(logger_name=logger_name) - logger.error(message) - - -def critical(message: str, verbose: bool = True): - logger_name = os.getenv("LOGGER_NAME") +class ColorLoggingFormatter(logging.Formatter): + """ + A logging formatter that adds colors to the output based on the log level. + """ - if verbosity.get() is True or False: - verbose = verbosity.get() + colorize = Colorize() - if verbose: - message = add_verbose_info(message=message, color="alert") + def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None): + super().__init__(fmt, datefmt) + self.start_msg = f"[{self.colorize.purple('%(asctime)s')}] " + + self.FORMATS = { + logging.DEBUG: self.start_msg + + self.colorize.green("%(levelname)8s") + + self.colorize.grey(" %(name)10s: ") + + " %(message)s", + logging.INFO: self.start_msg + + self.colorize.blue("%(levelname)8s") + + self.colorize.grey(" %(name)10s: ") + + " %(message)s ", + logging.WARNING: self.start_msg + + self.colorize.orange("%(levelname)8s") + + self.colorize.grey(" %(name)10s: ") + + " %(message)s ", + logging.ERROR: self.start_msg + + self.colorize.red("%(levelname)8s") + + self.colorize.grey(" %(name)10s: ") + + " %(message)s", + logging.CRITICAL: self.start_msg + + self.colorize.format( + text="%(levelname)8s", color=[220, 60, 20], highlight=True + ) + + self.colorize.grey(" %(name)10s: ") + + " %(message)s", + } + + def format(self, record: logging.LogRecord) -> str: + log_fmt = self.FORMATS.get(record.levelno, self._fmt) + formatter = logging.Formatter(log_fmt, self.datefmt) - logger = get_logger(logger_name=logger_name) - logger.critical(message) + return formatter.format(record) -class ColorLoggingFormatter(logging.Formatter): - colorize = Colorize() +class LoggingFormatter(logging.Formatter): + """ + A standard logging formatter for file output. + """ + + def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None): + super().__init__(fmt, datefmt) + self.start_msg = "[%(asctime)s] " + self.middle_msg = "%(levelname)8s" + + self.FORMATS = { + level: f"{self.start_msg}{self.middle_msg} %(name)10s: %(message)s" + for level in [ + logging.DEBUG, + logging.INFO, + logging.WARNING, + logging.ERROR, + logging.CRITICAL, + ] + } + + def format(self, record: logging.LogRecord) -> str: + log_fmt = self.FORMATS.get(record.levelno, self._fmt) + formatter = logging.Formatter(log_fmt, self.datefmt) - function = " [{function}] ".format(function=colorize.blue("%(funcName)s")) - verbose = " [{exechain}] ".format( - exechain=colorize.blue("%(filename)s:%(lineno)s : %(module)s.%(funcName)s") - ) - - start_msg = "[{time}] ".format(time=colorize.purple("%(asctime)s")) - middle_msg = "{level}".format(level="%(levelname)8s") - execution_msg = " {name} [ {filename} ]: {exec_info}: ".format( - name="%(name)10s", - filename="%(filename)-20s", - exec_info=colorize.blue("%(callchain)-45s"), - ) - - FORMATS = { - logging.DEBUG: start_msg - + colorize.green(middle_msg) - + colorize.grey(" %(name)10s: ") - + " %(message)s", - logging.INFO: start_msg - + colorize.blue(middle_msg) - + colorize.grey(" %(name)10s: ") - + " %(message)s ", - logging.WARNING: start_msg - + colorize.orange(middle_msg) - + colorize.grey(" %(name)10s: ") - + " %(message)s ", - logging.ERROR: start_msg - + colorize.red(middle_msg) - + colorize.grey(" %(name)10s: ") - + " %(message)s", - logging.CRITICAL: start_msg - + colorize.format(text=middle_msg, color=[220, 60, 20], highlight=True) - + colorize.grey(" %(name)10s: ") - + " %(message)s", - } - - def format(self, record): - log_fmt = self.FORMATS.get(record.levelno) - formatter = logging.Formatter(log_fmt) return formatter.format(record) -class LoggingFormatter(logging.Formatter): - function = " [{function}] ".format(function="%(funcName)s") - verbose = " [{exechain}] ".format( - exechain="%(filename)s:%(lineno)s : %(module)s.%(funcName)s" - ) - - start_msg = "[{time}] ".format(time="%(asctime)s") - middle_msg = "{level}".format(level="%(levelname)8s") - execution_msg = " {name} [ {filename} ]: {exec_info}: ".format( - name="%(name)10s", filename="%(filename)-20s", exec_info="%(callchain)-45s" - ) - - FORMATS = { - logging.DEBUG: start_msg + middle_msg + " %(name)10s: " + " %(message)s", - logging.INFO: start_msg + middle_msg + " %(name)10s: " + " %(message)s ", - logging.WARNING: start_msg + middle_msg + " %(name)10s: " + " %(message)s ", - logging.ERROR: start_msg + middle_msg + " %(name)10s: " + " %(message)s", - logging.CRITICAL: start_msg + middle_msg + " %(name)10s: " + " %(message)s", - } - - def format(self, record): - log_fmt = self.FORMATS.get(record.levelno) - formatter = logging.Formatter(log_fmt) - return formatter.format(record) +def get_logger(logger_name: Optional[str] = None) -> logging.Logger: + """ + Get a logger instance by name, with fallback to environment or defaults. + Parameters + ---------- + logger_name : str, optional + The name of the logger to retrieve. -def get_logger(logger_name: Union[str, None] = None): + Returns + ------- + logging.Logger + The logger instance. + + """ if logger_name is None: - if os.getenv("LOGGER_NAME"): - # Return default logger from env if none is specified. - logger_name = os.getenv("LOGGER_NAME") - else: - logger_name = "viperlog" + logger_name = os.getenv(LOGGER_ENV_VAR, DEFAULT_LOGGER_NAME) try: worker = get_worker() + # If we're on a worker, try to get the worker-specific logger from the plugin + if hasattr(worker, "plugins") and "worker_logger" in worker.plugins: + return worker.plugins["worker_logger"].get_logger() - except ValueError: - # Scheduler processes - logger_dict = logging.Logger.manager.loggerDict - if logger_name in logger_dict: - logger = logging.getLogger(logger_name) - else: - # If main logger is not started using client function it defaults to printing to term. - logger = logging.getLogger(logger_name) - stream_handler = logging.StreamHandler(sys.stdout) - stream_handler.setFormatter(ColorLoggingFormatter()) - logger.addHandler(stream_handler) - logger.setLevel(logging.getLevelName("INFO")) - - return logger + except (ValueError, AttributeError, KeyError): + # Not on a worker, or worker logger plugin not available + pass - try: - logger = worker.plugins["worker_logger"].get_logger() - - return logger + logger = logging.getLogger(logger_name) - except Exception as e: - print("Could not load worker logger: {}".format(e)) - print(worker.plugins.keys()) + # If the logger has no handlers, it hasn't been set up yet. + if not logger.handlers: + # Default to a simple stream handler if not explicitly set up + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setFormatter(ColorLoggingFormatter()) + logger.addHandler(stream_handler) + logger.setLevel(logging.INFO) - return logging.getLogger() + return logger def setup_logger( - logger_name: Union[str, None] = None, + logger_name: Optional[str] = None, log_to_term: bool = False, log_to_file: bool = True, log_file: str = "logger", log_level: str = "INFO", -): - """To set up as many loggers as you want""" +) -> logging.Logger: + """ + Configure and return a logger. + + Parameters + ---------- + logger_name : str, optional + The name of the logger to set up. + log_to_term : bool, optional + Whether to log to the terminal. Defaults to False. + log_to_file : bool, optional + Whether to log to a file. Defaults to True. + log_file : str, optional + The base name of the log file. + log_level : str, optional + The logging level (e.g., 'DEBUG', 'INFO'). Defaults to 'INFO'. + + Returns + ------- + logging.Logger + The configured logger. + """ if logger_name is None: - logger_name = "viperlog" + logger_name = DEFAULT_LOGGER_NAME logger = logging.getLogger(logger_name) - logger.setLevel(logging.getLevelName(log_level)) - + logger.setLevel(getattr(logging, log_level.upper(), logging.INFO)) logger.handlers.clear() if log_to_term: @@ -266,19 +339,38 @@ def setup_logger( logger.addHandler(stream_handler) if log_to_file: - log_file = log_file + datetime.today().strftime("%Y%m%d_%H%M%S") + ".log" - log_handler = logging.FileHandler(log_file) + timestamp = datetime.today().strftime("%Y%m%d_%H%M%S") + full_log_file = f"{log_file}{timestamp}.log" + log_handler = logging.FileHandler(full_log_file) log_handler.setFormatter(LoggingFormatter()) logger.addHandler(log_handler) return logger -def get_worker_logger_name(logger_name: Union[str, None] = None): +def get_worker_logger_name(logger_name: Optional[str] = None) -> str: + """ + Generate a unique logger name for a Dask worker. + + Parameters + ---------- + logger_name : str, optional + The base logger name. + + Returns + ------- + str + The worker-specific logger name. + """ if logger_name is None: - logger_name = "viperlog" + logger_name = DEFAULT_LOGGER_NAME + + try: + worker_id = get_worker().id + return f"{logger_name}_{worker_id}" - return "_".join((logger_name, str(get_worker().id))) + except (ValueError, AttributeError): + return logger_name def setup_worker_logger( @@ -287,12 +379,36 @@ def setup_worker_logger( log_to_file: bool, log_file: str, log_level: str, - worker: dask.distributed.worker.Worker, -): - parallel_logger_name = "_".join((logger_name, str(worker.name))) + worker: "dask.distributed.worker.Worker", +) -> logging.Logger: + """ + Configure and return a logger for a Dask worker. + + Parameters + ---------- + logger_name : str + The base name of the logger. + log_to_term : bool + Whether to log to the terminal. + log_to_file : bool + Whether to log to a file. + log_file : str + The base name of the log file. + log_level : str + The logging level. + worker : dask.distributed.worker.Worker + The Dask worker instance. + + Returns + ------- + logging.Logger + The configured worker logger. + """ + parallel_logger_name = f"{logger_name}_{worker.name}" logger = logging.getLogger(parallel_logger_name) - logger.setLevel(logging.getLevelName(log_level)) + logger.setLevel(getattr(logging, log_level.upper(), logging.INFO)) + logger.handlers.clear() if log_to_term: stream_handler = logging.StreamHandler(sys.stdout) @@ -300,20 +416,9 @@ def setup_worker_logger( logger.addHandler(stream_handler) if log_to_file: - logger.info(f"log_to_file: {log_file}") - dask.distributed.print(f"log_to_file: {log_to_file}") - - log_file = ( - log_file - + "_" - + str(worker.name) - + "_" - + datetime.today().strftime("%Y%m%d_%H%M%S") - + "_" - + str(worker.ip) - + ".log" - ) - log_handler = logging.FileHandler(log_file) + timestamp = datetime.today().strftime("%Y%m%d_%H%M%S") + full_log_file = f"{log_file}_{worker.name}_{timestamp}_{worker.ip}.log" + log_handler = logging.FileHandler(full_log_file) log_handler.setFormatter(LoggingFormatter()) logger.addHandler(log_handler) diff --git a/src/toolviper/utils/parameter.py b/src/toolviper/utils/parameter.py index 3a2899e..a5cc0b4 100644 --- a/src/toolviper/utils/parameter.py +++ b/src/toolviper/utils/parameter.py @@ -1,13 +1,12 @@ import functools import glob -import importlib import inspect import json import os import pathlib import pkgutil from types import ModuleType -from typing import Any, Callable, Dict, List, NoReturn, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import toolviper.utils.console as console import toolviper.utils.logger @@ -43,7 +42,7 @@ def wrapper(*args, **kwargs): meta_data["function"] = function.__name__ meta_data["module"] = function.__module__ - # If this is a class method, drop the self entry. + # If this is a class method, drop the self-entry. if "self" in list(arguments.keys()): class_name = args[0].__class__.__name__ meta_data["function"] = ".".join((class_name, function.__name__)) diff --git a/src/toolviper/utils/profile.py b/src/toolviper/utils/profile.py new file mode 100644 index 0000000..57a7713 --- /dev/null +++ b/src/toolviper/utils/profile.py @@ -0,0 +1,102 @@ +import tracemalloc +import uuid +import csv +import functools +import multiprocessing +import time +import psutil + +import toolviper.utils.logger as logger + + +def cpu_usage_(stop_event, filename): + if filename is None: + filename = f"cpu_usage_{uuid.uuid4()}.csv" + + with open(filename, "w") as csvfile: + number_of_cores = psutil.cpu_count(logical=True) + + core_list = [f"c{core}" for core in range(number_of_cores)] + writer = csv.writer(csvfile, delimiter=",", lineterminator="\n") + writer.writerow(core_list) + while not stop_event.is_set(): + usage = psutil.cpu_percent(percpu=True, interval=1) + writer.writerow(usage) + + +def cpu_usage(filename=None): + def function_wrapper(function): + @functools.wraps(function) + def wrapper(*args, **kwargs): + stop_event = multiprocessing.Event() + + monitor_process = multiprocessing.Process( + target=cpu_usage_, args=(stop_event, filename) + ) + monitor_process.start() + + time.sleep(1) + + try: + results = function(*args, **kwargs) + finally: + stop_event.set() + monitor_process.join(timeout=1) + monitor_process.terminate() + + return results + + return wrapper + + return function_wrapper + + +# Not for production. Yet. +def memory_usage(): + def decorator(function): + @functools.wraps(function) + def wrapper(*args, **kwargs): + import csv + + logger.debug(f"start memory profiling on function {function.__name__}") + + tracemalloc.start() + result = function(*args, **kwargs) + snapshot = tracemalloc.take_snapshot() + snapshot = snapshot.filter_traces( + ( + tracemalloc.Filter(False, ""), + tracemalloc.Filter(False, ""), + ) + ) + stats = snapshot.statistics("lineno") + record = [] + for index, stat in enumerate(stats, 1): + frame = stat.traceback[0] + record.append( + { + "index": index, + "filename": frame.filename, + "lineno": frame.lineno, + "memory": int(stat.size / 1024), + } + ) + + tracemalloc.stop() + field_names = ["index", "filename", "lineno", "memory"] + + with open( + f"memory_usage_{function.__name__}_{uuid.uuid4()}.csv", + "w", + newline="", + encoding="utf-8", + ) as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=field_names) + writer.writeheader() + writer.writerows(record) + + return result + + return wrapper + + return decorator diff --git a/src/toolviper/utils/sd/__init__.py b/src/toolviper/utils/sd/__init__.py new file mode 100644 index 0000000..0622cdf --- /dev/null +++ b/src/toolviper/utils/sd/__init__.py @@ -0,0 +1,5 @@ +from .prototype import * +from .graph import * + +__submodules__ = ["prototype"] +__all__ = __submodules__ + [s for s in dir() if not s.startswith("_")] diff --git a/src/toolviper/utils/sd/graph.py b/src/toolviper/utils/sd/graph.py new file mode 100644 index 0000000..5bdddd9 --- /dev/null +++ b/src/toolviper/utils/sd/graph.py @@ -0,0 +1,58 @@ +import dask +import collections +import toolviper + +import toolviper.utils.logger as logger + + +class Graph: + """A class representing a directed graph for dependency management.""" + + def __init__(self): + self._graph = None + self._results = collections.defaultdict(list) + + def source(self, job, axes, connect=False, type="", node=None): + function_name = job["function"].__name__ + previous = None + + logger.info(f"Adding sink node for function: {function_name}") + if connect: + previous = self._graph + logger.info(f"Connecting to previous node: {previous}") + + if node is not None: + try: + logger.info(f"Connecting to user-supplied node: {node}") + previous = self._results[node] + + except KeyError: + logger.error(f"Node {node} not found in results.") + + logger.info(f"Distributing function: {function_name} on axes: {axes}") + if type == "tree": + for _previous in previous: + self._graph = toolviper.utils.sd.distribute( + job=job, axes=axes, function=job["function"], previous=_previous + ) + else: + self._graph = toolviper.utils.sd.distribute( + job=job, axes=axes, function=job["function"], previous=previous + ) + + self._results[function_name].append(self._graph) + + def sink(self, function, edges=None): + logger.info(f"Adding sink node for function: {function.__name__}") + self._results[function.__name__].append(self._graph) + self._graph = dask.delayed(function)(self._graph) + + def visualize(self): + return dask.visualize(self._graph) + + def compute(self): + return dask.compute(self._graph) + + @property + def nodes(self): + return list(self._results.keys()) diff --git a/src/toolviper/utils/sd/prototype.py b/src/toolviper/utils/sd/prototype.py new file mode 100644 index 0000000..22f5479 --- /dev/null +++ b/src/toolviper/utils/sd/prototype.py @@ -0,0 +1,98 @@ +import dask +import typing +import itertools + +import numpy as np +import xarray as xr + + +# Build a simple Dask dataset based on a given set of axes +def simulate(field, spw, polarization, antenna, row): + data_shape = { + "field": [i for i in range(field)], + "spw": [i for i in range(spw)], + "polarization": polarization, + "antenna": [f"antenna_{i}" for i in range(antenna)], + "row": [i for i in range(row)], + } + + dataset = xr.Dataset( + coords=data_shape, + data_vars=dict( + DATA=( + list(data_shape.keys()), + np.zeros((field, spw, len(polarization), antenna, row)), + ) + ), + ) + + return dataset + + +def distribute( + job: typing.Dict, axes: typing.List[str], function: typing.Callable, previous=None +) -> typing.List[dask.delayed]: + """ + Distribute a function across a dataset along specified axes. + + This function creates a list of delayed dask tasks, where each task + represents a call to the specified function with the dataset or previous + result, and the values of the distribution axes. + + Parameters + ---------- + dataset : xr.Dataset + The input dataset to be distributed. + axes : typing.List[str] + The axes to distribute along. + function : typing.Callable + The function to be applied in a delayed manner. + previous : typing.Any, optional + A previous result to be passed to the function. Defaults to None. + + Returns + ------- + typing.List[dask.delayed] + A list of dask delayed objects. + """ + # Get the coordinate values for each axis + axis_values = [job["dataset"].coords[axis].values for axis in axes] + # arguments = {axis: job["dataset"].coords[axis].values for axis in axes} + # inputs = [dict(zip(axes, combo)) for combo in itertools.product(*arguments.values())] + + if isinstance(previous, list): + axis_values.append(previous) + # inputs.append(previous) + + # Create a delayed version of the function + delayed_func = dask.delayed(function) + + # Use itertools.product to generate all combinations of axis values + # and create a delayed task for each combination. + # The axis values are passed as positional arguments after 'previous'. + + return [ + delayed_func(*values) if previous is not None else delayed_func(*values) + for values in itertools.product(*axis_values) # previously axis_values + ] + # output = [] + # for values in inputs: + # print(values) + # if isinstance(values, dict) and previous is not None: + # print("===== dict") + # output.append(delayed_func(**values)) + + # elif isinstance(values, list): + # print("===== list") + # output.append(delayed_func(*values)) + + # else: + # print("===== idk") + # output.append(delayed_func(values)) + + # return output + + # return [ + # delayed_func(**values) if previous is not None else delayed_func(**values) + # for values in inputs # previously axis_values + # ] diff --git a/src/toolviper/utils/tools.py b/src/toolviper/utils/tools.py index 37462ab..5dd6dca 100644 --- a/src/toolviper/utils/tools.py +++ b/src/toolviper/utils/tools.py @@ -147,7 +147,7 @@ def add_entry( ---------- entries : dict, list - Dictionary or list of metadata info that are needed to build the new entry. + Dictionary, or list of metadata info that are needed to build the new entry. manifest : str Points to the manifest you want to modify. @@ -186,9 +186,9 @@ def add_entry( for entry in entries: process_entry_(**entry, json_file=json_file) - except KeyError as key_error: - logger.error(f"entry not found in metadata ... skipping: {key_error}") - return None + # except KeyError as key_error: + # logger.error(f"entry not found in metadata ... skipping: {key_error}") + # return None except TypeError: logger.error( diff --git a/tests/test_client.py b/tests/test_client.py index 2f5fd88..79a7555 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,6 +2,7 @@ import re import pathlib import distributed +from unittest.mock import patch, MagicMock from toolviper.dask.client import local_client @@ -172,3 +173,29 @@ def test__set_up_dask(self): _set_up_dask(local_directory=pathlib.Path(".").cwd()) assert dask.config.config["distributed"]["scheduler"]["allowed-failures"] == 10 + + +def test_print_libraries_availability(): + from toolviper.dask.client import print_libraries_availability + import toolviper.utils.logger as logger + + with patch.object(logger, "debug") as mock_debug: + print_libraries_availability({"CUDA": True, "MPI": False}) + mock_debug.assert_called_once() + args, kwargs = mock_debug.call_args + assert "CUDA" in args[0] + assert "MPI" not in args[0] + + +def test_get_client_none(): + from toolviper.dask.client import get_client + + with patch("distributed.Client.current", side_effect=ValueError): + assert get_client() is None + + +def test_get_cluster_none(): + from toolviper.dask.client import get_cluster + + with patch("toolviper.dask.client.get_client", return_value=None): + assert get_cluster() is None diff --git a/tests/test_cloudflare.py b/tests/test_cloudflare.py new file mode 100644 index 0000000..8181450 --- /dev/null +++ b/tests/test_cloudflare.py @@ -0,0 +1,109 @@ +import os +import pathlib +import json +import pytest +import responses +from toolviper.utils.data import cloudflare +import pandas as pd + + +@pytest.fixture +def mock_metadata(tmp_path): + metadata = { + "version": "1.0.0", + "metadata": { + "test_file.zip": { + "file": "test_file.zip", + "path": "test", + "dtype": "ZIP", + "telescope": "ALMA", + "size": "100", + "mode": "test", + } + }, + } + meta_dir = tmp_path / ".cloudflare" + meta_dir.mkdir() + meta_file = meta_dir / "file.download.json" + with open(meta_file, "w") as f: + json.dump(metadata, f) + return meta_file + + +def test_version(mock_metadata, monkeypatch, caplog): + # Mock __file__ in cloudflare to point to our temp directory + monkeypatch.setattr( + cloudflare, "__file__", str(mock_metadata.parent.parent / "cloudflare.py") + ) + + with caplog.at_level("INFO"): + cloudflare.version() + assert "1.0.0" in caplog.text + + +@responses.activate +def test_download(mock_metadata, monkeypatch, tmp_path): + monkeypatch.setattr( + cloudflare, "__file__", str(mock_metadata.parent.parent / "cloudflare.py") + ) + + url = "https://downloadnrao.org/test/test_file.zip" + responses.add( + responses.GET, + url, + body=b"test data", + status=200, + headers={"Content-Length": "9"}, + ) + + dest_folder = tmp_path / "dest" + cloudflare.download("test_file.zip", folder=str(dest_folder), decompress=False) + + assert (dest_folder / "test_file.zip").exists() + with open(dest_folder / "test_file.zip", "rb") as f: + assert f.read() == b"test data" + + +def test_get_files(mock_metadata, monkeypatch): + monkeypatch.setattr( + cloudflare, "__file__", str(mock_metadata.parent.parent / "cloudflare.py") + ) + files = cloudflare.get_files() + assert "test_file.zip" in files + + +def test_get_file_size(tmp_path): + test_file = tmp_path / "test.txt" + test_file.write_text("hello") + + sizes = cloudflare.get_file_size(str(tmp_path)) + assert sizes["test.txt"] == 5 + + +@responses.activate +def test_update(mock_metadata, monkeypatch, tmp_path): + monkeypatch.setattr( + cloudflare, "__file__", str(mock_metadata.parent.parent / "cloudflare.py") + ) + + url = "https://downloadnrao.org/file.download.json" + new_metadata = {"version": "1.1.0", "metadata": {}} + responses.add(responses.GET, url, json=new_metadata, status=200) + + update_path = tmp_path / "update_dir" + update_path.mkdir() + cloudflare.update(path=str(update_path)) + + assert (update_path / "file.download.json").exists() + + +def test_list_files(mock_metadata, monkeypatch): + monkeypatch.setattr( + cloudflare, "__file__", str(mock_metadata.parent.parent / "cloudflare.py") + ) + # list_files returns pd.DataFrame or None (if it prints) + # By default it might try to use itables or tabulate + df = cloudflare.list_files() + if df is not None: + assert isinstance(df, pd.DataFrame) + assert "test_file.zip" in df["file"].values diff --git a/tests/test_console.py b/tests/test_console.py new file mode 100644 index 0000000..51eb0aa --- /dev/null +++ b/tests/test_console.py @@ -0,0 +1,59 @@ +import pytest +import inspect +from toolviper.utils.console import ColorCodes, Colorize, add_verbose_info + + +def test_color_codes(): + codes = ColorCodes() + assert codes.reset == "\033[0m" + assert codes.red == "\033[38;2;220;20;60m" + + +def test_colorize_basic(): + c = Colorize() + text = "hello" + assert c.bold(text) == f"\033[1m{text}\033[0m" + assert c.red(text) == f"\033[38;2;220;20;60m{text}\033[0m" + assert c.blue(text) == f"\033[38;2;50;50;205m{text}\033[0m" + + +def test_colorize_format_list(): + c = Colorize() + # Testing format with RGB list + formatted = c.format("test", color=[255, 0, 0]) + assert "test" in formatted + assert "38;2;255;0;0" in formatted + + +def test_colorize_format_string(): + c = Colorize() + formatted = c.format("test", color="green") + assert "test" in formatted + assert "38;2;46;139;87" in formatted + + +def test_get_color_function(): + c = Colorize() + fn = c.get_color_function("red") + assert fn == c.red + + # Default to black if not found + fn_unknown = c.get_color_function("nonexistent") + assert fn_unknown == c.black + + +def test_add_verbose_info(): + def dummy_caller(): + # result = add_verbose_info("my message") + # In this context, dummy_caller is the direct caller, so PENULTIMATE_FUNCTION (2) + # might refer to the caller of dummy_caller if add_verbose_info is called from it. + # Actually, add_verbose_info uses PENULTIMATE_FUNCTION = 2. + # stack[0] = add_verbose_info + # stack[1] = dummy_caller + # stack[2] = test_add_verbose_info + return add_verbose_info("my message") + + result = dummy_caller() + # It seems in this pytest execution, it gets 'test_add_verbose_info' as PENULTIMATE_FUNCTION + assert "my message" in result + assert "\033[" in result diff --git a/tests/test_dask_plugins.py b/tests/test_dask_plugins.py new file mode 100644 index 0000000..0410743 --- /dev/null +++ b/tests/test_dask_plugins.py @@ -0,0 +1,81 @@ +import pytest +from unittest.mock import MagicMock, patch +from toolviper.dask.plugins.worker import DaskWorker +from toolviper.dask.plugins.scheduler import Scheduler, unravel_deps, get_node_depths + + +def test_dask_worker_init(): + log_params = { + "log_level": "DEBUG", + "log_to_term": False, + "log_to_file": True, + "log_file": "test.log", + } + plugin = DaskWorker(local_cache=True, log_params=log_params) + assert plugin.local_cache is True + assert plugin.log_level == "DEBUG" + assert plugin.log_to_term is False + assert plugin.log_to_file is True + assert plugin.log_file == "test.log" + + +def test_dask_worker_setup(): + plugin = DaskWorker(log_params={"log_level": "INFO"}) + mock_worker = MagicMock() + mock_worker.id = "worker-1" + mock_worker.address = "tcp://127.0.0.1:1234" + mock_worker.state.available_resources = {} + + with patch("toolviper.utils.logger.setup_worker_logger") as mock_setup_logger: + mock_logger = MagicMock() + mock_setup_logger.return_value = mock_logger + + plugin.setup(mock_worker) + + mock_setup_logger.assert_called_once() + assert plugin.worker == mock_worker + # Check if resource for IP was added + assert "127.0.0.1" in mock_worker.state.available_resources + + +def test_scheduler_init(): + scheduler = Scheduler(autorestrictor=True, local_cache=False) + assert scheduler.autorestrictor is True + assert scheduler.local_cache is False + + +def test_unravel_deps(): + hlg_deps = { + "task1": {"task2", "task3"}, + "task2": {"task4"}, + "task3": set(), + "task4": set(), + } + unravelled = unravel_deps(hlg_deps, "task1") + assert unravelled == {"task2", "task3", "task4"} + + +def test_get_node_depths(): + dependencies = {"A": set(), "B": {"A"}, "C": {"B"}, "D": {"A"}} + root_nodes = {"A"} + # metrics[node][-1] is the "depth" of the node from terminal nodes (as calculated by graph_metrics) + # get_node_depths calculates: max(metrics[r][-1] - metrics[k][-1] for r in roots) + # For a simple chain A -> B -> C: + # C is terminal, depth 0 in metrics. + # B depends on A, so B's depth in metrics is 1. + # A is root, depth 2 in metrics. + metrics = { + "A": [0, 2], # depth 2 + "B": [0, 1], # depth 1 + "C": [0, 0], # depth 0 + "D": [0, 1], # depth 1 + } + + node_depths = get_node_depths(dependencies, root_nodes, metrics) + assert node_depths["A"] == 0 + # For B: roots is {'A'}. node_depths['B'] = max(metrics['A'][1] - metrics['B'][1]) = 2 - 1 = 1 + assert node_depths["B"] == 1 + # For C: roots is {'A'}. node_depths['C'] = max(metrics['A'][1] - metrics['C'][1]) = 2 - 0 = 2 + assert node_depths["C"] == 2 + # For D: roots is {'A'}. node_depths['D'] = max(metrics['A'][1] - metrics['D'][1]) = 2 - 1 = 1 + assert node_depths["D"] == 1 diff --git a/tests/test_download.py b/tests/test_download.py index e4815b0..438ac3c 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -148,13 +148,5 @@ def test_private_print_file_queue(self): logger.info("Failure test passed!") return None - # If error isn't as expected, fail the test. + # If the error isn't as expected, fail the test. raise AssertionError() - - def test_private_make_dir(self): - from toolviper.utils.data.cloudflare import _make_dir - - _make_dir(path=str(pathlib.Path.cwd()), folder="data") - - if not pathlib.Path.cwd().joinpath("data").exists(): - raise FileNotFoundError("data") diff --git a/tests/test_logger.py b/tests/test_logger.py new file mode 100644 index 0000000..46567d9 --- /dev/null +++ b/tests/test_logger.py @@ -0,0 +1,205 @@ +import pytest +import logging +from unittest.mock import MagicMock, patch +from toolviper.utils.logger import ( + set_verbosity, + verbosity, + info, + debug, + warning, + error, + critical, + exception, + log, + get_logger, + setup_logger, + ColorLoggingFormatter, + LoggingFormatter, + get_worker_logger_name, + setup_worker_logger, +) + + +@pytest.fixture(autouse=True) +def reset_verbosity(): + token = verbosity.set(None) + yield + verbosity.reset(token) + + +@pytest.fixture +def mock_logger(): + with patch("toolviper.utils.logger.get_logger") as mock: + logger = MagicMock(spec=logging.Logger) + mock.return_value = logger + yield logger + + +def test_set_verbosity(): + set_verbosity(True) + assert verbosity.get() is True + set_verbosity(False) + assert verbosity.get() is False + set_verbosity(None) + assert verbosity.get() is None + + +def test_info_logging(mock_logger): + info("test message") + mock_logger.info.assert_called_with("test message") + + +def test_info_verbose_logging(mock_logger): + with patch( + "toolviper.utils.logger.add_verbose_info", return_value="verbose test" + ) as mock_add: + info("test message", verbose=True) + mock_add.assert_called() + mock_logger.info.assert_called_with("verbose test") + + +def test_info_verbosity_context(mock_logger): + set_verbosity(True) + with patch("toolviper.utils.logger.add_verbose_info", return_value="verbose test"): + info("test message") + mock_logger.info.assert_called_with("verbose test") + + +def test_debug_logging(mock_logger): + debug("debug message") + mock_logger.debug.assert_called_with("debug message") + + +def test_warning_logging(mock_logger): + warning("warning message") + mock_logger.warning.assert_called_with("warning message") + + +def test_error_logging(mock_logger): + with patch( + "toolviper.utils.logger.add_verbose_info", + side_effect=lambda message, color: message, + ): + error("error message") + mock_logger.error.assert_called_with("error message") + + +def test_critical_logging(mock_logger): + with patch( + "toolviper.utils.logger.add_verbose_info", + side_effect=lambda message, color: message, + ): + critical("critical message") + mock_logger.critical.assert_called_with("critical message") + + +def test_exception_logging(mock_logger): + exception("exception message") + mock_logger.exception.assert_called_with("exception message") + + +def test_log_logging(mock_logger): + mock_logger.level = logging.INFO + log("log message") + mock_logger.log.assert_called_with(logging.INFO, "log message") + + +def test_get_logger_no_env_no_worker(monkeypatch): + monkeypatch.delenv("VIPER_LOGGER_NAME", raising=False) + with patch("toolviper.utils.logger.get_worker", side_effect=ValueError): + logger = get_logger() + assert logger.name == "viperlog" + # Since it's a new logger, it should have a StreamHandler + assert any(isinstance(h, logging.StreamHandler) for h in logger.handlers) + + +def test_get_logger_existing_logger(monkeypatch): + monkeypatch.delenv("VIPER_LOGGER_NAME", raising=False) + # Pre-create logger + existing_logger = logging.getLogger("existing_log") + with patch("toolviper.utils.logger.get_worker", side_effect=ValueError): + logger = get_logger("existing_log") + assert logger == existing_logger + + +def test_get_logger_env(): + with ( + patch("os.environ", {"VIPER_LOGGER_NAME": "env_logger"}), + patch("toolviper.utils.logger.get_worker", side_effect=ValueError), + ): + logger = get_logger() + assert logger.name == "env_logger" + + +def test_get_logger_worker(): + mock_worker = MagicMock() + mock_logger_obj = MagicMock() + mock_worker.plugins = {"worker_logger": MagicMock()} + mock_worker.plugins["worker_logger"].get_logger.return_value = mock_logger_obj + + with patch("toolviper.utils.logger.get_worker", return_value=mock_worker): + logger = get_logger("test_logger") + assert logger == mock_logger_obj + + +def test_setup_logger_basic(tmp_path): + log_file_base = str(tmp_path / "test_log") + logger = setup_logger( + logger_name="setup_test", + log_to_term=True, + log_to_file=True, + log_file=log_file_base, + ) + assert logger.name == "setup_test" + assert len(logger.handlers) == 2 + # Cleanup + for handler in logger.handlers: + handler.close() + + +def test_color_logging_formatter(): + formatter = ColorLoggingFormatter() + record = logging.LogRecord("name", logging.INFO, "path", 10, "msg", None, None) + formatted = formatter.format(record) + assert "INFO" in formatted + assert "msg" in formatted + + +def test_logging_formatter(): + formatter = LoggingFormatter() + record = logging.LogRecord( + "name", logging.ERROR, "path", 20, "error msg", None, None + ) + formatted = formatter.format(record) + assert "ERROR" in formatted + assert "error msg" in formatted + + +def test_get_worker_logger_name(): + mock_worker = MagicMock() + mock_worker.id = "worker-123" + with patch("toolviper.utils.logger.get_worker", return_value=mock_worker): + name = get_worker_logger_name("mylog") + assert name == "mylog_worker-123" + + +def test_setup_worker_logger(tmp_path): + mock_worker = MagicMock() + mock_worker.name = "worker-1" + mock_worker.ip = "127.0.0.1" + log_file_base = str(tmp_path / "worker_log") + + with patch("dask.distributed.print"): + logger = setup_worker_logger( + logger_name="worker_test", + log_to_term=True, + log_to_file=True, + log_file=log_file_base, + log_level="DEBUG", + worker=mock_worker, + ) + assert "worker_test_worker-1" == logger.name + assert logger.level == logging.DEBUG + # Cleanup + for handler in logger.handlers: + handler.close() diff --git a/tests/test_menrva.py b/tests/test_menrva.py index a6f89dd..edd23db 100644 --- a/tests/test_menrva.py +++ b/tests/test_menrva.py @@ -2,6 +2,7 @@ import re import pathlib import distributed +from unittest.mock import patch, MagicMock from toolviper.dask import menrva from toolviper.dask.client import local_client @@ -78,3 +79,62 @@ def test_thread_info(self): ) client.shutdown() + + +def test_port_is_free(): + from toolviper.dask.menrva import port_is_free + import socket + + # Test with a definitely free port (hopefully) + # We can use port 0 to let the OS pick a free port, but port_is_free binds it and closes it. + # Let's try to bind a port ourselves and then check if it's free. + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + + # Since 's' is holding the port, port_is_free should return False + assert port_is_free(port) is False + + s.close() + # Now it should be free + assert port_is_free(port) is True + + +def test_close_port(): + from toolviper.dask.menrva import close_port, port_is_free + import socket + import time + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + s.listen(1) + + assert port_is_free(port) is False + + # close_port tries to kill the process holding the port. + # Since it's our own process, this might be dangerous if not careful, + # but here it's just a socket in the same process. + # Actually, close_port uses psutil to find processes with that port and SIGKILLs them. + # We should probably mock psutil for this test to avoid killing ourselves. + + with patch("psutil.process_iter") as mock_iter: + mock_proc = MagicMock() + mock_conn = MagicMock() + mock_conn.laddr.port = port + mock_proc.connections.return_value = [mock_conn] + mock_iter.return_value = [mock_proc] + + close_port(port) + + mock_proc.send_signal.assert_called_once() + + +def test_menrva_client_call(): + from toolviper.dask.menrva import MenrvaClient + + def my_func(a, b=1): + return a + b + + assert MenrvaClient.call(my_func, 2, b=3) == 5 + assert MenrvaClient.call(my_func, 2) == 3 diff --git a/tests/test_parameter.py b/tests/test_parameter.py new file mode 100644 index 0000000..52a43a9 --- /dev/null +++ b/tests/test_parameter.py @@ -0,0 +1,141 @@ +import pytest +import os +import json +from unittest.mock import patch, MagicMock +from toolviper.utils.parameter import ( + validate, + get_path, + set_config_directory, + is_notebook, + verify, +) + + +def test_is_notebook(): + # Should be False in normal python environment + assert is_notebook() is False + + +def test_get_path_standard(monkeypatch): + def dummy_func(): + pass + + # Mock inspect.getmodule and inspect.getfile + mock_module = MagicMock() + mock_module.__name__ = "toolviper.utils.dummy" + + with ( + patch("inspect.getmodule", return_value=mock_module), + patch("inspect.getfile", return_value="/abs/path/src/toolviper/utils/dummy.py"), + ): + base, mod = get_path(dummy_func) + assert "src/toolviper" in base + assert mod == "/abs/path/src/toolviper/utils/dummy" + + +def test_set_config_directory(tmp_path): + config_dir = tmp_path / "my_config" + config_dir.mkdir() + + with patch("toolviper.utils.logger.info"): + set_config_directory(str(config_dir)) + assert os.environ["PARAMETER_CONFIG_PATH"] == str(config_dir) + + +def test_validate_decorator_success(tmp_path, monkeypatch): + # Setup a mock config file + config_dir = tmp_path / "config" + config_dir.mkdir() + param_file = config_dir / "test_mod.param.json" + schema = { + "my_func": {"arg1": {"type": "int", "required": True}, "arg2": {"type": "str"}} + } + with open(param_file, "w") as f: + json.dump(schema, f) + + def my_func(arg1, arg2="default"): + return f"{arg1}-{arg2}" + + # Manually wrap with validate and trick it + my_func.__module__ = "toolviper.utils.test_mod" + my_func.__name__ = "my_func" + + wrapped = validate(config_dir=str(config_dir))(my_func) + + # We also need to mock get_path to avoid it searching in /tmp or something + with patch( + "toolviper.utils.parameter.get_path", + return_value=(str(tmp_path), str(tmp_path / "test_mod")), + ): + assert wrapped(10, arg2="hello") == "10-hello" + + +def test_validate_decorator_failure(tmp_path): + config_dir = tmp_path / "config" + config_dir.mkdir() + param_file = config_dir / "test_mod.param.json" + schema = {"fail_func": {"arg1": {"type": "int"}}} + with open(param_file, "w") as f: + json.dump(schema, f) + + def fail_func(arg1): + return arg1 + + fail_func.__module__ = "toolviper.utils.test_mod" + fail_func.__name__ = "fail_func" + + wrapped = validate(config_dir=str(config_dir))(fail_func) + + with patch( + "toolviper.utils.parameter.get_path", + return_value=(str(tmp_path), str(tmp_path / "test_mod")), + ): + # Should raise AssertionError from verify's assert validator.validate(args) + with pytest.raises(AssertionError): + wrapped("not an int") + + +def test_verify_missing_config(): + def no_config_func(): + pass + + no_config_func.__module__ = "ghost_module" + + with ( + patch( + "toolviper.utils.parameter.get_path", + return_value=("/tmp", "/tmp/ghost_module"), + ), + patch("toolviper.utils.logger.error"), + ): + with pytest.raises(FileNotFoundError): + verify( + no_config_func, + {}, + {"function": "no_config_func", "module": "ghost_module"}, + ) + + +def test_verify_function_not_in_schema(tmp_path): + config_dir = tmp_path / "config" + config_dir.mkdir() + param_file = config_dir / "known_mod.param.json" + with open(param_file, "w") as f: + json.dump({"other_func": {}}, f) + + def unknown_func(): + pass + + unknown_func.__module__ = "known_mod" + + with patch( + "toolviper.utils.parameter.get_path", + return_value=(str(tmp_path), str(tmp_path / "known_mod")), + ): + with pytest.raises(KeyError): + verify( + unknown_func, + {}, + {"function": "unknown_func", "module": "known_mod"}, + config_dir=str(config_dir), + ) diff --git a/tests/test_tools.py b/tests/test_tools.py index c3bd8f1..9b047cb 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,85 +1,201 @@ +import pytest +import json import pathlib -import toolviper +import hashlib +import os +from unittest.mock import MagicMock, patch +from toolviper.utils.tools import ( + open_json, + calculate_checksum, + iter_files_, + update_hash, + verify, + process_entry_, + add_entry, + update_version, + ChecksumError, +) -import toolviper.utils.logger as logger +@pytest.fixture +def temp_json_file(tmp_path): + data = {"version": "v1.0.0", "metadata": {}} + file_path = tmp_path / "test.json" + with open(file_path, "w") as f: + json.dump(data, f) + return file_path -class TestToolViperTools: - @classmethod - def setup_class(cls): - """setup any state specific to the execution of the given test class - such as fetching test data""" - pass - @classmethod - def teardown_class(cls): - """teardown any state that was previously setup with a call to setup_class - such as deleting test data""" - # cls.client.shutdown() - pass +@pytest.fixture +def temp_data_file(tmp_path): + file_path = tmp_path / "test_file.txt" + content = b"hello world" + file_path.write_bytes(content) + # sha256 of "hello world" + expected_hash = hashlib.sha256(content).hexdigest() + return file_path, expected_hash - def setup_method(self): - """setup any state specific to all methods of the given class""" - pass - def teardown_method(self): - """teardown any state that was previously setup for all methods of the given class""" - pass +def test_open_json_success(temp_json_file): + data = open_json(str(temp_json_file)) + assert data["version"] == "v1.0.0" - def test_open_json(self): - from toolviper.utils.tools import open_json - try: - open_json("tests/data/test.json") +def test_open_json_not_found(): + with pytest.raises(FileNotFoundError): + open_json("non_existent_file.json") - except FileNotFoundError: - logger.info(f"Function open_json(...) working as expected.") - return None - raise AssertionError +def test_calculate_checksum(temp_data_file): + file_path, expected_hash = temp_data_file + assert calculate_checksum(str(file_path)) == expected_hash - def test_private_iter_files(self): - from toolviper.utils.tools import iter_files_ - try: - for _ in iter_files_("tests/data/"): - pass +def test_iter_files_(tmp_path): + (tmp_path / "file1.txt").write_text("1") + (tmp_path / "file2.txt").write_text("2") + files = list(iter_files_(str(tmp_path))) + assert set(files) == {"file1.txt", "file2.txt"} - except FileNotFoundError: - logger.info(f"Function iter_files_(...) working as expected.") - return None - raise AssertionError +def test_iter_files_not_found(): + with pytest.raises(FileNotFoundError): + list(iter_files_("non_existent_path")) - def test_verify(self): - from toolviper.utils.tools import verify - try: - # Test files that doesn't exist - verify(filename="test.json", folder="tests/data") +def test_update_hash(tmp_path): + # Setup manifest + manifest_path = tmp_path / "manifest.json" + manifest_data = {"metadata": {"test_file": {"hash": "old_hash"}}} + with open(manifest_path, "w") as f: + json.dump(manifest_data, f) - except FileNotFoundError: - logger.info(f"Function verify(...) working as expected.") - return None + # Setup data file + data_file = tmp_path / "test_file" + data_file.write_text("new content") + new_hash = hashlib.sha256(b"new content").hexdigest() - raise AssertionError + update_hash(str(manifest_path), str(tmp_path)) - def test_calculate_checksum(self): - from toolviper.utils.tools import calculate_checksum + updated_manifest = open_json(str(manifest_path)) + assert updated_manifest["metadata"]["test_file"]["hash"] == new_hash - base_address = pathlib.Path(toolviper.__file__).parent - metadata_address = base_address.joinpath( - "utils/data/.cloudflare/file.download.json" - ) - metadata = toolviper.utils.tools.open_json(str(metadata_address)) +def test_verify_success(tmp_path, monkeypatch): + # Setup manifest in a place where verify can find it (mocking toolviper.__file__) + manifest_dir = tmp_path / "utils/data/.cloudflare" + manifest_dir.mkdir(parents=True) + manifest_path = manifest_dir / "file.download.json" + + data_file = tmp_path / "test.zip" + data_file.write_text("zip content") + expected_hash = hashlib.sha256(b"zip content").hexdigest() + + manifest_data = {"metadata": {"test": {"hash": expected_hash}}} + with open(manifest_path, "w") as f: + json.dump(manifest_data, f) + + import toolviper + + monkeypatch.setattr(toolviper, "__file__", str(tmp_path / "__init__.py")) + (tmp_path / "__init__.py").touch() + + # verify(filename, folder) + # verify handles .zip extension by stripping it + verify("test.zip", str(tmp_path)) + + +def test_verify_checksum_error(tmp_path, monkeypatch): + manifest_dir = tmp_path / "utils/data/.cloudflare" + manifest_dir.mkdir(parents=True) + manifest_path = manifest_dir / "file.download.json" + + data_file = tmp_path / "test.zip" + data_file.write_text("wrong content") + + manifest_data = {"metadata": {"test": {"hash": "expected_but_different_hash"}}} + with open(manifest_path, "w") as f: + json.dump(manifest_data, f) + + import toolviper + + monkeypatch.setattr(toolviper, "__file__", str(tmp_path / "__init__.py")) + (tmp_path / "__init__.py").touch() + + with pytest.raises(ChecksumError): + verify("test.zip", str(tmp_path)) + - path = pathlib.Path.cwd().joinpath("data") - path.mkdir(parents=True, exist_ok=True) +def test_update_version(): + with ( + patch("toolviper.utils.tools.open_json") as mock_open, + patch("pathlib.Path.exists", return_value=True), + ): - toolviper.utils.data.download(file="checksum.hash", folder=str(path)) + mock_open.return_value = {"version": "v1.2.3"} - assert ( - toolviper.utils.tools.calculate_checksum(file="data/checksum.hash") - == metadata["metadata"]["checksum.hash"]["hash"] + # current implementation doesn't reset other parts + assert update_version("major") == "v2.2.3" + assert update_version("minor") == "v1.3.3" + assert update_version("patch") == "v1.2.4" + assert update_version("unknown") is None + + +def test_process_entry_(tmp_path): + json_file = {"metadata": {}} + test_file = tmp_path / "test.zip" + test_file.write_text("content") + file_hash = hashlib.sha256(b"content").hexdigest() + + with patch("toolviper.utils.data.get_file_size", return_value={"test": 123}): + process_entry_( + file=str(test_file), + path="verification", + dtype="int", + telescope="VLA", + mode="test", + json_file=json_file, ) + + assert "test" in json_file["metadata"] + assert json_file["metadata"]["test"]["hash"] == file_hash + assert json_file["metadata"]["test"]["size"] == "123" + + +def test_add_entry(tmp_path, monkeypatch): + manifest_dir = tmp_path / "utils/data/.cloudflare" + manifest_dir.mkdir(parents=True) + manifest_path = manifest_dir / "file.download.json" + manifest_data = {"version": "v1.0.0", "metadata": {}} + with open(manifest_path, "w") as f: + json.dump(manifest_data, f) + + import toolviper + + monkeypatch.setattr(toolviper, "__path__", [str(tmp_path)]) + + test_file = tmp_path / "new_file.zip" + test_file.write_text("content") + + entry = { + "file": str(test_file), + "path": "verification", + "dtype": "int", + "telescope": "VLA", + "mode": "test", + } + + with patch("toolviper.utils.data.get_file_size", return_value={"new_file": 456}): + result = add_entry(entries=[entry], manifest=str(manifest_path)) + + assert result["version"] == "v1.0.1" + assert "new_file" in result["metadata"] + assert os.path.exists("file.download.json") + os.remove("file.download.json") + + +def test_checksum_error_str(): + error = ChecksumError("msg", "file.txt", "/folder", 10) + assert "[10]: There was an error verifying the checksum of /folder/file.txt" in str( + error + )