diff --git a/graphdoc/graphdoc/__init__.py b/graphdoc/graphdoc/__init__.py index 1bcca33..5b90143 100644 --- a/graphdoc/graphdoc/__init__.py +++ b/graphdoc/graphdoc/__init__.py @@ -22,7 +22,6 @@ setup_logging, ) from graphdoc.eval import DocGeneratorEvaluator -from graphdoc.main import GraphDoc from graphdoc.modules import DocGeneratorModule from graphdoc.prompts import ( BadDocGeneratorSignature, @@ -45,7 +44,6 @@ ) __all__ = [ - "GraphDoc", "DocGeneratorModule", "DocGeneratorEvaluator", "DocGeneratorTrainer", diff --git a/graphdoc/graphdoc/main.py b/graphdoc/graphdoc/main.py index ea2d93f..a4b534e 100644 --- a/graphdoc/graphdoc/main.py +++ b/graphdoc/graphdoc/main.py @@ -1,9 +1,12 @@ # Copyright 2025-, Semiotic AI, Inc. # SPDX-License-Identifier: Apache-2.0 -# system packages +import argparse import logging import random + +# system packages +import sys from pathlib import Path from typing import List, Literal, Optional, Union @@ -633,3 +636,147 @@ def doc_generator_eval_from_yaml( evaluator_prediction_field=evaluator_prediction_field, readable_value=readable_value, ) + + +####################### +# Main Entry Point # +####################### +"""Run GraphDoc as a command-line application. + +This module can be run directly to train models, generate documentation, +or evaluate documentation quality. + +Usage: + python -m graphdoc.main --config CONFIG_FILE [--log-level LEVEL] COMMAND [ARGS] + +Global Arguments: + --config PATH Path to YAML configuration file with GraphDoc + and language model settings + --log-level LEVEL Set logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + Default: INFO + +Commands: + train Train a prompt using a dataset + --trainer-config PATH Path to trainer YAML configuration + + generate Generate documentation for schema files + --module-config PATH Path to module YAML configuration + --input PATH Path to input schema file or directory + --output PATH Path to output file + + evaluate Evaluate documentation quality + --eval-config PATH Path to evaluator YAML configuration + +Examples: + # Train a documentation quality model + python -m graphdoc.main \ + --config config.yaml \ + train \ + --trainer-config trainer_config.yaml + + # Generate documentation for schemas + python -m graphdoc.main \ + --config config.yaml \ + generate \ + --module-config module_config.yaml \ + --input schema.graphql \ + --output documented_schema.graphql + + # Evaluate documentation quality + python -m graphdoc.main \ + --config config.yaml \ + evaluate \ + --eval-config eval_config.yaml + +Configuration: + See example YAML files in the documentation for format details. +""" # noqa: B950 +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="GraphDoc - Documentation Generator") + parser.add_argument("--config", type=str, help="Path to YAML configuration file") + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level", + ) + subparsers = parser.add_subparsers(dest="command", help="Command to execute") + + ################### + # train # + ################### + train_parser = subparsers.add_parser("train", help="Train a prompt") + train_parser.add_argument( + "--trainer-config", + type=str, + required=True, + help="Path to trainer YAML configuration", + ) + + ################### + # generate # + ################### + generate_parser = subparsers.add_parser("generate", help="Generate documentation") + generate_parser.add_argument( + "--module-config", + type=str, + required=True, + help="Path to module YAML configuration", + ) + generate_parser.add_argument( + "--input", type=str, required=True, help="Path to input schema file" + ) + generate_parser.add_argument( + "--output", type=str, required=True, help="Path to output schema file" + ) + + ################### + # evaluate # + ################### + eval_parser = subparsers.add_parser( + "evaluate", help="Evaluate documentation quality" + ) + eval_parser.add_argument( + "--eval-config", + type=str, + required=True, + help="Path to evaluator YAML configuration", + ) + + args = parser.parse_args() + if not args.config: + parser.print_help() + sys.exit(1) + + graphdoc = GraphDoc.from_yaml(args.config) + + if args.command == "train": + trainer = graphdoc.single_trainer_from_yaml(args.trainer_config) + trained_prompt = trainer.train() + print( + f"Training complete. Saved to MLflow with name: {trainer.mlflow_model_name}" + ) + + elif args.command == "generate": + module = graphdoc.doc_generator_module_from_yaml(args.module_config) + + with open(args.input, "r") as f: + schema = f.read() + + documented_schema = module.document_full_schema(schema) + + with open(args.output, "w") as f: + f.write(documented_schema.documented_schema) + print(f"Generation complete. Documentation saved to {args.output}") + + elif args.command == "evaluate": + evaluator = graphdoc.doc_generator_eval_from_yaml(args.eval_config) + results = evaluator.evaluate() + print( + "Evaluation complete. Results saved to MLflow experiment: " + f"{evaluator.mlflow_experiment_name}" + ) + else: + parser.print_help() diff --git a/graphdoc/graphdoc/modules/doc_generator_module.py b/graphdoc/graphdoc/modules/doc_generator_module.py index 7acc132..1b27f42 100644 --- a/graphdoc/graphdoc/modules/doc_generator_module.py +++ b/graphdoc/graphdoc/modules/doc_generator_module.py @@ -64,6 +64,45 @@ def __init__( ) self.prompt.prompt_metric.prompt_metric = "rating" + ####################### + # MLFLOW TRACING # + ####################### + # TODO: we will break this out into a separate class later + # when we have need for it elsewhere + def _start_trace( + self, + client: mlflow.MlflowClient, + expirement_name: str, + trace_name: str, + inputs: dict, + attributes: dict, + ): + # set the experiment name so that everything is logged to the same experiment + mlflow.set_experiment(expirement_name) + + # start the trace + trace = client.start_trace( + name=trace_name, + inputs=inputs, + attributes=attributes, + # experiment_id=expirement_name, + ) + + return trace + + def _end_trace( + self, + client: mlflow.MlflowClient, + trace: Any, # TODO: trace: mlflow.Span, + # E AttributeError: module 'mlflow' has no attribute 'Span' + outputs: dict, + status: Literal["OK", "ERROR"], + ): + client.end_trace(request_id=trace.request_id, outputs=outputs, status=status) + + ####################### + # MODULE FUNCTIONS # + ####################### def _retry_by_rating(self, database_schema: str) -> str: """Retry the generation if the quality check fails. Rating threshold is determined at initialization. @@ -211,42 +250,6 @@ def forward(self, database_schema: str) -> dspy.Prediction: else: return self._predict(database_schema=database_schema) - ####################### - # MLFLOW TRACING # - ####################### - # TODO: we will break this out into a separate class later - # when we have need for it elsewhere - def _start_trace( - self, - client: mlflow.MlflowClient, - expirement_name: str, - trace_name: str, - inputs: dict, - attributes: dict, - ): - # set the experiment name so that everything is logged to the same experiment - mlflow.set_experiment(expirement_name) - - # start the trace - trace = client.start_trace( - name=trace_name, - inputs=inputs, - attributes=attributes, - # experiment_id=expirement_name, - ) - - return trace - - def _end_trace( - self, - client: mlflow.MlflowClient, - trace: Any, # TODO: trace: mlflow.Span, - # E AttributeError: module 'mlflow' has no attribute 'Span' - outputs: dict, - status: Literal["OK", "ERROR"], - ): - client.end_trace(request_id=trace.request_id, outputs=outputs, status=status) - def document_full_schema( self, database_schema: str, diff --git a/graphdoc/runners/eval/eval_doc_generator_module.py b/graphdoc/runners/eval/eval_doc_generator_module.py index 16c17b9..08e760e 100644 --- a/graphdoc/runners/eval/eval_doc_generator_module.py +++ b/graphdoc/runners/eval/eval_doc_generator_module.py @@ -11,7 +11,7 @@ from dotenv import load_dotenv # internal packages -from graphdoc import GraphDoc +from graphdoc.main import GraphDoc # logging log = logging.getLogger(__name__) diff --git a/graphdoc/runners/train/single_prompt_trainer.py b/graphdoc/runners/train/single_prompt_trainer.py index af84f5a..6c4fe64 100644 --- a/graphdoc/runners/train/single_prompt_trainer.py +++ b/graphdoc/runners/train/single_prompt_trainer.py @@ -14,7 +14,7 @@ import mlflow from dotenv import load_dotenv -from graphdoc import GraphDoc, load_yaml_config +from graphdoc.main import GraphDoc, load_yaml_config # logging log = logging.getLogger(__name__) diff --git a/graphdoc/tests/conftest.py b/graphdoc/tests/conftest.py index 359554b..bea0ad1 100644 --- a/graphdoc/tests/conftest.py +++ b/graphdoc/tests/conftest.py @@ -17,10 +17,10 @@ from graphdoc import ( DocGeneratorPrompt, DocQualityPrompt, - GraphDoc, LocalDataHelper, Parser, ) +from graphdoc.main import GraphDoc # logging log = logging.getLogger(__name__) diff --git a/graphdoc/tests/test_confest.py b/graphdoc/tests/test_confest.py index 90c0be5..9225adf 100644 --- a/graphdoc/tests/test_confest.py +++ b/graphdoc/tests/test_confest.py @@ -7,10 +7,10 @@ from graphdoc import ( DocGeneratorPrompt, DocQualityPrompt, - GraphDoc, LocalDataHelper, Parser, ) +from graphdoc.main import GraphDoc from .conftest import ( OverwriteSchemaCategory, diff --git a/graphdoc/tests/test_graphdoc.py b/graphdoc/tests/test_graphdoc.py index 3218385..3484761 100644 --- a/graphdoc/tests/test_graphdoc.py +++ b/graphdoc/tests/test_graphdoc.py @@ -16,10 +16,10 @@ DocGeneratorTrainer, DocQualityPrompt, DocQualityTrainer, - GraphDoc, SinglePromptTrainer, load_yaml_config, ) +from graphdoc.main import GraphDoc # logging log = logging.getLogger(__name__)