diff --git a/docs/modules.rst b/docs/modules.rst index 467d763..6aecad8 100644 --- a/docs/modules.rst +++ b/docs/modules.rst @@ -4,6 +4,7 @@ Rush Modules :maxdepth: 1 rush.exess + rush.admet_ai_rex rush.nnxtb rush.prepare_protein rush.prepare_complex diff --git a/docs/rush.admet_ai_rex.rst b/docs/rush.admet_ai_rex.rst new file mode 100644 index 0000000..1581f96 --- /dev/null +++ b/docs/rush.admet_ai_rex.rst @@ -0,0 +1,7 @@ +ADMET AI Rex +============ + +.. automodule:: rush.admet_ai_rex + :members: + :undoc-members: + :show-inheritance: diff --git a/src/rush/admet_ai_rex.py b/src/rush/admet_ai_rex.py new file mode 100644 index 0000000..8935ebf --- /dev/null +++ b/src/rush/admet_ai_rex.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +""" +Raw-config Rex wrappers for the ADMET AI Tengu module repo. +""" + +import sys +from pathlib import Path +from string import Template + +from gql.transport.exceptions import TransportQueryError + +from .client import ( + RunOpts, + RunSpec, + _get_project_id, + _submit_rex, + collect_run, + upload_object, +) + + +def _upload_json(input_json: Path | str) -> str: + if isinstance(input_json, str): + input_json = Path(input_json) + obj = upload_object(input_json) + return obj["path"] + + +def talo_admet_ai_rex( + input_json: Path | str, + config_rex: str, + run_spec: RunSpec = RunSpec(), + run_opts: RunOpts = RunOpts(), + collect=False, +): + """ + Run talo_admet_ai_rex with a raw config Rex expression and JSON input. + """ + input_path = _upload_json(input_json) + rex = Template("""let + obj_j = λ j → + VirtualObject { path = j, format = ObjectFormat::json, size = 0 }, + cfg = $config_rex, + input = obj_j "$input_path", + result = talo_admet_ai_rex_s + ($run_spec) + cfg + input +in + result +""").substitute( + run_spec=run_spec._to_rex(), + config_rex=config_rex, + input_path=input_path, + ) + try: + run_id = _submit_rex(_get_project_id(), rex, run_opts) + if collect: + return collect_run(run_id) + return run_id + except TransportQueryError as e: + if e.errors: + for error in e.errors: + print(f"Error: {error['message']}", file=sys.stderr) + + +def talo_admet_ai_plot_drugbank_rex( + input_json: Path | str, + config_rex: str, + run_spec: RunSpec = RunSpec(), + run_opts: RunOpts = RunOpts(), + collect=False, +): + """ + Run talo_admet_ai_plot_drugbank_rex with a raw config Rex expression and JSON input. + """ + input_path = _upload_json(input_json) + rex = Template("""let + obj_j = λ j → + VirtualObject { path = j, format = ObjectFormat::json, size = 0 }, + cfg = $config_rex, + preds = obj_j "$input_path", + result = talo_admet_ai_plot_drugbank_rex_s + ($run_spec) + cfg + preds +in + result +""").substitute( + run_spec=run_spec._to_rex(), + config_rex=config_rex, + input_path=input_path, + ) + try: + run_id = _submit_rex(_get_project_id(), rex, run_opts) + if collect: + return collect_run(run_id) + return run_id + except TransportQueryError as e: + if e.errors: + for error in e.errors: + print(f"Error: {error['message']}", file=sys.stderr) + + +def talo_admet_ai_plot_radial_rex( + input_json: Path | str, + config_rex: str, + run_spec: RunSpec = RunSpec(), + run_opts: RunOpts = RunOpts(), + collect=False, +): + """ + Run talo_admet_ai_plot_radial_rex with a raw config Rex expression and JSON input. + """ + input_path = _upload_json(input_json) + rex = Template("""let + obj_j = λ j → + VirtualObject { path = j, format = ObjectFormat::json, size = 0 }, + cfg = $config_rex, + preds = obj_j "$input_path", + result = talo_admet_ai_plot_radial_rex_s + ($run_spec) + cfg + preds +in + result +""").substitute( + run_spec=run_spec._to_rex(), + config_rex=config_rex, + input_path=input_path, + ) + try: + run_id = _submit_rex(_get_project_id(), rex, run_opts) + if collect: + return collect_run(run_id) + return run_id + except TransportQueryError as e: + if e.errors: + for error in e.errors: + print(f"Error: {error['message']}", file=sys.stderr) + + +def talo_admet_ai_web_rex( + input_json: Path | str | None, + config_rex: str, + run_spec: RunSpec = RunSpec(), + run_opts: RunOpts = RunOpts(), + collect=False, +): + """ + Run talo_admet_ai_web_rex with a raw config Rex expression. + + The web entrypoint does not accept an input object; input_json is ignored. + """ + rex = Template("""let + cfg = $config_rex, + result = talo_admet_ai_web_rex_s + ($run_spec) + cfg +in + result +""").substitute( + run_spec=run_spec._to_rex(), + config_rex=config_rex, + ) + try: + run_id = _submit_rex(_get_project_id(), rex, run_opts) + if collect: + return collect_run(run_id) + return run_id + except TransportQueryError as e: + if e.errors: + for error in e.errors: + print(f"Error: {error['message']}", file=sys.stderr) diff --git a/src/rush/client.py b/src/rush/client.py index 27c67bf..a707557 100644 --- a/src/rush/client.py +++ b/src/rush/client.py @@ -90,6 +90,10 @@ def _get_project_id() -> str: # staging "auto3d_rex": "github:talo/tengu-auto3d/ce81cfb6f4f2628cee07400992650c15ccec790e#auto3d_rex", "boltz2_rex": "github:talo/tengu-boltz2/76df0b4b4fa42e88928a430a54a28620feef8ea8#boltz2_rex", + "talo_admet_ai_plot_drugbank_rex": "github:talo/admet_ai_rex/9757825d6f7a3bac632344a9af5cfdd0249ce8f0#talo_admet_ai_plot_drugbank_rex", + "talo_admet_ai_plot_radial_rex": "github:talo/admet_ai_rex/9757825d6f7a3bac632344a9af5cfdd0249ce8f0#talo_admet_ai_plot_radial_rex", + "talo_admet_ai_rex": "github:talo/admet_ai_rex/9757825d6f7a3bac632344a9af5cfdd0249ce8f0#talo_admet_ai_rex", + "talo_admet_ai_web_rex": "github:talo/admet_ai_rex/9757825d6f7a3bac632344a9af5cfdd0249ce8f0#talo_admet_ai_web_rex", "exess_rex": "github:talo/tengu-exess/ac24fadc935aa66b398aad3bacffc30f6cf3116a#exess_rex", "exess_geo_opt_rex": "github:talo/tengu-exess/f64f752732d89c47731085f1a688bfd2dee6dfc7#exess_geo_opt_rex", "exess_qmmm_rex": "github:talo/tengu-exess/61b1874f8df65a083e9170082250473fd8e46978#exess_qmmm_rex", diff --git a/tests/data/admet_ai_rex/preds.json b/tests/data/admet_ai_rex/preds.json new file mode 100644 index 0000000..146aa95 --- /dev/null +++ b/tests/data/admet_ai_rex/preds.json @@ -0,0 +1,20 @@ +[ + { + "columns": [ + {"key": "BBB_Martins_drugbank_approved_percentile", "value": "10"}, + {"key": "ClinTox_drugbank_approved_percentile", "value": "20"}, + {"key": "Solubility_AqSolDB_drugbank_approved_percentile", "value": "30"}, + {"key": "Bioavailability_Ma_drugbank_approved_percentile", "value": "40"}, + {"key": "hERG_drugbank_approved_percentile", "value": "50"} + ] + }, + { + "columns": [ + {"key": "BBB_Martins_drugbank_approved_percentile", "value": "25"}, + {"key": "ClinTox_drugbank_approved_percentile", "value": "35"}, + {"key": "Solubility_AqSolDB_drugbank_approved_percentile", "value": "45"}, + {"key": "Bioavailability_Ma_drugbank_approved_percentile", "value": "55"}, + {"key": "hERG_drugbank_approved_percentile", "value": "65"} + ] + } +] diff --git a/tests/test_onboarding_admet_ai_rex.py b/tests/test_onboarding_admet_ai_rex.py new file mode 100644 index 0000000..12832ee --- /dev/null +++ b/tests/test_onboarding_admet_ai_rex.py @@ -0,0 +1,52 @@ +import sys +from pathlib import Path + +import pytest + +from rush import admet_ai_rex +from rush.client import ( + GRAPHQL_ENDPOINT, + MODULE_LOCK, + RunError, + RunOpts, + RunSpec, + _get_env, + collect_run, + set_opts, +) + + +MODULE_KEYS = [ + "talo_admet_ai_rex", + "talo_admet_ai_plot_drugbank_rex", + "talo_admet_ai_plot_radial_rex", + "talo_admet_ai_web_rex", +] + + +def test_admet_ai_rex_imports(): + assert hasattr(admet_ai_rex, "talo_admet_ai_rex") + assert hasattr(admet_ai_rex, "talo_admet_ai_plot_drugbank_rex") + assert hasattr(admet_ai_rex, "talo_admet_ai_plot_radial_rex") + assert hasattr(admet_ai_rex, "talo_admet_ai_web_rex") + + +def test_admet_ai_rex_module_lock(): + if "staging" in GRAPHQL_ENDPOINT: + for key in MODULE_KEYS: + assert key in MODULE_LOCK + else: + pytest.xfail("Prod endpoint in use and update_prod=False.") + + +def _assert_run_ok(result): + if isinstance(result, RunError): + pytest.fail(f"RunError: {result.message}") + if isinstance(result, dict) and "Err" in result: + pytest.fail(f"Run returned Err: {result['Err']}") + if isinstance(result, (list, tuple)): + for item in result: + if isinstance(item, dict) and "Err" in item: + pytest.fail(f"Run returned Err: {item['Err']}") + +