diff --git a/.env.example b/.env.example index 20247c2..b26de13 100644 --- a/.env.example +++ b/.env.example @@ -1,5 +1,5 @@ OPENAI_API_KEY= -HF_DATASET_KEY= +HF_DATASET_KEY= MLFLOW_TRACKING_URI= MLFLOW_TRACKING_USERNAME=admin MLFLOW_TRACKING_PASSWORD=password \ No newline at end of file diff --git a/README.md b/README.md index 4b60c45..81475cf 100644 --- a/README.md +++ b/README.md @@ -11,11 +11,13 @@ We utilize [poetry](https://python-poetry.org/) for dependency management. Pleas We utilize [commitizen](https://commitizen-tools.github.io/commitizen/) for commit messages and semantic versioning. Please run `cz commit` to commit your changes. Commitizen can be installed with `pip install commitizen` or `brew install commitizen`. +We utilize [docker](https://www.docker.com/) for managing the tracking of our service and associated expirements through [mlflow](https://mlflow.org/). In our docker image, we spin up a [mlflow](https://mlflow.org/), [postgres](https://www.postgresql.org/), and [minio](https://min.io/) instance. This is very similar to our production setup, and allows for a pretty smooth development flow between local and prod. Please ensure you have downloaded and are running docker in the background of your machine. + Here are some quick commands for getting started: ```bash -brew add poetry -brew add commitizen +brew install poetry +brew install commitizen ``` ```bash @@ -26,11 +28,35 @@ cd ../mlflow-manager poetry install ``` +### .env + +There are two `.env` files that we expect the user to set up. They are divided between `mlflow-manager` and `graphdoc`. First, let's setup the `mlflow-manager` `.env` file. You can leave these values as they are, or modify them as you see fit: + +```bash +# navigate to the docker root +cd mlflow-manager +cd docker + +# copy the .env.example for setup +cp .env.example .env # set values directly in your newly created .env file +``` + +Next, let's set up the `.env` file to be used by our `graphdoc` program. + +```bash +# navigate to the graphdoc root +cd ../.. + +# copy the .env.example for setup +cp .env.example .env # set values directly in your newly created .env file +``` + ### run.sh The `run.sh` script is a convenience script for development. It provides a few shortcuts for running useful commands. ```bash +# make sure you are in the root of the repository # ensure that the script is executable chmod +x run.sh @@ -41,6 +67,8 @@ chmod +x run.sh To setup the mlflow-manager services, run the following command: ```bash +# default username: admin +# default password: password ./run.sh mlflow-setup ``` diff --git a/graphdoc/assets/configs/single_prompt_doc_generator_module.yaml b/graphdoc/assets/configs/single_prompt_doc_generator_module.yaml index 34aa38e..3c4c84c 100644 --- a/graphdoc/assets/configs/single_prompt_doc_generator_module.yaml +++ b/graphdoc/assets/configs/single_prompt_doc_generator_module.yaml @@ -1,8 +1,5 @@ graphdoc: log_level: INFO # The log level to use (DEBUG, INFO, WARNING, ERROR, CRITICAL) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server - mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server mlflow: mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow @@ -26,6 +23,7 @@ data: evalset_ratio: 0.1 # The proportionate size of the evalset data_helper_type: generation # Type of data helper to use (quality, generation) seed: 42 # The seed for the random number generator + prompt: prompt: base_doc_gen # Which prompt signature to use class: DocGeneratorPrompt # Must be a child of SinglePrompt (we will use an enum to map this) @@ -50,18 +48,15 @@ prompt_metric: trainer: class: DocGeneratorTrainer # The type of trainer to use (DocQualityTrainer) optimizer_type: miprov2 # The type of optimizer to use (miprov2, BootstrapFewShotWithRandomSearch) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow mlflow_model_name: doc_generator_model # The name of the model in MLflow mlflow_experiment_name: doc_generator_experiment # The name of the experiment in MLflow optimizer: optimizer_type: miprov2 # BootstrapFewShotWithRandomSearch, miprov2 auto: light # miprov2 setting - # student: this is the prompt.infer object - # trainset: this is the dataset we are working with - max_labeled_demos: 2 - max_bootstrapped_demos: 4 - num_trials: 2 + max_labeled_demos: 2 # max number of labeled demonstrations + max_bootstrapped_demos: 4 # max number of bootstrapped demonstrations + num_trials: 2 # number of trials minibatch: true # default true module: diff --git a/graphdoc/assets/configs/single_prompt_doc_generator_module_eval.yaml b/graphdoc/assets/configs/single_prompt_doc_generator_module_eval.yaml index 2aeb572..92a52b2 100644 --- a/graphdoc/assets/configs/single_prompt_doc_generator_module_eval.yaml +++ b/graphdoc/assets/configs/single_prompt_doc_generator_module_eval.yaml @@ -1,8 +1,5 @@ graphdoc: log_level: INFO # The log level to use (DEBUG, INFO, WARNING, ERROR, CRITICAL) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server - mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server mlflow: mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow @@ -51,18 +48,15 @@ prompt_metric: trainer: class: DocGeneratorTrainer # The type of trainer to use (DocQualityTrainer) optimizer_type: miprov2 # The type of optimizer to use (miprov2, BootstrapFewShotWithRandomSearch) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow mlflow_model_name: doc_generator_model # The name of the model in MLflow mlflow_experiment_name: doc_generator_experiment # The name of the experiment in MLflow optimizer: optimizer_type: miprov2 # BootstrapFewShotWithRandomSearch, miprov2 auto: light # miprov2 setting - # student: this is the prompt.infer object - # trainset: this is the dataset we are working with - max_labeled_demos: 2 - max_bootstrapped_demos: 4 - num_trials: 2 + max_labeled_demos: 2 # max number of labeled demonstrations + max_bootstrapped_demos: 4 # max number of bootstrapped demonstrations + num_trials: 2 # number of trials minibatch: true # default true module: @@ -72,7 +66,6 @@ module: fill_empty_descriptions: true # Whether to fill the empty descriptions in the schema eval: - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow mlflow_experiment_name: doc_generator_eval # The name of the experiment in MLflow generator_prediction_field: documented_schema # The field in the generator prediction to use evaluator_prediction_field: rating # The field in the evaluator prediction to use diff --git a/graphdoc/assets/configs/single_prompt_doc_generator_trainer.yaml b/graphdoc/assets/configs/single_prompt_doc_generator_trainer.yaml index afa26c3..8e4ce24 100644 --- a/graphdoc/assets/configs/single_prompt_doc_generator_trainer.yaml +++ b/graphdoc/assets/configs/single_prompt_doc_generator_trainer.yaml @@ -1,17 +1,14 @@ graphdoc: - log_level: INFO # The log level to use (DEBUG, INFO, WARNING, ERROR, CRITICAL) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server - mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server + log_level: INFO # The log level to use (DEBUG, INFO, WARNING, ERROR, CRITICAL) mlflow: - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow + mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server language_model: - model: openai/gpt-4o # Must be a valid dspy language model - api_key: !env OPENAI_API_KEY # Must be a valid dspy language model API key + model: openai/gpt-4o # Must be a valid dspy language model + api_key: !env OPENAI_API_KEY # Must be a valid dspy language model API key cache: true # Whether to cache the calls to the language model data: @@ -26,6 +23,7 @@ data: evalset_ratio: 0.1 # The proportionate size of the evalset data_helper_type: generation # Type of data helper to use (quality, generation) seed: 42 # The seed for the random number generator + prompt: prompt: base_doc_gen # Which prompt signature to use class: DocGeneratorPrompt # Must be a child of SinglePrompt (we will use an enum to map this) @@ -50,16 +48,13 @@ prompt_metric: trainer: class: DocGeneratorTrainer # The type of trainer to use (DocQualityTrainer) optimizer_type: miprov2 # The type of optimizer to use (miprov2, BootstrapFewShotWithRandomSearch) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow mlflow_model_name: doc_generator_model # The name of the model in MLflow mlflow_experiment_name: doc_generator_experiment # The name of the experiment in MLflow optimizer: optimizer_type: miprov2 # BootstrapFewShotWithRandomSearch, miprov2 auto: light # miprov2 setting - # student: this is the prompt.infer object - # trainset: this is the dataset we are working with - max_labeled_demos: 2 - max_bootstrapped_demos: 4 - num_trials: 2 + max_labeled_demos: 2 # max number of labeled demonstrations + max_bootstrapped_demos: 4 # max number of bootstrapped demonstrations + num_trials: 2 # number of trials minibatch: true # default true \ No newline at end of file diff --git a/graphdoc/assets/configs/single_prompt_doc_quality_trainer.yaml b/graphdoc/assets/configs/single_prompt_doc_quality_trainer.yaml index 582bafe..3450478 100644 --- a/graphdoc/assets/configs/single_prompt_doc_quality_trainer.yaml +++ b/graphdoc/assets/configs/single_prompt_doc_quality_trainer.yaml @@ -1,8 +1,5 @@ graphdoc: log_level: INFO # The log level to use (DEBUG, INFO, WARNING, ERROR, CRITICAL) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server - mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server mlflow: mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow @@ -26,6 +23,7 @@ data: evalset_ratio: 0.1 # The proportionate size of the evalset data_helper_type: quality # Type of data helper to use (quality, generation) seed: 42 # The seed for the random number generator + prompt: prompt: doc_quality # Which prompt signature to use class: DocQualityPrompt # Must be a child of SinglePrompt (we will use an enum to map this) @@ -50,16 +48,13 @@ prompt_metric: trainer: class: DocQualityTrainer # The type of trainer to use (DocQualityTrainer) optimizer_type: miprov2 # The type of optimizer to use (miprov2, BootstrapFewShotWithRandomSearch) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow mlflow_model_name: doc_quality_model # The name of the model in MLflow mlflow_experiment_name: doc_quality_experiment # The name of the experiment in MLflow optimizer: optimizer_type: miprov2 # BootstrapFewShotWithRandomSearch, miprov2 auto: light # miprov2 setting - # student: this is the prompt.infer object - # trainset: this is the dataset we are working with - max_labeled_demos: 2 - max_bootstrapped_demos: 4 - num_trials: 2 + max_labeled_demos: 2 # max number of labeled demonstrations + max_bootstrapped_demos: 4 # max number of bootstrapped demonstrations + num_trials: 2 # number of trials minibatch: true # default true \ No newline at end of file diff --git a/graphdoc/docs/graphdoc.modules.token_tracker.rst b/graphdoc/docs/graphdoc.modules.token_tracker.rst new file mode 100644 index 0000000..8ac5a49 --- /dev/null +++ b/graphdoc/docs/graphdoc.modules.token_tracker.rst @@ -0,0 +1,8 @@ +graphdoc.modules.token\_tracker module +====================================== + +.. automodule:: graphdoc.modules.token_tracker + :members: + :undoc-members: + :show-inheritance: + :noindex: diff --git a/graphdoc/docs/index.rst b/graphdoc/docs/index.rst index 73eee52..1985497 100644 --- a/graphdoc/docs/index.rst +++ b/graphdoc/docs/index.rst @@ -26,6 +26,5 @@ Indices and tables ================== * :ref:`genindex` -* :ref:`modindex` * :ref:`search` diff --git a/graphdoc/graphdoc/config.py b/graphdoc/graphdoc/config.py index 5a30493..29ab43d 100644 --- a/graphdoc/graphdoc/config.py +++ b/graphdoc/graphdoc/config.py @@ -116,7 +116,7 @@ def mlflow_data_helper_from_yaml(yaml_path: Union[str, Path]) -> MlflowDataHelpe .. code-block:: yaml mlflow: - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow + mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server @@ -429,11 +429,15 @@ def single_trainer_from_dict( .. code-block:: python { + "mlflow": { + "mlflow_tracking_uri": "http://localhost:5000", + "mlflow_tracking_username": "admin", + "mlflow_tracking_password": "password", + }, "trainer": { "class": "DocQualityTrainer", "mlflow_model_name": "doc_quality_model", "mlflow_experiment_name": "doc_quality_experiment", - "mlflow_tracking_uri": "http://localhost:5000" }, "optimizer": { "optimizer_type": "miprov2", @@ -465,7 +469,7 @@ def single_trainer_from_dict( optimizer_kwargs=trainer_dict["optimizer"], mlflow_model_name=trainer_dict["trainer"]["mlflow_model_name"], mlflow_experiment_name=trainer_dict["trainer"]["mlflow_experiment_name"], - mlflow_tracking_uri=trainer_dict["trainer"]["mlflow_tracking_uri"], + mlflow_tracking_uri=trainer_dict["mlflow"]["mlflow_tracking_uri"], trainset=trainset, evalset=evalset, ) @@ -631,7 +635,7 @@ def doc_generator_eval_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorEva .. code-block:: yaml mlflow: - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow + mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server @@ -663,7 +667,6 @@ def doc_generator_eval_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorEva fill_empty_descriptions: true # Whether to fill the empty descriptions in the schema eval: - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow mlflow_experiment_name: doc_generator_eval # The name of the experiment in MLflow generator_prediction_field: documented_schema # The field in the generator prediction to use evaluator_prediction_field: rating # The field in the evaluator prediction to use diff --git a/graphdoc/graphdoc/prompts/schema_doc_quality.py b/graphdoc/graphdoc/prompts/schema_doc_quality.py index 5ab8a69..0ff4acc 100644 --- a/graphdoc/graphdoc/prompts/schema_doc_quality.py +++ b/graphdoc/graphdoc/prompts/schema_doc_quality.py @@ -30,9 +30,9 @@ class DocQualitySignature(dspy.Signature): """ # noqa: B950 database_schema: str = dspy.InputField() - category: Literal[ - "perfect", "almost perfect", "poor but correct", "incorrect" - ] = dspy.OutputField() + category: Literal["perfect", "almost perfect", "poor but correct", "incorrect"] = ( + dspy.OutputField() + ) rating: Literal[4, 3, 2, 1] = dspy.OutputField() @@ -69,9 +69,9 @@ class DocQualityDemonstrationSignature(dspy.Signature): """ # noqa: B950 database_schema: str = dspy.InputField() - category: Literal[ - "perfect", "almost perfect", "poor but correct", "incorrect" - ] = dspy.OutputField() + category: Literal["perfect", "almost perfect", "poor but correct", "incorrect"] = ( + dspy.OutputField() + ) rating: Literal[4, 3, 2, 1] = dspy.OutputField() diff --git a/graphdoc/poetry.lock b/graphdoc/poetry.lock index 95099ee..30df110 100644 --- a/graphdoc/poetry.lock +++ b/graphdoc/poetry.lock @@ -244,7 +244,7 @@ version = "25.1.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "attrs-25.1.0-py3-none-any.whl", hash = "sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a"}, {file = "attrs-25.1.0.tar.gz", hash = "sha256:1c97078a80c814273a76b2a298a932eb681c87415c11dee0a6921de7f1b02c3e"}, @@ -452,7 +452,7 @@ version = "3.4.1" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7" -groups = ["main", "docs"] +groups = ["main", "dev", "docs"] files = [ {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"}, {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"}, @@ -971,7 +971,7 @@ version = "1.7.5" description = "Formats docstrings to follow PEP 257" optional = false python-versions = ">=3.7,<4.0" -groups = ["main"] +groups = ["dev"] files = [ {file = "docformatter-1.7.5-py3-none-any.whl", hash = "sha256:a24f5545ed1f30af00d106f5d85dc2fce4959295687c24c8f39f5263afaf9186"}, {file = "docformatter-1.7.5.tar.gz", hash = "sha256:ffed3da0daffa2e77f80ccba4f0e50bfa2755e1c10e130102571c890a61b246e"}, @@ -1152,7 +1152,7 @@ version = "7.1.2" description = "the modular source code checker: pep8 pyflakes and co" optional = false python-versions = ">=3.8.1" -groups = ["main"] +groups = ["dev"] files = [ {file = "flake8-7.1.2-py2.py3-none-any.whl", hash = "sha256:1cbc62e65536f65e6d754dfe6f1bada7f5cf392d6f5db3c2b85892466c3e7c1a"}, {file = "flake8-7.1.2.tar.gz", hash = "sha256:c586ffd0b41540951ae41af572e6790dbd49fc12b3aa2541685d253d9bd504bd"}, @@ -1169,7 +1169,7 @@ version = "24.12.12" description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle." optional = false python-versions = ">=3.8.1" -groups = ["main"] +groups = ["dev"] files = [ {file = "flake8_bugbear-24.12.12-py3-none-any.whl", hash = "sha256:1b6967436f65ca22a42e5373aaa6f2d87966ade9aa38d4baf2a1be550767545e"}, {file = "flake8_bugbear-24.12.12.tar.gz", hash = "sha256:46273cef0a6b6ff48ca2d69e472f41420a42a46e24b2a8972e4f0d6733d12a64"}, @@ -1799,6 +1799,22 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "isort" +version = "6.0.1" +description = "A Python utility / library to sort Python imports." +optional = false +python-versions = ">=3.9.0" +groups = ["dev"] +files = [ + {file = "isort-6.0.1-py3-none-any.whl", hash = "sha256:2dc5d7f65c9678d94c88dfc29161a320eec67328bc97aad576874cb4be1e9615"}, + {file = "isort-6.0.1.tar.gz", hash = "sha256:1cb5df28dfbc742e490c5e41bad6da41b805b0a8be7bc93cd0fb2a8a890ac450"}, +] + +[package.extras] +colors = ["colorama"] +plugins = ["setuptools"] + [[package]] name = "itsdangerous" version = "2.2.0" @@ -2291,7 +2307,7 @@ version = "0.7.0" description = "McCabe checker, plugin for flake8" optional = false python-versions = ">=3.6" -groups = ["main"] +groups = ["dev"] files = [ {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, @@ -3251,7 +3267,7 @@ version = "2.12.1" description = "Python style guide checker" optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["dev"] files = [ {file = "pycodestyle-2.12.1-py2.py3-none-any.whl", hash = "sha256:46f0fb92069a7c28ab7bb558f05bfc0110dac69a0cd23c61ea0040283a9d78b3"}, {file = "pycodestyle-2.12.1.tar.gz", hash = "sha256:6838eae08bbce4f6accd5d5572075c63626a15ee3e6f842df996bf62f6d73521"}, @@ -3410,7 +3426,7 @@ version = "3.2.0" description = "passive checker of Python programs" optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["dev"] files = [ {file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"}, {file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"}, @@ -4792,7 +4808,7 @@ version = "0.1.1" description = "Transforms tokens into original source code (while preserving whitespace)." optional = false python-versions = "*" -groups = ["main"] +groups = ["dev"] files = [ {file = "untokenize-0.1.1.tar.gz", hash = "sha256:3865dbbbb8efb4bb5eaa72f1be7f3e0be00ea8b7f125c69cbd1f5fda926f37a2"}, ] @@ -5263,4 +5279,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = "3.13.0" -content-hash = "a4d89a17053ed168d8cff47c9b46c50febb503bb62548e73af8fd98cc18f422f" +content-hash = "7e5fca102366df431659ef683a7463902ab0d23918a6819310c306bfad51c308" diff --git a/graphdoc/pyproject.toml b/graphdoc/pyproject.toml index 77a8312..0211945 100644 --- a/graphdoc/pyproject.toml +++ b/graphdoc/pyproject.toml @@ -13,8 +13,6 @@ dspy = "2.6.3" mlflow = "2.20.0" litellm = "1.61.6" responses = "^0.25.6" -flake8-bugbear = "^24.12.12" -docformatter = "^1.7.5" [tool.poetry.group.dev.dependencies] @@ -22,6 +20,9 @@ pytest = "8.3.4" black = "24.10.0" pyright = "1.1.392.post0" pytest-testmon = "2.1.3" +flake8-bugbear = "^24.12.12" +docformatter = "^1.7.5" +isort = "^6.0.1" [tool.poetry.group.docs] optional = true diff --git a/graphdoc/requirements.txt b/graphdoc/requirements.txt index 552b147..0de8417 100644 --- a/graphdoc/requirements.txt +++ b/graphdoc/requirements.txt @@ -61,6 +61,7 @@ idna==3.10 ; python_full_version == "3.13.0" imagesize==1.4.1 ; python_full_version == "3.13.0" importlib-metadata==8.6.1 ; python_full_version == "3.13.0" iniconfig==2.0.0 ; python_full_version == "3.13.0" +isort==6.0.1 ; python_full_version == "3.13.0" itsdangerous==2.2.0 ; python_full_version == "3.13.0" jinja2==3.1.5 ; python_full_version == "3.13.0" jiter==0.8.2 ; python_full_version == "3.13.0" diff --git a/graphdoc/run.sh b/graphdoc/run.sh index 26c4f46..60e9ee2 100755 --- a/graphdoc/run.sh +++ b/graphdoc/run.sh @@ -14,7 +14,7 @@ install_command() { } dev_command() { - poetry install --with dev + poetry install --with dev,docs } requirements_command() { @@ -68,7 +68,7 @@ docs_generate() { docs() { echo "Building documentation..." - cd docs && make clean html + cd docs && poetry run make clean html echo "Documentation built in docs/_build/html" } diff --git a/graphdoc/runners/train/single_prompt_trainer.py b/graphdoc/runners/train/single_prompt_trainer.py index 7383d5b..f50c146 100644 --- a/graphdoc/runners/train/single_prompt_trainer.py +++ b/graphdoc/runners/train/single_prompt_trainer.py @@ -51,7 +51,7 @@ def main(): report_config = copy.deepcopy(config) report_config["language_model"]["api_key"] = "REDACTED" report_config["data"]["hf_api_key"] = "REDACTED" - report_config["trainer"]["mlflow_tracking_uri"] = "REDACTED" + report_config["mlflow"]["mlflow_tracking_uri"] = "REDACTED" mlflow.log_params(report_config) diff --git a/graphdoc/tests/README.md b/graphdoc/tests/README.md index 9c9c25e..5329b18 100644 --- a/graphdoc/tests/README.md +++ b/graphdoc/tests/README.md @@ -11,8 +11,10 @@ cp .env.example .env # (optional) set environment variables directly in the shell export OPENAI_API_KEY="" -export MLFLOW_TRACKING_URI="" export HF_DATASET_KEY="" +export MLFLOW_TRACKING_URI="" +export MLFLOW_TRACKING_USERNAME="" +export MLFLOW_TRACKING_PASSWORD="" # navigate to the graphdoc package root cd.. diff --git a/graphdoc/tests/assets/configs/single_prompt_doc_generator_module.yaml b/graphdoc/tests/assets/configs/single_prompt_doc_generator_module.yaml index 34aa38e..3c4c84c 100644 --- a/graphdoc/tests/assets/configs/single_prompt_doc_generator_module.yaml +++ b/graphdoc/tests/assets/configs/single_prompt_doc_generator_module.yaml @@ -1,8 +1,5 @@ graphdoc: log_level: INFO # The log level to use (DEBUG, INFO, WARNING, ERROR, CRITICAL) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server - mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server mlflow: mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow @@ -26,6 +23,7 @@ data: evalset_ratio: 0.1 # The proportionate size of the evalset data_helper_type: generation # Type of data helper to use (quality, generation) seed: 42 # The seed for the random number generator + prompt: prompt: base_doc_gen # Which prompt signature to use class: DocGeneratorPrompt # Must be a child of SinglePrompt (we will use an enum to map this) @@ -50,18 +48,15 @@ prompt_metric: trainer: class: DocGeneratorTrainer # The type of trainer to use (DocQualityTrainer) optimizer_type: miprov2 # The type of optimizer to use (miprov2, BootstrapFewShotWithRandomSearch) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow mlflow_model_name: doc_generator_model # The name of the model in MLflow mlflow_experiment_name: doc_generator_experiment # The name of the experiment in MLflow optimizer: optimizer_type: miprov2 # BootstrapFewShotWithRandomSearch, miprov2 auto: light # miprov2 setting - # student: this is the prompt.infer object - # trainset: this is the dataset we are working with - max_labeled_demos: 2 - max_bootstrapped_demos: 4 - num_trials: 2 + max_labeled_demos: 2 # max number of labeled demonstrations + max_bootstrapped_demos: 4 # max number of bootstrapped demonstrations + num_trials: 2 # number of trials minibatch: true # default true module: diff --git a/graphdoc/tests/assets/configs/single_prompt_doc_generator_module_eval.yaml b/graphdoc/tests/assets/configs/single_prompt_doc_generator_module_eval.yaml index e46a538..92a52b2 100644 --- a/graphdoc/tests/assets/configs/single_prompt_doc_generator_module_eval.yaml +++ b/graphdoc/tests/assets/configs/single_prompt_doc_generator_module_eval.yaml @@ -1,11 +1,8 @@ graphdoc: log_level: INFO # The log level to use (DEBUG, INFO, WARNING, ERROR, CRITICAL) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server - mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server mlflow: - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow + mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server @@ -20,12 +17,13 @@ data: load_from_local: true # Whether to load the dataset from a local directory load_local_specific_category: false # Whether to load all categories or a specific category (if load_from_local is true) local_specific_category: perfect # The specific category to load from the dataset (if load_from_local is true) - local_parse_objects: True # Whether to parse the objects in the dataset (if load_from_local is true) - split_for_eval: True # Whether to split the dataset into trainset and evalset + local_parse_objects: false # Whether to parse the objects in the dataset (if load_from_local is true) + split_for_eval: false # Whether to split the dataset into trainset and evalset trainset_size: 10 # The size of the trainset evalset_ratio: 0.1 # The proportionate size of the evalset data_helper_type: generation # Type of data helper to use (quality, generation) seed: 42 # The seed for the random number generator + prompt: prompt: base_doc_gen # Which prompt signature to use class: DocGeneratorPrompt # Must be a child of SinglePrompt (we will use an enum to map this) @@ -50,18 +48,15 @@ prompt_metric: trainer: class: DocGeneratorTrainer # The type of trainer to use (DocQualityTrainer) optimizer_type: miprov2 # The type of optimizer to use (miprov2, BootstrapFewShotWithRandomSearch) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow mlflow_model_name: doc_generator_model # The name of the model in MLflow mlflow_experiment_name: doc_generator_experiment # The name of the experiment in MLflow optimizer: optimizer_type: miprov2 # BootstrapFewShotWithRandomSearch, miprov2 auto: light # miprov2 setting - # student: this is the prompt.infer object - # trainset: this is the dataset we are working with - max_labeled_demos: 2 - max_bootstrapped_demos: 4 - num_trials: 2 + max_labeled_demos: 2 # max number of labeled demonstrations + max_bootstrapped_demos: 4 # max number of bootstrapped demonstrations + num_trials: 2 # number of trials minibatch: true # default true module: @@ -71,8 +66,7 @@ module: fill_empty_descriptions: true # Whether to fill the empty descriptions in the schema eval: - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_experiment_name: doc_generator_eval # The name of the experiment in MLflow - generator_prediction_field: documented_schema # The field in the generator prediction to use + mlflow_experiment_name: doc_generator_eval # The name of the experiment in MLflow + generator_prediction_field: documented_schema # The field in the generator prediction to use evaluator_prediction_field: rating # The field in the evaluator prediction to use - readable_value: 25 # The value to make the ratings readable + readable_value: 25 # The value to make the ratings readable diff --git a/graphdoc/tests/assets/configs/single_prompt_doc_generator_trainer.yaml b/graphdoc/tests/assets/configs/single_prompt_doc_generator_trainer.yaml index 7e94e14..8e4ce24 100644 --- a/graphdoc/tests/assets/configs/single_prompt_doc_generator_trainer.yaml +++ b/graphdoc/tests/assets/configs/single_prompt_doc_generator_trainer.yaml @@ -1,11 +1,8 @@ graphdoc: - log_level: INFO # The log level to use (DEBUG, INFO, WARNING, ERROR, CRITICAL) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server - mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server + log_level: INFO # The log level to use (DEBUG, INFO, WARNING, ERROR, CRITICAL) mlflow: - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow + mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server @@ -22,10 +19,11 @@ data: local_specific_category: perfect # The specific category to load from the dataset (if load_from_local is true) local_parse_objects: True # Whether to parse the objects in the dataset (if load_from_local is true) split_for_eval: True # Whether to split the dataset into trainset and evalset - trainset_size: 1000 # The size of the trainset + trainset_size: 10 # The size of the trainset evalset_ratio: 0.1 # The proportionate size of the evalset data_helper_type: generation # Type of data helper to use (quality, generation) seed: 42 # The seed for the random number generator + prompt: prompt: base_doc_gen # Which prompt signature to use class: DocGeneratorPrompt # Must be a child of SinglePrompt (we will use an enum to map this) @@ -48,18 +46,15 @@ prompt_metric: model_version: null # The version of the model in MLflow trainer: - class: DocGeneratorTrainer # The type of trainer to use (DocQualityTrainer) + class: DocGeneratorTrainer # The type of trainer to use (DocQualityTrainer) optimizer_type: miprov2 # The type of optimizer to use (miprov2, BootstrapFewShotWithRandomSearch) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_model_name: doc_quality_model # The name of the model in MLflow - mlflow_experiment_name: doc_quality_experiment # The name of the experiment in MLflow + mlflow_model_name: doc_generator_model # The name of the model in MLflow + mlflow_experiment_name: doc_generator_experiment # The name of the experiment in MLflow optimizer: optimizer_type: miprov2 # BootstrapFewShotWithRandomSearch, miprov2 - auto: medium # miprov2 setting - # student: this is the prompt.infer object - # trainset: this is the dataset we are working with - max_labeled_demos: 2 - max_bootstrapped_demos: 4 - num_trials: 25 + auto: light # miprov2 setting + max_labeled_demos: 2 # max number of labeled demonstrations + max_bootstrapped_demos: 4 # max number of bootstrapped demonstrations + num_trials: 2 # number of trials minibatch: true # default true \ No newline at end of file diff --git a/graphdoc/tests/assets/configs/single_prompt_doc_quality_trainer.yaml b/graphdoc/tests/assets/configs/single_prompt_doc_quality_trainer.yaml index 463e19f..3450478 100644 --- a/graphdoc/tests/assets/configs/single_prompt_doc_quality_trainer.yaml +++ b/graphdoc/tests/assets/configs/single_prompt_doc_quality_trainer.yaml @@ -1,8 +1,5 @@ graphdoc: log_level: INFO # The log level to use (DEBUG, INFO, WARNING, ERROR, CRITICAL) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server - mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server mlflow: mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow @@ -22,7 +19,7 @@ data: local_specific_category: perfect # The specific category to load from the dataset (if load_from_local is true) local_parse_objects: True # Whether to parse the objects in the dataset (if load_from_local is true) split_for_eval: True # Whether to split the dataset into trainset and evalset - trainset_size: 1000 # The size of the trainset + trainset_size: 10 # The size of the trainset evalset_ratio: 0.1 # The proportionate size of the evalset data_helper_type: quality # Type of data helper to use (quality, generation) seed: 42 # The seed for the random number generator @@ -51,16 +48,13 @@ prompt_metric: trainer: class: DocQualityTrainer # The type of trainer to use (DocQualityTrainer) optimizer_type: miprov2 # The type of optimizer to use (miprov2, BootstrapFewShotWithRandomSearch) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow mlflow_model_name: doc_quality_model # The name of the model in MLflow mlflow_experiment_name: doc_quality_experiment # The name of the experiment in MLflow optimizer: optimizer_type: miprov2 # BootstrapFewShotWithRandomSearch, miprov2 - auto: medium # miprov2 setting - # student: this is the prompt.infer object - # trainset: this is the dataset we are working with - max_labeled_demos: 2 - max_bootstrapped_demos: 4 - num_trials: 25 + auto: light # miprov2 setting + max_labeled_demos: 2 # max number of labeled demonstrations + max_bootstrapped_demos: 4 # max number of bootstrapped demonstrations + num_trials: 2 # number of trials minibatch: true # default true \ No newline at end of file diff --git a/graphdoc/tests/assets/configs/single_prompt_schema_doc_generator_trainer.yaml b/graphdoc/tests/assets/configs/single_prompt_schema_doc_generator_trainer.yaml deleted file mode 100644 index 942f1b1..0000000 --- a/graphdoc/tests/assets/configs/single_prompt_schema_doc_generator_trainer.yaml +++ /dev/null @@ -1,43 +0,0 @@ -language_model: - lm_model_name: openai/gpt-4o # Must be a valid dspy language model - lm_api_key: !env OPENAI_API_KEY # Must be a valid dspy language model API key - cache: true # Whether to cache the calls to the language model - -mlflow: - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server - mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server - -data: - hf_api_key: !env HF_DATASET_KEY # Must be a valid Hugging Face API key (with permission to access graphdoc) # TODO: we may make this public in the future - load_from_hf: true # Whether to load the dataset from Hugging Face - load_from_local: false # Whether to load the dataset from a local directory - load_local_specific_category: false # Whether to load all categories or a specific category (if load_from_local is true) - local_specific_category: perfect # The specific category to load from the dataset (if load_from_local is true) - local_parse_objects: True # Whether to parse the objects in the dataset (if load_from_local is true) - seed: 42 # The seed for the random number generator - -prompt: - prompt: zero_shot_doc_gen # Which prompt signature to use - class: DocGeneratorPrompt # Must be a child of SinglePrompt (we will use an enum to map this) - type: chain_of_thought # The type of prompt to use (predict, chain_of_thought) - metric: DocQualityPrompt # The type of metric to use (rating, category) - load_from_uri: true # Whether to load the prompt from an MLFlow URI - mlflow_uri: file:///Users/denver/Documents/code/graph/graphdoc/mlruns/513408250948216117/976d330558344c41b30bd1531571de18/artifacts/model # The tracking URI for MLflow - -trainer: - class: DocGeneratorTrainer # The type of trainer to use (DocQualityTrainer) - optimizer_type: miprov2 # The type of optimizer to use (miprov2) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_model_name: doc_generator_model_zero_shot # The name of the model in MLflow - mlflow_experiment_name: doc_generator_experiment_zero_shot # The name of the experiment in MLflow - mlflow_load_model: true # Whether to load the most recent model from MLflow - -optimizer: - optimizer_type: miprov2 - # metric: this is set in the prompt - auto: light # miprov2 setting - # student: this is the prompt.infer object - # trainset: this is the dataset we are working with - max_labeled_demos: 0 - max_bootstrapped_demos: 4 \ No newline at end of file diff --git a/graphdoc/tests/assets/configs/single_prompt_schema_doc_generator_version.yaml b/graphdoc/tests/assets/configs/single_prompt_schema_doc_generator_version.yaml deleted file mode 100644 index 4e93614..0000000 --- a/graphdoc/tests/assets/configs/single_prompt_schema_doc_generator_version.yaml +++ /dev/null @@ -1,45 +0,0 @@ -language_model: - lm_model_name: openai/gpt-4o # Must be a valid dspy language model - lm_api_key: !env OPENAI_API_KEY # Must be a valid dspy language model API key - cache: true # Whether to cache the calls to the language model - -mlflow: - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server - mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server - -data: - hf_api_key: !env HF_DATASET_KEY # Must be a valid Hugging Face API key (with permission to access graphdoc) # TODO: we may make this public in the future - load_from_hf: true # Whether to load the dataset from Hugging Face - load_from_local: false # Whether to load the dataset from a local directory - load_local_specific_category: false # Whether to load all categories or a specific category (if load_from_local is true) - local_specific_category: perfect # The specific category to load from the dataset (if load_from_local is true) - local_parse_objects: True # Whether to parse the objects in the dataset (if load_from_local is true) - seed: 42 # The seed for the random number generator - -prompt: - prompt: zero_shot_doc_gen # Which prompt signature to use - class: DocGeneratorPrompt # Must be a child of SinglePrompt (we will use an enum to map this) - type: chain_of_thought # The type of prompt to use (predict, chain_of_thought) - metric: DocQualityPrompt # The type of metric to use (rating, category) - load_from_uri: true # Whether to load the prompt from an MLFlow URI - mlflow_uri: null # The tracking URI for MLflow - mlflow_model_name: doc_generator_model # The name of the model in MLflow - mlflow_model_version: 1 # The version of the model in MLflow - -trainer: - class: DocGeneratorTrainer # The type of trainer to use (DocQualityTrainer) - optimizer_type: miprov2 # The type of optimizer to use (miprov2) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_model_name: doc_generator_model_zero_shot # The name of the model in MLflow - mlflow_experiment_name: doc_generator_experiment_zero_shot # The name of the experiment in MLflow - mlflow_load_model: true # Whether to load the most recent model from MLflow - -optimizer: - optimizer_type: miprov2 - # metric: this is set in the prompt - auto: light # miprov2 setting - # student: this is the prompt.infer object - # trainset: this is the dataset we are working with - max_labeled_demos: 0 - max_bootstrapped_demos: 4 \ No newline at end of file diff --git a/graphdoc/tests/assets/configs/single_prompt_schema_doc_quality_trainer.yaml b/graphdoc/tests/assets/configs/single_prompt_schema_doc_quality_trainer.yaml deleted file mode 100644 index 708344e..0000000 --- a/graphdoc/tests/assets/configs/single_prompt_schema_doc_quality_trainer.yaml +++ /dev/null @@ -1,42 +0,0 @@ -language_model: - lm_model_name: openai/gpt-4o # Must be a valid dspy language model - lm_api_key: !env OPENAI_API_KEY # Must be a valid dspy language model API key - cache: true # Whether to cache the calls to the language model - -mlflow: - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server - mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server - -data: - hf_api_key: !env HF_DATASET_KEY # Must be a valid Hugging Face API key (with permission to access graphdoc) # TODO: we may make this public in the future - load_from_hf: true # Whether to load the dataset from Hugging Face - load_from_local: false # Whether to load the dataset from a local directory - load_local_specific_category: false # Whether to load all categories or a specific category (if load_from_local is true) - local_specific_category: perfect # The specific category to load from the dataset (if load_from_local is true) - local_parse_objects: True # Whether to parse the objects in the dataset (if load_from_local is true) - seed: 42 # The seed for the random number generator -prompt: - prompt: doc_quality # Which prompt signature to use - class: SchemaDocQualityPrompt # Must be a child of SinglePrompt (we will use an enum to map this) - type: predict # The type of prompt to use (predict, chain_of_thought) - metric: rating # The type of metric to use (rating, category) - load_from_uri: true # Whether to load the prompt from an MLFlow URI - mlflow_uri: file:///Users/denver/Documents/code/graph/graphdoc/mlruns/113281354219570660/639710d056054cdea5c86459f2357df2/artifacts/model # The tracking URI for MLflow - -trainer: - class: DocQualityTrainer # The type of trainer to use (DocQualityTrainer) - optimizer_type: miprov2 # The type of optimizer to use (miprov2) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_model_name: doc_quality_model_zero_shot # The name of the model in MLflow - mlflow_experiment_name: doc_quality_experiment_zero_shot # The name of the experiment in MLflow - mlflow_load_model: true # Whether to load the most recent model from MLflow - -optimizer: - optimizer_type: miprov2 - # metric: this is set in the prompt - auto: light # miprov2 setting - # student: this is the prompt.infer object - # trainset: this is the dataset we are working with - max_labeled_demos: 0 - max_bootstrapped_demos: 4 \ No newline at end of file diff --git a/graphdoc/tests/assets/configs/single_prompt_schema_doc_quality_version.yaml b/graphdoc/tests/assets/configs/single_prompt_schema_doc_quality_version.yaml deleted file mode 100644 index 55ad936..0000000 --- a/graphdoc/tests/assets/configs/single_prompt_schema_doc_quality_version.yaml +++ /dev/null @@ -1,45 +0,0 @@ -language_model: - lm_model_name: openai/gpt-4o # Must be a valid dspy language model - lm_api_key: !env OPENAI_API_KEY # Must be a valid dspy language model API key - cache: true # Whether to cache the calls to the language model - -mlflow: - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server - mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server - -data: - hf_api_key: !env HF_DATASET_KEY # Must be a valid Hugging Face API key (with permission to access graphdoc) # TODO: we may make this public in the future - load_from_hf: true # Whether to load the dataset from Hugging Face - load_from_local: false # Whether to load the dataset from a local directory - load_local_specific_category: false # Whether to load all categories or a specific category (if load_from_local is true) - local_specific_category: perfect # The specific category to load from the dataset (if load_from_local is true) - local_parse_objects: True # Whether to parse the objects in the dataset (if load_from_local is true) - seed: 42 # The seed for the random number generator - -prompt: - prompt: doc_quality # Which prompt signature to use - class: SchemaDocQualityPrompt # Must be a child of SinglePrompt (we will use an enum to map this) - type: predict # The type of prompt to use (predict, chain_of_thought) - metric: rating # The type of metric to use (rating, category) - load_from_uri: true # Whether to load the prompt from an MLFlow URI - mlflow_uri: null # The tracking URI for MLflow - mlflow_model_name: doc_quality_model # The name of the model in MLflow - mlflow_model_version: 1 # The version of the model in MLflow - -trainer: - class: DocQualityTrainer # The type of trainer to use (DocQualityTrainer) - optimizer_type: miprov2 # The type of optimizer to use (miprov2) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_model_name: doc_quality_model_zero_shot # The name of the model in MLflow - mlflow_experiment_name: doc_quality_experiment_zero_shot # The name of the experiment in MLflow - mlflow_load_model: true # Whether to load the most recent model from MLflow - -optimizer: - optimizer_type: miprov2 - # metric: this is set in the prompt - auto: light # miprov2 setting - # student: this is the prompt.infer object - # trainset: this is the dataset we are working with - max_labeled_demos: 0 - max_bootstrapped_demos: 4 \ No newline at end of file diff --git a/graphdoc/tests/assets/configs/single_prompt_trainer.yaml b/graphdoc/tests/assets/configs/single_prompt_trainer.yaml deleted file mode 100644 index 4b71afe..0000000 --- a/graphdoc/tests/assets/configs/single_prompt_trainer.yaml +++ /dev/null @@ -1,43 +0,0 @@ -language_model: - lm_model_name: openai/gpt-4o # Must be a valid dspy language model - lm_api_key: !env OPENAI_API_KEY # Must be a valid dspy language model API key - cache: true # Whether to cache the calls to the language model - -mlflow: - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_tracking_username: !env MLFLOW_TRACKING_USERNAME # The username for the mlflow tracking server - mlflow_tracking_password: !env MLFLOW_TRACKING_PASSWORD # The password for the mlflow tracking server - -data: - hf_api_key: !env HF_DATASET_KEY # Must be a valid Hugging Face API key (with permission to access graphdoc) # TODO: we may make this public in the future - load_from_hf: true # Whether to load the dataset from Hugging Face - load_from_local: false # Whether to load the dataset from a local directory - load_local_specific_category: false # Whether to load all categories or a specific category (if load_from_local is true) - local_specific_category: perfect # The specific category to load from the dataset (if load_from_local is true) - local_parse_objects: True # Whether to parse the objects in the dataset (if load_from_local is true) - seed: 42 # The seed for the random number generator - -prompt: - prompt: doc_quality # Which prompt signature to use - class: SchemaDocQualityPrompt # Must be a child of SinglePrompt (we will use an enum to map this) - type: predict # The type of prompt to use (predict, chain_of_thought) - metric: rating # The type of metric to use (rating, category) - load_from_uri: true # Whether to load the prompt from an MLFlow URI - mlflow_uri: file:///Users/denver/Documents/code/graph/graphdoc/mlruns/113281354219570660/639710d056054cdea5c86459f2357df2/artifacts/model # The tracking URI for MLflow - -trainer: - class: DocQualityTrainer # The type of trainer to use (DocQualityTrainer) - optimizer_type: miprov2 # The type of optimizer to use (miprov2) - mlflow_tracking_uri: !env MLFLOW_TRACKING_URI # The tracking URI for MLflow - mlflow_model_name: doc_quality_model_zero_shot # The name of the model in MLflow - mlflow_experiment_name: doc_quality_experiment_zero_shot # The name of the experiment in MLflow - mlflow_load_model: true # Whether to load the most recent model from MLflow - -optimizer: - optimizer_type: miprov2 - # metric: this is set in the prompt - auto: light # miprov2 setting - # student: this is the prompt.infer object - # trainset: this is the dataset we are working with - max_labeled_demos: 0 - max_bootstrapped_demos: 4 \ No newline at end of file diff --git a/graphdoc/tests/conftest.py b/graphdoc/tests/conftest.py index 2409082..d298825 100644 --- a/graphdoc/tests/conftest.py +++ b/graphdoc/tests/conftest.py @@ -32,6 +32,11 @@ MLRUNS_DIR = ASSETS_DIR / "mlruns" ENV_PATH = TEST_DIR / ".env" +# set the environment variables +os.environ["MLFLOW_TRACKING_URI"] = str(MLRUNS_DIR) +os.environ["MLFLOW_TRACKING_USERNAME"] = "admin" +os.environ["MLFLOW_TRACKING_PASSWORD"] = "password" + # Check if .env file exists if not ENV_PATH.exists(): log.error(f".env file not found at {ENV_PATH}") @@ -123,33 +128,6 @@ def overwrite_ldh() -> LocalDataHelper: ) -# @fixture -# def gd() -> GraphDoc: -# """Fixture for GraphDoc with proper environment setup.""" -# # Ensure environment is set up correctly -# if ENV_PATH.exists(): -# load_dotenv(dotenv_path=ENV_PATH, override=True) -# ensure_env_vars() - -# api_key = os.environ.get("OPENAI_API_KEY") -# mlflow_tracking_username = os.environ.get("MLFLOW_TRACKING_USERNAME") -# mlflow_tracking_password = os.environ.get("MLFLOW_TRACKING_PASSWORD") -# if not api_key: -# log.error("OPENAI_API_KEY still not available after loading .env file") - -# return GraphDoc( -# model_args={ -# "model": "gpt-4o-mini", -# "api_key": api_key, -# "cache": True, -# }, -# mlflow_tracking_uri=MLRUNS_DIR, -# mlflow_tracking_username=mlflow_tracking_username, -# mlflow_tracking_password=mlflow_tracking_password, -# log_level="INFO", -# ) - - @fixture def dqp(): return DocQualityPrompt( @@ -170,3 +148,12 @@ def dgp(): prompt_metric="rating", ), ) + + +@fixture +def mlflow_dict(): + return { + "mlflow_tracking_uri": MLRUNS_DIR, + "mlflow_tracking_username": "admin", + "mlflow_tracking_password": "password", + } diff --git a/graphdoc/tests/data/test_helper.py b/graphdoc/tests/data/test_helper.py index 1f1224d..468d80d 100644 --- a/graphdoc/tests/data/test_helper.py +++ b/graphdoc/tests/data/test_helper.py @@ -50,21 +50,27 @@ def test_load_yaml_config(self): OPENAI_API_KEY = "test" HF_DATASET_KEY = "test" MLFLOW_TRACKING_URI = "test" + MLFLOW_TRACKING_USERNAME = "test" + MLFLOW_TRACKING_PASSWORD = "test" os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY os.environ["HF_DATASET_KEY"] = HF_DATASET_KEY os.environ["MLFLOW_TRACKING_URI"] = MLFLOW_TRACKING_URI - config_path = CONFIG_DIR / "single_prompt_trainer.yaml" + os.environ["MLFLOW_TRACKING_USERNAME"] = MLFLOW_TRACKING_USERNAME + os.environ["MLFLOW_TRACKING_PASSWORD"] = MLFLOW_TRACKING_PASSWORD + config_path = CONFIG_DIR / "single_prompt_doc_quality_trainer.yaml" config = load_yaml_config(str(config_path)) assert config is not None - assert config["language_model"]["lm_api_key"] is not None - assert config["language_model"]["lm_api_key"] == OPENAI_API_KEY + assert config["language_model"]["api_key"] is not None + assert config["language_model"]["api_key"] == OPENAI_API_KEY assert config["data"]["hf_api_key"] is not None assert config["data"]["hf_api_key"] == HF_DATASET_KEY - assert config["trainer"]["mlflow_tracking_uri"] is not None - assert config["trainer"]["mlflow_tracking_uri"] == MLFLOW_TRACKING_URI + assert config["mlflow"]["mlflow_tracking_uri"] is not None + assert config["mlflow"]["mlflow_tracking_uri"] == MLFLOW_TRACKING_URI del os.environ["OPENAI_API_KEY"] del os.environ["HF_DATASET_KEY"] del os.environ["MLFLOW_TRACKING_URI"] + del os.environ["MLFLOW_TRACKING_USERNAME"] + del os.environ["MLFLOW_TRACKING_PASSWORD"] def test_setup_logging(self): root_logger = logging.getLogger() diff --git a/graphdoc/tests/test_config.py b/graphdoc/tests/test_config.py index ac33333..2d23787 100644 --- a/graphdoc/tests/test_config.py +++ b/graphdoc/tests/test_config.py @@ -1,8 +1,9 @@ # Copyright 2025-, Semiotic AI, Inc. # SPDX-License-Identifier: Apache-2.0 -# system packages import logging + +# system packages from pathlib import Path # external packages @@ -83,10 +84,10 @@ def test_trainset_and_evalset_from_yaml(self): config_path = CONFIG_DIR / "single_prompt_doc_quality_trainer.yaml" trainset, evalset = trainset_and_evalset_from_yaml(config_path) assert isinstance(trainset, list) - assert len(trainset) == 900 + assert len(trainset) == 9 assert isinstance(trainset[0], dspy.Example) assert isinstance(evalset, list) - assert len(evalset) == 100 + assert len(evalset) == 1 ############################################################ # prompt tests # @@ -106,7 +107,7 @@ def test_single_prompt_from_dict(self): assert isinstance(generator_prompt, DocGeneratorPrompt) assert isinstance(generator_prompt.prompt_metric, DocQualityPrompt) - def test_single_prompt_by_version_from_dict(self): + def test_single_prompt_by_version_from_dict(self, mlflow_dict): config_path = CONFIG_DIR / "single_prompt_doc_quality_trainer.yaml" config_dict = load_yaml_config(config_path) prompt_dict = config_dict["prompt"] @@ -115,8 +116,7 @@ def test_single_prompt_by_version_from_dict(self): prompt_dict["model_version"] = "1" prompt_dict["type"] = "predict" prompt_metric = prompt_dict["metric"] - mlfow_dict = config_dict["mlflow"] - prompt = single_prompt_from_dict(prompt_dict, prompt_metric, mlfow_dict) + prompt = single_prompt_from_dict(prompt_dict, prompt_metric, mlflow_dict) assert isinstance(prompt, DocQualityPrompt) def test_single_prompt_from_yaml(self):