diff --git a/graphdoc/graphdoc/__init__.py b/graphdoc/graphdoc/__init__.py index 22d85aa..221013a 100644 --- a/graphdoc/graphdoc/__init__.py +++ b/graphdoc/graphdoc/__init__.py @@ -37,7 +37,7 @@ setup_logging, ) from graphdoc.eval import DocGeneratorEvaluator -from graphdoc.modules import DocGeneratorModule +from graphdoc.modules import DocGeneratorModule, TokenTracker from graphdoc.prompts import ( BadDocGeneratorSignature, DocGeneratorHelperSignature, @@ -60,6 +60,7 @@ __all__ = [ "DocGeneratorModule", + "TokenTracker", "DocGeneratorEvaluator", "DocGeneratorTrainer", "DocQualityTrainer", diff --git a/graphdoc/graphdoc/config.py b/graphdoc/graphdoc/config.py index afe813a..5a30493 100644 --- a/graphdoc/graphdoc/config.py +++ b/graphdoc/graphdoc/config.py @@ -34,6 +34,50 @@ ####################### +def lm_from_dict(lm_config: dict): + """Load a language model from a dictionary of parameters. + + :param lm_config: Dictionary containing language model parameters. + :type lm_config: dict + + """ + return dspy.LM(**lm_config) + + +def lm_from_yaml(yaml_path: Union[str, Path]): + """Load a language model from a YAML file. + + :param lm_config: Dictionary containing language model parameters. + :type lm_config: dict + + """ + config = load_yaml_config(yaml_path) + return lm_from_dict(config["language_model"]) + + +def dspy_lm_from_dict(lm_config: dict): + """Load a language model from a dictionary of parameters. Set the dspy language + model. + + :param lm_config: Dictionary containing language model parameters. + :type lm_config: dict + + """ + lm = lm_from_dict(lm_config) + dspy.configure(lm=lm) + + +def dspy_lm_from_yaml(yaml_path: Union[str, Path]): + """Load a language model from a YAML file. Set the dspy language model. + + :param lm_config: Dictionary containing language model parameters. + :type lm_config: dict + + """ + config = load_yaml_config(yaml_path) + dspy_lm_from_dict(config["language_model"]) + + def mlflow_data_helper_from_dict(mlflow_config: dict) -> MlflowDataHelper: """Load a MLflow data helper from a dictionary of parameters. @@ -354,6 +398,9 @@ def single_prompt_from_yaml(yaml_path: Union[str, Path]) -> SinglePrompt: :rtype: SinglePrompt """ + # set the dspy language model + dspy_lm_from_yaml(yaml_path) + config = load_yaml_config(yaml_path) mlflow_config = config.get("mlflow", None) if config["prompt"]["prompt_metric"]: @@ -471,6 +518,9 @@ def single_trainer_from_yaml(yaml_path: Union[str, Path]) -> SinglePromptTrainer :rtype: SinglePromptTrainer """ + # set the dspy language model + dspy_lm_from_yaml(yaml_path) + try: config = load_yaml_config(yaml_path) prompt = single_prompt_from_yaml(yaml_path) @@ -564,6 +614,9 @@ def doc_generator_module_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorM :rtype: DocGeneratorModule """ + # set the dspy language model + dspy_lm_from_yaml(yaml_path) + config = load_yaml_config(yaml_path)["module"] prompt = single_prompt_from_yaml(yaml_path) return doc_generator_module_from_dict(config, prompt) @@ -622,6 +675,9 @@ def doc_generator_eval_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorEva :rtype: DocGeneratorEvaluator """ # noqa: B950 + # set the dspy language model + dspy_lm_from_yaml(yaml_path) + # load the generator generator = doc_generator_module_from_yaml(yaml_path) config = load_yaml_config(yaml_path) @@ -630,9 +686,10 @@ def doc_generator_eval_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorEva metric_config = config["prompt_metric"] evaluator = single_prompt_from_dict(metric_config, metric_config["metric"]) - # load the eval config + # load the mlflow data helper mdh = mlflow_data_helper_from_yaml(yaml_path) # noqa: F841 - mlflow_tracking_uri = config["eval"]["mlflow_tracking_uri"] + + # load the eval config mlflow_experiment_name = config["eval"]["mlflow_experiment_name"] generator_prediction_field = config["eval"]["generator_prediction_field"] evaluator_prediction_field = config["eval"]["evaluator_prediction_field"] @@ -646,7 +703,7 @@ def doc_generator_eval_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorEva generator=generator, evaluator=evaluator, evalset=evalset, - mlflow_tracking_uri=mlflow_tracking_uri, + mlflow_helper=mdh, mlflow_experiment_name=mlflow_experiment_name, generator_prediction_field=generator_prediction_field, evaluator_prediction_field=evaluator_prediction_field, diff --git a/graphdoc/graphdoc/eval/doc_generator_eval.py b/graphdoc/graphdoc/eval/doc_generator_eval.py index ffd189c..76b3a3a 100644 --- a/graphdoc/graphdoc/eval/doc_generator_eval.py +++ b/graphdoc/graphdoc/eval/doc_generator_eval.py @@ -3,7 +3,6 @@ # system packages import logging -from pathlib import Path from typing import Any, List, Union # external packages @@ -30,7 +29,7 @@ def __init__( DocQualityPrompt, SinglePrompt, Any ], # we have type hints, but accept any type for flexibility evalset: Union[List[dspy.Example], Any], - mlflow_tracking_uri: Union[str, Path], + mlflow_helper: MlflowDataHelper, mlflow_experiment_name: str = "doc_generator_eval", generator_prediction_field: str = "documented_schema", evaluator_prediction_field: str = "rating", @@ -47,10 +46,9 @@ def __init__( self.generator = generator self.evaluator = evaluator self.evalset = evalset - self.mlflow_tracking_uri = mlflow_tracking_uri + self.mlflow_helper = mlflow_helper self.generator_prediction_field = generator_prediction_field self.evaluator_prediction_field = evaluator_prediction_field - self.mlflow_helper = MlflowDataHelper(mlflow_tracking_uri) self.mlflow_experiment_name = mlflow_experiment_name self.readable_value = readable_value @@ -58,7 +56,13 @@ def forward(self, database_schema: str) -> dict[str, Any]: """Takes a database schema, documents it, and then evaluates each component and the aggregate.""" # (we assume we are using DocGeneratorModule) - generator_result = self.generator.forward(database_schema) # type: ignore + generator_result = self.generator.document_full_schema( # type: ignore + database_schema=database_schema, + trace=True, + client=self.mlflow_helper.mlflow_client, + expirement_name=self.mlflow_experiment_name, + api_key="temp", + ) # TODO: let's decide if this is how we want to handle this in the future. # Alternatively, we could return the documented schema from forward, # not as a prediction object. @@ -103,7 +107,6 @@ def evaluate(self): """Batches the evaluation set and logs the results to mlflow.""" mlflow.set_experiment(self.mlflow_experiment_name) with mlflow.start_run(): - # evalset = [x.database_schema for x in self.evalset] evaluation_results = self.batch(self.evalset, num_threads=32) avg_overall_rating = sum( [x["overall_rating"] for x in evaluation_results] diff --git a/graphdoc/graphdoc/modules/__init__.py b/graphdoc/graphdoc/modules/__init__.py index 067c1ae..84dc804 100644 --- a/graphdoc/graphdoc/modules/__init__.py +++ b/graphdoc/graphdoc/modules/__init__.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from graphdoc.modules.doc_generator_module import DocGeneratorModule +from graphdoc.modules.token_tracker import TokenTracker __all__ = [ "DocGeneratorModule", + "TokenTracker", ] diff --git a/graphdoc/graphdoc/modules/doc_generator_module.py b/graphdoc/graphdoc/modules/doc_generator_module.py index ee41c5b..200aa13 100644 --- a/graphdoc/graphdoc/modules/doc_generator_module.py +++ b/graphdoc/graphdoc/modules/doc_generator_module.py @@ -3,6 +3,7 @@ # system packages import logging +import queue from typing import Any, Literal, Optional, Union # external packages @@ -12,6 +13,7 @@ # internal packages from graphdoc.data import Parser +from graphdoc.modules.token_tracker import TokenTracker from graphdoc.prompts import DocGeneratorPrompt, SinglePrompt # logging @@ -26,6 +28,7 @@ def __init__( retry_limit: int = 1, rating_threshold: int = 3, fill_empty_descriptions: bool = True, + token_tracker: Optional[TokenTracker] = None, ) -> None: """Initialize the DocGeneratorModule. A module for generating documentation for a given GraphQL schema. Schemas are decomposed and individually used to generate @@ -59,6 +62,7 @@ def __init__( # we should move to a dict like structure for passing in those parameters self.fill_empty_descriptions = fill_empty_descriptions self.par = Parser() + self.token_tracker = TokenTracker() if token_tracker is None else token_tracker # ensure that the doc generator prompt metric is set to rating if self.prompt.prompt_metric.prompt_metric != "rating": @@ -250,11 +254,21 @@ def forward(self, database_schema: str) -> dspy.Prediction: :rtype: dspy.Prediction """ + + def _update_active_tasks(): + with self.token_tracker.callback_lock: + self.token_tracker.active_tasks -= 1 + if self.token_tracker.active_tasks == 0: + self.token_tracker.all_tasks_done.set() + if self.retry: database_schema = self._retry_by_rating(database_schema=database_schema) + _update_active_tasks() return dspy.Prediction(documented_schema=database_schema) else: - return self._predict(database_schema=database_schema) + prediction = self._predict(database_schema=database_schema) + _update_active_tasks() + return prediction def document_full_schema( self, @@ -310,6 +324,10 @@ def document_full_schema( ) log.info("created trace: " + str(root_trace)) + # token tracker details + self.token_tracker.active_tasks = len(examples) + self.token_tracker.all_tasks_done.clear() + # batch generate the documentation documented_examples = self.batch(examples, num_threads=32) document_ast.definitions = tuple( @@ -318,6 +336,26 @@ def document_full_schema( # TODO: we should have better type handling, but we know this works ) + # token tracker details + self.token_tracker.all_tasks_done.wait() + callbacks_during_run = 0 + while True: + try: + data = self.token_tracker.callback_queue.get(timeout=2) + with self.token_tracker.callback_lock: + self.token_tracker.api_call_count += 1 + self.token_tracker.model_name = data.get("model", "unknown") + self.token_tracker.completion_tokens += data.get( + "completion_tokens", 0 + ) + self.token_tracker.prompt_tokens += data.get("prompt_tokens", 0) + self.token_tracker.total_tokens += data.get("total_tokens", 0) + callbacks_during_run += 1 + self.token_tracker.callback_queue.task_done() + except queue.Empty: + log.info("Queue empty after timeout, assuming all callbacks processed") + break + # check that the generated schema matches the original schema if self.par.schema_equality_check(parse(database_schema), document_ast): log.info("Schema equality check passed, returning documented schema") @@ -340,9 +378,16 @@ def document_full_schema( trace=root_trace, # type: ignore # TODO: we should have better type handling, but i believe we will get an # error if root_trace has an issue during the start_trace call - outputs={"documented_schema": return_schema}, + outputs={ + "documented_schema": return_schema, + "token_tracker": self.token_tracker.stats(), + }, status=status, ) log.info("ended trace: " + str(root_trace)) # type: ignore # TODO: we should have better type handling, but we check at the top + + # clear the token tracker + self.token_tracker.clear() + return dspy.Prediction(documented_schema=return_schema) diff --git a/graphdoc/graphdoc/modules/token_tracker.py b/graphdoc/graphdoc/modules/token_tracker.py new file mode 100644 index 0000000..fefbacc --- /dev/null +++ b/graphdoc/graphdoc/modules/token_tracker.py @@ -0,0 +1,72 @@ +# Copyright 2025-, Semiotic AI, Inc. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +# system packages +import queue +import threading + +# external packages +import litellm + +# internal packages + +# logging +log = logging.getLogger(__name__) + + +class TokenTracker: + """A class to track the number of tokens used.""" + + def __init__(self): + self.model_name = "" + self.api_call_count = 0 + self.completion_tokens = 0 + self.prompt_tokens = 0 + self.total_tokens = 0 + self.active_tasks = 0 + self.callback_lock = threading.Lock() + self.callback_queue = queue.Queue() + self.all_tasks_done = threading.Event() + + if self.global_token_callback not in litellm.callbacks: + litellm.callbacks.append(self.global_token_callback) + + def clear(self): + """Clear the token tracker.""" + self.api_call_count = 0 + self.model_name = "" + self.completion_tokens = 0 + self.prompt_tokens = 0 + self.total_tokens = 0 + self.active_tasks = 0 + + def stats(self): + """Get the stats of the token tracker.""" + return { + "model_name": self.model_name, + "api_call_count": self.api_call_count, + "completion_tokens": self.completion_tokens, + "prompt_tokens": self.prompt_tokens, + "total_tokens": self.total_tokens, + } + + def global_token_callback( + self, kwargs, response, start_time, end_time, **callback_kwargs + ): + """A global callback to track the number of tokens used. + + Intended to be used with the litellm ModelResponse object. + + """ + data = { + "model": response.get("model", "unknown"), + "completion_tokens": response.get("usage", {}).get("completion_tokens", 0), + "prompt_tokens": response.get("usage", {}).get("prompt_tokens", 0), + "total_tokens": response.get("usage", {}).get("total_tokens", 0), + } + self.callback_queue.put(data) + log.info( + f"Callback triggered, queued data, thread: {threading.current_thread().name}" + ) diff --git a/graphdoc/graphdoc/prompts/schema_doc_generation.py b/graphdoc/graphdoc/prompts/schema_doc_generation.py index 7636933..6a2f89e 100644 --- a/graphdoc/graphdoc/prompts/schema_doc_generation.py +++ b/graphdoc/graphdoc/prompts/schema_doc_generation.py @@ -23,17 +23,18 @@ ################### class DocGeneratorSignature(dspy.Signature): """ - ### TASK: Given a GraphQL Schema, generate a precise description for the columns of the tables in the database. + ### TASK: + Analyze the provided GraphQL Schema and generate detailed yet concise descriptions for each field within the database tables and enums. ### Requirements: - - Focus solely on confirmed details from the provided schema. - - Keep the description concise and factual. - - Exclude any speculative or additional commentary. - - DO NOT return the phrase "in the { table } table" in your description. - - ### Formatting - - Ensure that the schema maintains proper documentation formatting, as is provided. + - Utilize only the verified information from the schema to ensure accuracy. + - Descriptions should be factual, straightforward, and avoid any speculative language. + - Refrain from using the phrase "in the { table } table" within your descriptions. + - Ensure that the documentation adheres to standard schema formatting without modifying the underlying schema structure. + ### Formatting: + - Maintain consistency with the existing documentation style and structure. + - Focus on clarity and precision to aid developers and system architects in understanding the schema's components effectively. """ # noqa: B950 database_schema: str = dspy.InputField() diff --git a/graphdoc/graphdoc/prompts/schema_doc_quality.py b/graphdoc/graphdoc/prompts/schema_doc_quality.py index 620d6c7..5ab8a69 100644 --- a/graphdoc/graphdoc/prompts/schema_doc_quality.py +++ b/graphdoc/graphdoc/prompts/schema_doc_quality.py @@ -20,15 +20,13 @@ # DSPy Signatures # ################### class DocQualitySignature(dspy.Signature): - """You are evaluating the output of an LLM program, expect hallucinations. Given a GraphQL Schema, evaluate the quality of documentation for that schema and provide a category rating. - - The categories are described as: - - perfect (4): The documentation contains enough information so that the interpretation of the schema and its database content is completely free of ambiguity. - - almost perfect (3): The documentation is almost perfect and free from ambiguity, but there is room for improvement. - - poor but correct (2): The documentation is poor but correct and has room for improvement due to missing information. The documentation is not incorrect. - - incorrect (1): The documentation is incorrect and contains inaccurate or misleading information. Any incorrect information automatically leads to an incorrect rating, even if some correct information is present. - Output a number rating that corresponds to the categories described above. - + """ + You are a documentation quality evaluator specializing in GraphQL schemas. Your task is to assess the quality of documentation provided for a given database schema. Carefully analyze the schema's descriptions for clarity, accuracy, and completeness. Categorize the documentation into one of the following ratings based on your evaluation: + - perfect (4): The documentation is comprehensive and leaves no room for ambiguity in understanding the schema and its database content. + - almost perfect (3): The documentation is clear and mostly free of ambiguity, but there is potential for further improvement. + - poor but correct (2): The documentation is correct but lacks detail, resulting in some ambiguity. It requires enhancement to be more informative. + - incorrect (1): The documentation contains errors or misleading information, regardless of any correct segments present. Such inaccuracies necessitate an incorrect rating. + Provide a step-by-step reasoning to support your evaluation, along with the appropriate category label and numerical rating. """ # noqa: B950 database_schema: str = dspy.InputField()