From 2942abc5c39888696f6896ae5de9e38c361a07bd Mon Sep 17 00:00:00 2001 From: denver Date: Tue, 11 Mar 2025 12:57:57 -0500 Subject: [PATCH 01/14] feat(config): add dspy.LM functionality for proper builds --- graphdoc/graphdoc/config.py | 51 +++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/graphdoc/graphdoc/config.py b/graphdoc/graphdoc/config.py index 9f4dcd0..bb713a0 100644 --- a/graphdoc/graphdoc/config.py +++ b/graphdoc/graphdoc/config.py @@ -33,6 +33,45 @@ # Resource Setup # ####################### +def lm_from_dict(lm_config: dict): + """Load a language model from a dictionary of parameters. + + :param lm_config: Dictionary containing language model parameters. + :type lm_config: dict + + """ + return dspy.LM(**lm_config) + + +def lm_from_yaml(yaml_path: Union[str, Path]): + """Load a language model from a YAML file. + + :param lm_config: Dictionary containing language model parameters. + :type lm_config: dict + + """ + config = load_yaml_config(yaml_path) + return lm_from_dict(config["language_model"]) + +def dspy_lm_from_dict(lm_config: dict): + """Load a language model from a dictionary of parameters. Set the dspy language model. + + :param lm_config: Dictionary containing language model parameters. + :type lm_config: dict + + """ + lm = lm_from_dict(lm_config) + dspy.configure(lm=lm) + +def dspy_lm_from_yaml(yaml_path: Union[str, Path]): + """Load a language model from a YAML file. Set the dspy language model. + + :param lm_config: Dictionary containing language model parameters. + :type lm_config: dict + + """ + config = load_yaml_config(yaml_path) + dspy_lm_from_dict(config["language_model"]) def mlflow_data_helper_from_dict(mlflow_config: dict) -> MlflowDataHelper: """Load a mlflow data helper from a dictionary of parameters. @@ -342,6 +381,9 @@ def single_prompt_from_yaml(yaml_path: Union[str, Path]) -> SinglePrompt: :rtype: SinglePrompt """ + # set the dspy language model + dspy_lm_from_yaml(yaml_path) + config = load_yaml_config(yaml_path) mlflow_config = config.get("mlflow", None) if config["prompt"]["prompt_metric"]: @@ -472,6 +514,9 @@ def single_trainer_from_yaml(yaml_path: Union[str, Path]) -> SinglePromptTrainer :rtype: SinglePromptTrainer """ + # set the dspy language model + dspy_lm_from_yaml(yaml_path) + try: config = load_yaml_config(yaml_path) prompt = single_prompt_from_yaml(yaml_path) @@ -563,6 +608,9 @@ def doc_generator_module_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorM :rtype: DocGeneratorModule """ + # set the dspy language model + dspy_lm_from_yaml(yaml_path) + config = load_yaml_config(yaml_path)["module"] prompt = single_prompt_from_yaml(yaml_path) return doc_generator_module_from_dict(config, prompt) @@ -620,6 +668,9 @@ def doc_generator_eval_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorEva :rtype: DocGeneratorEvaluator """ # noqa: B950 + # set the dspy language model + dspy_lm_from_yaml(yaml_path) + # load the generator generator = doc_generator_module_from_yaml(yaml_path) config = load_yaml_config(yaml_path) From 3588694eb94e7dea4f9e10eb08ba07fcc398b820 Mon Sep 17 00:00:00 2001 From: denver Date: Thu, 13 Mar 2025 21:55:59 -0500 Subject: [PATCH 02/14] docs: add temporary scratch_queue.py file demonstrating handling callback with litellm --- graphdoc/scratch_queue.py | 127 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 graphdoc/scratch_queue.py diff --git a/graphdoc/scratch_queue.py b/graphdoc/scratch_queue.py new file mode 100644 index 0000000..f39d117 --- /dev/null +++ b/graphdoc/scratch_queue.py @@ -0,0 +1,127 @@ +from datetime import datetime +import os +import random +import threading +from dotenv import load_dotenv +import litellm +import uuid +import logging +import dspy +import concurrent.futures +from graphdoc.config import dspy_lm_from_yaml +import queue + +logger = logging.getLogger(__name__) + +# global variables +api_call_count = 0 +model_name = "gpt-4o-2024-08-06" +completion_tokens = 0 +prompt_tokens = 0 +total_tokens = 0 +callback_lock = threading.Lock() + +callback_queue = queue.Queue() +all_tasks_done = threading.Event() + +def global_token_callback(kwargs, response, start_time, end_time, **callback_kwargs): + data = { + "model": response.get("model", "unknown"), + "completion_tokens": response.get("usage", {}).get("completion_tokens", 0), + "prompt_tokens": response.get("usage", {}).get("prompt_tokens", 0), + "total_tokens": response.get("usage", {}).get("total_tokens", 0), + } + callback_queue.put(data) + logger.info(f"Callback triggered, queued data, thread: {threading.current_thread().name}") + +def math_chain(task_id, active_tasks): + math = dspy.Predict("question -> answer: float") + num_requests = random.randint(1, 3) + logger.info(f"math_chain: num_requests: {num_requests}") + for i in range(num_requests): + value = random.randint(1, 10) + result = math(question=f"What is 2+{value}? ID: {uuid.uuid4()}") + logger.info(f"Task {task_id}, Request {i+1}: Result: {result.answer}") + with callback_lock: + active_tasks[0] -= 1 + if active_tasks[0] == 0: + all_tasks_done.set() + logger.info(f"Task {task_id} completed, remaining active tasks: {active_tasks[0]}") + +def math_chain_multi(): + global api_call_count, model_name, completion_tokens, prompt_tokens, total_tokens + + with callback_lock: + start_count = api_call_count + logger.info(f"math_chain_multi started, initial api_call_count: {start_count}") + + num_tasks = random.randint(3, 7) + logger.info(f"math_chain_multi: num_tasks: {num_tasks}") + active_tasks = [num_tasks] + all_tasks_done.clear() + + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + futures = [executor.submit(math_chain, i, active_tasks) for i in range(num_tasks)] + logger.info(f"total futures: {len(futures)}") + for future in concurrent.futures.as_completed(futures): + future.result() + + all_tasks_done.wait() + logger.info("All tasks completed, now draining callback queue...") + + callbacks_during_run = 0 + while True: + try: + data = callback_queue.get(timeout=2) + with callback_lock: + api_call_count += 1 + model_name = data["model"] + completion_tokens += data["completion_tokens"] + prompt_tokens += data["prompt_tokens"] + total_tokens += data["total_tokens"] + callbacks_during_run += 1 + callback_queue.task_done() + except queue.Empty: + logger.info("Queue empty after timeout, assuming all callbacks processed") + break + + logger.info(f"math_chain_multi: model_name: {model_name}") + logger.info(f"math_chain_multi: completion_tokens: {completion_tokens}") + logger.info(f"math_chain_multi: prompt_tokens: {prompt_tokens}") + logger.info(f"math_chain_multi: total_tokens: {total_tokens}") + logger.info(f"math_chain_multi: total api_call_count: {api_call_count}") + logger.info(f"math_chain_multi: callbacks during this run: {callbacks_during_run}") + +def main(): + print("hello, world!") + load_dotenv("../.env") + if global_token_callback not in litellm.callbacks: + litellm.callbacks.append(global_token_callback) + + log_dir = "logs" + os.makedirs(log_dir, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = os.path.join(log_dir, f"run_{timestamp}.log") + + root_logger = logging.getLogger() + root_logger.setLevel(logging.DEBUG) + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + root_logger.addHandler(file_handler) + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(logging.Formatter('%(levelname)s - %(message)s')) + root_logger.addHandler(console_handler) + + print(f"logging to file: {log_file}") + logger.info("logging initialized") + os.environ["LITELLM_LOG"] = "DEBUG" + + config_path = "/Users/denver/Documents/code/graph/graphdoc-mono/graphdoc/graphdoc/assets/configs/single_prompt_doc_generator_module_eval.yaml" + dspy_lm_from_yaml(config_path) + + math_chain_multi() + +if __name__ == "__main__": + main() \ No newline at end of file From 6779d19f1678684ad70cfed95ef5a1804a09b7be Mon Sep 17 00:00:00 2001 From: denver Date: Thu, 13 Mar 2025 22:24:58 -0500 Subject: [PATCH 03/14] refactor(prompts): update base prompt to match prior optimization runs --- .../graphdoc/prompts/schema_doc_generation.py | 23 ++++++++++--------- .../graphdoc/prompts/schema_doc_quality.py | 15 +++++------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/graphdoc/graphdoc/prompts/schema_doc_generation.py b/graphdoc/graphdoc/prompts/schema_doc_generation.py index ef70ff5..1d2d967 100644 --- a/graphdoc/graphdoc/prompts/schema_doc_generation.py +++ b/graphdoc/graphdoc/prompts/schema_doc_generation.py @@ -23,17 +23,18 @@ ################### class DocGeneratorSignature(dspy.Signature): """ - ### TASK: Given a GraphQL Schema, generate a precise description for the columns of the tables in the database. - - ### Requirements: - - Focus solely on confirmed details from the provided schema. - - Keep the description concise and factual. - - Exclude any speculative or additional commentary. - - DO NOT return the phrase "in the { table } table" in your description. - - ### Formatting - - Ensure that the schema maintains proper documentation formatting, as is provided. - + ### TASK: + Analyze the provided GraphQL Schema and generate detailed yet concise descriptions for each field within the database tables and enums. + + ### Requirements: + - Utilize only the verified information from the schema to ensure accuracy. + - Descriptions should be factual, straightforward, and avoid any speculative language. + - Refrain from using the phrase "in the { table } table" within your descriptions. + - Ensure that the documentation adheres to standard schema formatting without modifying the underlying schema structure. + + ### Formatting: + - Maintain consistency with the existing documentation style and structure. + - Focus on clarity and precision to aid developers and system architects in understanding the schema's components effectively. """ # noqa: B950 database_schema: str = dspy.InputField() diff --git a/graphdoc/graphdoc/prompts/schema_doc_quality.py b/graphdoc/graphdoc/prompts/schema_doc_quality.py index a975196..44fef03 100644 --- a/graphdoc/graphdoc/prompts/schema_doc_quality.py +++ b/graphdoc/graphdoc/prompts/schema_doc_quality.py @@ -21,15 +21,12 @@ ################### class DocQualitySignature(dspy.Signature): """ - You are evaluating the output of an LLM program, expect hallucinations. Given a GraphQL Schema, evaluate the quality of documentation for that schema and provide a category rating. - - The categories are described as: - - perfect (4): The documentation contains enough information so that the interpretation of the schema and its database content is completely free of ambiguity. - - almost perfect (3): The documentation is almost perfect and free from ambiguity, but there is room for improvement. - - poor but correct (2): The documentation is poor but correct and has room for improvement due to missing information. The documentation is not incorrect. - - incorrect (1): The documentation is incorrect and contains inaccurate or misleading information. Any incorrect information automatically leads to an incorrect rating, even if some correct information is present. - Output a number rating that corresponds to the categories described above. - + You are a documentation quality evaluator specializing in GraphQL schemas. Your task is to assess the quality of documentation provided for a given database schema. Carefully analyze the schema's descriptions for clarity, accuracy, and completeness. Categorize the documentation into one of the following ratings based on your evaluation: + - perfect (4): The documentation is comprehensive and leaves no room for ambiguity in understanding the schema and its database content. + - almost perfect (3): The documentation is clear and mostly free of ambiguity, but there is potential for further improvement. + - poor but correct (2): The documentation is correct but lacks detail, resulting in some ambiguity. It requires enhancement to be more informative. + - incorrect (1): The documentation contains errors or misleading information, regardless of any correct segments present. Such inaccuracies necessitate an incorrect rating. + Provide a step-by-step reasoning to support your evaluation, along with the appropriate category label and numerical rating. """ # noqa: B950 database_schema: str = dspy.InputField() From cda722dc0331fa5cb2b2e6c9d735311231b0f361 Mon Sep 17 00:00:00 2001 From: denver Date: Fri, 14 Mar 2025 10:13:14 -0500 Subject: [PATCH 04/14] style(prompts): remove unused doc generator signature indent --- .../graphdoc/prompts/schema_doc_generation.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/graphdoc/graphdoc/prompts/schema_doc_generation.py b/graphdoc/graphdoc/prompts/schema_doc_generation.py index 1d2d967..c3b4431 100644 --- a/graphdoc/graphdoc/prompts/schema_doc_generation.py +++ b/graphdoc/graphdoc/prompts/schema_doc_generation.py @@ -23,18 +23,18 @@ ################### class DocGeneratorSignature(dspy.Signature): """ - ### TASK: - Analyze the provided GraphQL Schema and generate detailed yet concise descriptions for each field within the database tables and enums. - - ### Requirements: - - Utilize only the verified information from the schema to ensure accuracy. - - Descriptions should be factual, straightforward, and avoid any speculative language. - - Refrain from using the phrase "in the { table } table" within your descriptions. - - Ensure that the documentation adheres to standard schema formatting without modifying the underlying schema structure. - - ### Formatting: - - Maintain consistency with the existing documentation style and structure. - - Focus on clarity and precision to aid developers and system architects in understanding the schema's components effectively. + ### TASK: + Analyze the provided GraphQL Schema and generate detailed yet concise descriptions for each field within the database tables and enums. + + ### Requirements: + - Utilize only the verified information from the schema to ensure accuracy. + - Descriptions should be factual, straightforward, and avoid any speculative language. + - Refrain from using the phrase "in the { table } table" within your descriptions. + - Ensure that the documentation adheres to standard schema formatting without modifying the underlying schema structure. + + ### Formatting: + - Maintain consistency with the existing documentation style and structure. + - Focus on clarity and precision to aid developers and system architects in understanding the schema's components effectively. """ # noqa: B950 database_schema: str = dspy.InputField() From 359c9999686ceb22f8451109e0e96973c1aa5554 Mon Sep 17 00:00:00 2001 From: denver Date: Fri, 14 Mar 2025 12:30:22 -0500 Subject: [PATCH 05/14] feat(module): add a token tracker class for managing mlflow callbacks --- graphdoc/graphdoc/modules/token_tracker.py | 54 ++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 graphdoc/graphdoc/modules/token_tracker.py diff --git a/graphdoc/graphdoc/modules/token_tracker.py b/graphdoc/graphdoc/modules/token_tracker.py new file mode 100644 index 0000000..a37c2b2 --- /dev/null +++ b/graphdoc/graphdoc/modules/token_tracker.py @@ -0,0 +1,54 @@ +# Copyright 2025-, Semiotic AI, Inc. +# SPDX-License-Identifier: Apache-2.0 + +# system packages +import queue +import logging +import threading + +# external packages + +# internal packages + +# logging +log = logging.getLogger(__name__) + +class TokenTracker: + + def __init__(self): + self.model_name = "" + self.api_call_count = 0 + self.completion_tokens = 0 + self.prompt_tokens = 0 + self.total_tokens = 0 + self.active_tasks = 0 + self.callback_lock = threading.Lock() + self.callback_queue = queue.Queue() + self.all_tasks_done = threading.Event() + + def clear(self): + self.api_call_count = 0 + self.model_name = "" + self.completion_tokens = 0 + self.prompt_tokens = 0 + self.total_tokens = 0 + self.active_tasks = 0 + + def stats(self): + return { + "model_name": self.model_name, + "api_call_count": self.api_call_count, + "completion_tokens": self.completion_tokens, + "prompt_tokens": self.prompt_tokens, + "total_tokens": self.total_tokens, + } + + def global_token_callback(self, kwargs, response, start_time, end_time, **callback_kwargs): + data = { + "model": response.get("model", "unknown"), + "completion_tokens": response.get("usage", {}).get("completion_tokens", 0), + "prompt_tokens": response.get("usage", {}).get("prompt_tokens", 0), + "total_tokens": response.get("usage", {}).get("total_tokens", 0), + } + self.callback_queue.put(data) + log.info(f"Callback triggered, queued data, thread: {threading.current_thread().name}") \ No newline at end of file From bf4e8f35100edb5093783ce84b44dea5e7286e6d Mon Sep 17 00:00:00 2001 From: denver Date: Fri, 14 Mar 2025 12:30:54 -0500 Subject: [PATCH 06/14] feat: implements token tracker in the doc generator module --- graphdoc/graphdoc/modules/__init__.py | 2 + .../graphdoc/modules/doc_generator_module.py | 45 ++++++++++++++++--- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/graphdoc/graphdoc/modules/__init__.py b/graphdoc/graphdoc/modules/__init__.py index 067c1ae..50b44ce 100644 --- a/graphdoc/graphdoc/modules/__init__.py +++ b/graphdoc/graphdoc/modules/__init__.py @@ -1,8 +1,10 @@ # Copyright 2025-, Semiotic AI, Inc. # SPDX-License-Identifier: Apache-2.0 +from graphdoc.modules.token_tracker import TokenTracker from graphdoc.modules.doc_generator_module import DocGeneratorModule __all__ = [ "DocGeneratorModule", + "TokenTracker", ] diff --git a/graphdoc/graphdoc/modules/doc_generator_module.py b/graphdoc/graphdoc/modules/doc_generator_module.py index 1b27f42..e7bb0a8 100644 --- a/graphdoc/graphdoc/modules/doc_generator_module.py +++ b/graphdoc/graphdoc/modules/doc_generator_module.py @@ -3,6 +3,7 @@ # system packages import logging +import queue from typing import Any, Literal, Optional, Union # external packages @@ -12,6 +13,7 @@ # internal packages from graphdoc.data import Parser +from graphdoc.modules.token_tracker import TokenTracker from graphdoc.prompts import DocGeneratorPrompt, SinglePrompt # logging @@ -26,6 +28,7 @@ def __init__( retry_limit: int = 1, rating_threshold: int = 3, fill_empty_descriptions: bool = True, + token_tracker: TokenTracker = None, ) -> None: """Initialize the DocGeneratorModule. A module for generating documentation for a given GraphQL schema. Schemas are decomposed and individually used to generate @@ -56,6 +59,7 @@ def __init__( # we should move to a dict like structure for passing in those parameters self.fill_empty_descriptions = fill_empty_descriptions self.par = Parser() + self.token_tracker = TokenTracker() if token_tracker is None else token_tracker # ensure that the doc generator prompt metric is set to rating if self.prompt.prompt_metric.prompt_metric != "rating": @@ -244,11 +248,20 @@ def forward(self, database_schema: str) -> dspy.Prediction: dspy.Prediction """ + def _update_active_tasks(): + with self.token_tracker.callback_lock: + self.token_tracker.active_tasks -= 1 + if self.token_tracker.active_tasks == 0: + self.token_tracker.all_tasks_done.set() + if self.retry: database_schema = self._retry_by_rating(database_schema=database_schema) + _update_active_tasks() return dspy.Prediction(documented_schema=database_schema) else: - return self._predict(database_schema=database_schema) + prediction = self._predict(database_schema=database_schema) + _update_active_tasks() + return prediction def document_full_schema( self, @@ -261,9 +274,9 @@ def document_full_schema( """Given a database schema, parse out the underlying components and document on a per-component basis. - :param database_schema: The database schema to generate documentation for. :type - database_schema: str :return: The generated documentation. :rtype: - dspy.Prediction + :param database_schema: The database schema to generate documentation for. + :type database_schema: str :return: The generated documentation. + :rtype: dspy.Prediction """ # if we are tracing, make sure make sure we have everything needed to log to mlflow @@ -303,6 +316,10 @@ def document_full_schema( ) log.info("created trace: " + str(root_trace)) + # token tracker details + self.token_tracker.active_tasks = len(examples) + self.token_tracker.all_tasks_done.clear() + # batch generate the documentation documented_examples = self.batch(examples, num_threads=32) document_ast.definitions = tuple( @@ -311,6 +328,24 @@ def document_full_schema( # TODO: we should have better type handling, but we know this works ) + # token tracker details + self.token_tracker.all_tasks_done.wait() + callbacks_during_run = 0 + while True: + try: + data = self.token_tracker.callback_queue.get(timeout=2) + with self.token_tracker.callback_lock: + self.token_tracker.api_call_count += 1 + self.token_tracker.model_name = data.get("model", "unknown") + self.token_tracker.completion_tokens += data.get("completion_tokens", 0) + self.token_tracker.prompt_tokens += data.get("prompt_tokens", 0) + self.token_tracker.total_tokens += data.get("total_tokens", 0) + callbacks_during_run += 1 + self.token_tracker.callback_queue.task_done() + except queue.Empty: + log.info("Queue empty after timeout, assuming all callbacks processed") + break + # check that the generated schema matches the original schema if self.par.schema_equality_check(parse(database_schema), document_ast): log.info("Schema equality check passed, returning documented schema") @@ -333,7 +368,7 @@ def document_full_schema( trace=root_trace, # type: ignore # TODO: we should have better type handling, but i believe we will get an # error if root_trace has an issue during the start_trace call - outputs={"documented_schema": return_schema}, + outputs={"documented_schema": return_schema, "token_tracker": self.token_tracker.stats()}, status=status, ) log.info("ended trace: " + str(root_trace)) # type: ignore From 6aef125f3e6deb8c426501d1da24ebab3afaaa08 Mon Sep 17 00:00:00 2001 From: denver Date: Fri, 14 Mar 2025 12:31:50 -0500 Subject: [PATCH 07/14] refactor(temp): updates doc gen eval to utilize document_full_schema --- graphdoc/graphdoc/__init__.py | 3 ++- graphdoc/graphdoc/config.py | 6 ++++-- graphdoc/graphdoc/eval/doc_generator_eval.py | 16 ++++++++++++---- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/graphdoc/graphdoc/__init__.py b/graphdoc/graphdoc/__init__.py index 22d85aa..221013a 100644 --- a/graphdoc/graphdoc/__init__.py +++ b/graphdoc/graphdoc/__init__.py @@ -37,7 +37,7 @@ setup_logging, ) from graphdoc.eval import DocGeneratorEvaluator -from graphdoc.modules import DocGeneratorModule +from graphdoc.modules import DocGeneratorModule, TokenTracker from graphdoc.prompts import ( BadDocGeneratorSignature, DocGeneratorHelperSignature, @@ -60,6 +60,7 @@ __all__ = [ "DocGeneratorModule", + "TokenTracker", "DocGeneratorEvaluator", "DocGeneratorTrainer", "DocQualityTrainer", diff --git a/graphdoc/graphdoc/config.py b/graphdoc/graphdoc/config.py index bb713a0..ba489d6 100644 --- a/graphdoc/graphdoc/config.py +++ b/graphdoc/graphdoc/config.py @@ -679,8 +679,10 @@ def doc_generator_eval_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorEva metric_config = config["prompt_metric"] evaluator = single_prompt_from_dict(metric_config, metric_config["metric"]) - # load the eval config + # load the mlflow data helper mdh = mlflow_data_helper_from_yaml(yaml_path) # noqa: F841 + + # load the eval config mlflow_tracking_uri = config["mlflow"]["mlflow_tracking_uri"] mlflow_experiment_name = config["eval"]["mlflow_experiment_name"] generator_prediction_field = config["eval"]["generator_prediction_field"] @@ -695,7 +697,7 @@ def doc_generator_eval_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorEva generator=generator, evaluator=evaluator, evalset=evalset, - mlflow_tracking_uri=mlflow_tracking_uri, + mlflow_helper=mdh, mlflow_experiment_name=mlflow_experiment_name, generator_prediction_field=generator_prediction_field, evaluator_prediction_field=evaluator_prediction_field, diff --git a/graphdoc/graphdoc/eval/doc_generator_eval.py b/graphdoc/graphdoc/eval/doc_generator_eval.py index ffd189c..3785ee9 100644 --- a/graphdoc/graphdoc/eval/doc_generator_eval.py +++ b/graphdoc/graphdoc/eval/doc_generator_eval.py @@ -30,7 +30,8 @@ def __init__( DocQualityPrompt, SinglePrompt, Any ], # we have type hints, but accept any type for flexibility evalset: Union[List[dspy.Example], Any], - mlflow_tracking_uri: Union[str, Path], + # mlflow_tracking_uri: Union[str, Path], + mlflow_helper: MlflowDataHelper, mlflow_experiment_name: str = "doc_generator_eval", generator_prediction_field: str = "documented_schema", evaluator_prediction_field: str = "rating", @@ -47,10 +48,11 @@ def __init__( self.generator = generator self.evaluator = evaluator self.evalset = evalset - self.mlflow_tracking_uri = mlflow_tracking_uri + # self.mlflow_tracking_uri = mlflow_tracking_uri + self.mlflow_helper = mlflow_helper self.generator_prediction_field = generator_prediction_field self.evaluator_prediction_field = evaluator_prediction_field - self.mlflow_helper = MlflowDataHelper(mlflow_tracking_uri) + # self.mlflow_helper = MlflowDataHelper(mlflow_tracking_uri) self.mlflow_experiment_name = mlflow_experiment_name self.readable_value = readable_value @@ -58,7 +60,13 @@ def forward(self, database_schema: str) -> dict[str, Any]: """Takes a database schema, documents it, and then evaluates each component and the aggregate.""" # (we assume we are using DocGeneratorModule) - generator_result = self.generator.forward(database_schema) # type: ignore + generator_result = self.generator.document_full_schema( + database_schema=database_schema, + trace=True, + client=self.mlflow_helper.mlflow_client, + expirement_name=self.mlflow_experiment_name, + api_key="temp", + ) # type: ignore # TODO: let's decide if this is how we want to handle this in the future. # Alternatively, we could return the documented schema from forward, # not as a prediction object. From 3f1290425549243c1e1a59df5093d573f63a9c65 Mon Sep 17 00:00:00 2001 From: denver Date: Fri, 14 Mar 2025 14:03:31 -0500 Subject: [PATCH 08/14] refactor: ensure litellm callback handler is included --- graphdoc/graphdoc/modules/token_tracker.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/graphdoc/graphdoc/modules/token_tracker.py b/graphdoc/graphdoc/modules/token_tracker.py index a37c2b2..23485f2 100644 --- a/graphdoc/graphdoc/modules/token_tracker.py +++ b/graphdoc/graphdoc/modules/token_tracker.py @@ -7,6 +7,7 @@ import threading # external packages +import litellm # internal packages @@ -14,6 +15,9 @@ log = logging.getLogger(__name__) class TokenTracker: + """ + A class to track the number of tokens used. + """ def __init__(self): self.model_name = "" @@ -26,7 +30,13 @@ def __init__(self): self.callback_queue = queue.Queue() self.all_tasks_done = threading.Event() + if self.global_token_callback not in litellm.callbacks: + litellm.callbacks.append(self.global_token_callback) + def clear(self): + """ + Clear the token tracker. + """ self.api_call_count = 0 self.model_name = "" self.completion_tokens = 0 @@ -35,6 +45,9 @@ def clear(self): self.active_tasks = 0 def stats(self): + """ + Get the stats of the token tracker. + """ return { "model_name": self.model_name, "api_call_count": self.api_call_count, @@ -44,6 +57,9 @@ def stats(self): } def global_token_callback(self, kwargs, response, start_time, end_time, **callback_kwargs): + """ + A global callback to track the number of tokens used. Intended to be used with the litellm ModelResponse object. + """ data = { "model": response.get("model", "unknown"), "completion_tokens": response.get("usage", {}).get("completion_tokens", 0), From efda95462e02bac2c9eeba820ae1b42e29c06b95 Mon Sep 17 00:00:00 2001 From: denver Date: Fri, 14 Mar 2025 14:04:22 -0500 Subject: [PATCH 09/14] refactor: clear token tracker variables at the end of execution --- graphdoc/graphdoc/modules/doc_generator_module.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/graphdoc/graphdoc/modules/doc_generator_module.py b/graphdoc/graphdoc/modules/doc_generator_module.py index e7bb0a8..c541118 100644 --- a/graphdoc/graphdoc/modules/doc_generator_module.py +++ b/graphdoc/graphdoc/modules/doc_generator_module.py @@ -373,4 +373,8 @@ def document_full_schema( ) log.info("ended trace: " + str(root_trace)) # type: ignore # TODO: we should have better type handling, but we check at the top + + # clear the token tracker + self.token_tracker.clear() + return dspy.Prediction(documented_schema=return_schema) From 85a34689790504903b67d5acac6c16f13e9fae23 Mon Sep 17 00:00:00 2001 From: denver Date: Fri, 14 Mar 2025 14:10:02 -0500 Subject: [PATCH 10/14] style: format with black and other built in checks --- graphdoc/graphdoc/config.py | 7 ++- graphdoc/graphdoc/eval/doc_generator_eval.py | 2 +- graphdoc/graphdoc/modules/__init__.py | 2 +- .../graphdoc/modules/doc_generator_module.py | 23 ++++++--- graphdoc/graphdoc/modules/token_tracker.py | 44 +++++++++-------- .../graphdoc/prompts/schema_doc_generation.py | 8 +-- .../graphdoc/prompts/schema_doc_quality.py | 16 +++--- graphdoc/scratch_queue.py | 49 ++++++++++++------- 8 files changed, 89 insertions(+), 62 deletions(-) diff --git a/graphdoc/graphdoc/config.py b/graphdoc/graphdoc/config.py index ba489d6..225dd62 100644 --- a/graphdoc/graphdoc/config.py +++ b/graphdoc/graphdoc/config.py @@ -33,6 +33,7 @@ # Resource Setup # ####################### + def lm_from_dict(lm_config: dict): """Load a language model from a dictionary of parameters. @@ -53,8 +54,10 @@ def lm_from_yaml(yaml_path: Union[str, Path]): config = load_yaml_config(yaml_path) return lm_from_dict(config["language_model"]) + def dspy_lm_from_dict(lm_config: dict): - """Load a language model from a dictionary of parameters. Set the dspy language model. + """Load a language model from a dictionary of parameters. Set the dspy language + model. :param lm_config: Dictionary containing language model parameters. :type lm_config: dict @@ -63,6 +66,7 @@ def dspy_lm_from_dict(lm_config: dict): lm = lm_from_dict(lm_config) dspy.configure(lm=lm) + def dspy_lm_from_yaml(yaml_path: Union[str, Path]): """Load a language model from a YAML file. Set the dspy language model. @@ -73,6 +77,7 @@ def dspy_lm_from_yaml(yaml_path: Union[str, Path]): config = load_yaml_config(yaml_path) dspy_lm_from_dict(config["language_model"]) + def mlflow_data_helper_from_dict(mlflow_config: dict) -> MlflowDataHelper: """Load a mlflow data helper from a dictionary of parameters. diff --git a/graphdoc/graphdoc/eval/doc_generator_eval.py b/graphdoc/graphdoc/eval/doc_generator_eval.py index 3785ee9..cd08a45 100644 --- a/graphdoc/graphdoc/eval/doc_generator_eval.py +++ b/graphdoc/graphdoc/eval/doc_generator_eval.py @@ -61,7 +61,7 @@ def forward(self, database_schema: str) -> dict[str, Any]: the aggregate.""" # (we assume we are using DocGeneratorModule) generator_result = self.generator.document_full_schema( - database_schema=database_schema, + database_schema=database_schema, trace=True, client=self.mlflow_helper.mlflow_client, expirement_name=self.mlflow_experiment_name, diff --git a/graphdoc/graphdoc/modules/__init__.py b/graphdoc/graphdoc/modules/__init__.py index 50b44ce..84dc804 100644 --- a/graphdoc/graphdoc/modules/__init__.py +++ b/graphdoc/graphdoc/modules/__init__.py @@ -1,8 +1,8 @@ # Copyright 2025-, Semiotic AI, Inc. # SPDX-License-Identifier: Apache-2.0 -from graphdoc.modules.token_tracker import TokenTracker from graphdoc.modules.doc_generator_module import DocGeneratorModule +from graphdoc.modules.token_tracker import TokenTracker __all__ = [ "DocGeneratorModule", diff --git a/graphdoc/graphdoc/modules/doc_generator_module.py b/graphdoc/graphdoc/modules/doc_generator_module.py index c541118..f8df262 100644 --- a/graphdoc/graphdoc/modules/doc_generator_module.py +++ b/graphdoc/graphdoc/modules/doc_generator_module.py @@ -248,12 +248,13 @@ def forward(self, database_schema: str) -> dspy.Prediction: dspy.Prediction """ + def _update_active_tasks(): with self.token_tracker.callback_lock: self.token_tracker.active_tasks -= 1 if self.token_tracker.active_tasks == 0: self.token_tracker.all_tasks_done.set() - + if self.retry: database_schema = self._retry_by_rating(database_schema=database_schema) _update_active_tasks() @@ -274,8 +275,9 @@ def document_full_schema( """Given a database schema, parse out the underlying components and document on a per-component basis. - :param database_schema: The database schema to generate documentation for. - :type database_schema: str :return: The generated documentation. + :param database_schema: The database schema to generate documentation for. + :type database_schema: str + :return: The generated documentation. :rtype: dspy.Prediction """ @@ -331,18 +333,20 @@ def document_full_schema( # token tracker details self.token_tracker.all_tasks_done.wait() callbacks_during_run = 0 - while True: - try: + while True: + try: data = self.token_tracker.callback_queue.get(timeout=2) with self.token_tracker.callback_lock: self.token_tracker.api_call_count += 1 self.token_tracker.model_name = data.get("model", "unknown") - self.token_tracker.completion_tokens += data.get("completion_tokens", 0) + self.token_tracker.completion_tokens += data.get( + "completion_tokens", 0 + ) self.token_tracker.prompt_tokens += data.get("prompt_tokens", 0) self.token_tracker.total_tokens += data.get("total_tokens", 0) callbacks_during_run += 1 self.token_tracker.callback_queue.task_done() - except queue.Empty: + except queue.Empty: log.info("Queue empty after timeout, assuming all callbacks processed") break @@ -368,7 +372,10 @@ def document_full_schema( trace=root_trace, # type: ignore # TODO: we should have better type handling, but i believe we will get an # error if root_trace has an issue during the start_trace call - outputs={"documented_schema": return_schema, "token_tracker": self.token_tracker.stats()}, + outputs={ + "documented_schema": return_schema, + "token_tracker": self.token_tracker.stats(), + }, status=status, ) log.info("ended trace: " + str(root_trace)) # type: ignore diff --git a/graphdoc/graphdoc/modules/token_tracker.py b/graphdoc/graphdoc/modules/token_tracker.py index 23485f2..fefbacc 100644 --- a/graphdoc/graphdoc/modules/token_tracker.py +++ b/graphdoc/graphdoc/modules/token_tracker.py @@ -1,9 +1,10 @@ # Copyright 2025-, Semiotic AI, Inc. # SPDX-License-Identifier: Apache-2.0 +import logging + # system packages import queue -import logging import threading # external packages @@ -14,17 +15,16 @@ # logging log = logging.getLogger(__name__) -class TokenTracker: - """ - A class to track the number of tokens used. - """ + +class TokenTracker: + """A class to track the number of tokens used.""" def __init__(self): self.model_name = "" - self.api_call_count = 0 + self.api_call_count = 0 self.completion_tokens = 0 - self.prompt_tokens = 0 - self.total_tokens = 0 + self.prompt_tokens = 0 + self.total_tokens = 0 self.active_tasks = 0 self.callback_lock = threading.Lock() self.callback_queue = queue.Queue() @@ -34,20 +34,16 @@ def __init__(self): litellm.callbacks.append(self.global_token_callback) def clear(self): - """ - Clear the token tracker. - """ - self.api_call_count = 0 + """Clear the token tracker.""" + self.api_call_count = 0 self.model_name = "" self.completion_tokens = 0 - self.prompt_tokens = 0 - self.total_tokens = 0 + self.prompt_tokens = 0 + self.total_tokens = 0 self.active_tasks = 0 def stats(self): - """ - Get the stats of the token tracker. - """ + """Get the stats of the token tracker.""" return { "model_name": self.model_name, "api_call_count": self.api_call_count, @@ -56,9 +52,13 @@ def stats(self): "total_tokens": self.total_tokens, } - def global_token_callback(self, kwargs, response, start_time, end_time, **callback_kwargs): - """ - A global callback to track the number of tokens used. Intended to be used with the litellm ModelResponse object. + def global_token_callback( + self, kwargs, response, start_time, end_time, **callback_kwargs + ): + """A global callback to track the number of tokens used. + + Intended to be used with the litellm ModelResponse object. + """ data = { "model": response.get("model", "unknown"), @@ -67,4 +67,6 @@ def global_token_callback(self, kwargs, response, start_time, end_time, **callba "total_tokens": response.get("usage", {}).get("total_tokens", 0), } self.callback_queue.put(data) - log.info(f"Callback triggered, queued data, thread: {threading.current_thread().name}") \ No newline at end of file + log.info( + f"Callback triggered, queued data, thread: {threading.current_thread().name}" + ) diff --git a/graphdoc/graphdoc/prompts/schema_doc_generation.py b/graphdoc/graphdoc/prompts/schema_doc_generation.py index c3b4431..d8bacc3 100644 --- a/graphdoc/graphdoc/prompts/schema_doc_generation.py +++ b/graphdoc/graphdoc/prompts/schema_doc_generation.py @@ -24,17 +24,17 @@ class DocGeneratorSignature(dspy.Signature): """ ### TASK: - Analyze the provided GraphQL Schema and generate detailed yet concise descriptions for each field within the database tables and enums. - + Analyze the provided GraphQL Schema and generate detailed yet concise descriptions for each field within the database tables and enums. + ### Requirements: - Utilize only the verified information from the schema to ensure accuracy. - Descriptions should be factual, straightforward, and avoid any speculative language. - Refrain from using the phrase "in the { table } table" within your descriptions. - Ensure that the documentation adheres to standard schema formatting without modifying the underlying schema structure. - + ### Formatting: - Maintain consistency with the existing documentation style and structure. - - Focus on clarity and precision to aid developers and system architects in understanding the schema's components effectively. + - Focus on clarity and precision to aid developers and system architects in understanding the schema's components effectively. """ # noqa: B950 database_schema: str = dspy.InputField() diff --git a/graphdoc/graphdoc/prompts/schema_doc_quality.py b/graphdoc/graphdoc/prompts/schema_doc_quality.py index 44fef03..9765b6c 100644 --- a/graphdoc/graphdoc/prompts/schema_doc_quality.py +++ b/graphdoc/graphdoc/prompts/schema_doc_quality.py @@ -21,18 +21,18 @@ ################### class DocQualitySignature(dspy.Signature): """ - You are a documentation quality evaluator specializing in GraphQL schemas. Your task is to assess the quality of documentation provided for a given database schema. Carefully analyze the schema's descriptions for clarity, accuracy, and completeness. Categorize the documentation into one of the following ratings based on your evaluation: + You are a documentation quality evaluator specializing in GraphQL schemas. Your task is to assess the quality of documentation provided for a given database schema. Carefully analyze the schema's descriptions for clarity, accuracy, and completeness. Categorize the documentation into one of the following ratings based on your evaluation: - perfect (4): The documentation is comprehensive and leaves no room for ambiguity in understanding the schema and its database content. - almost perfect (3): The documentation is clear and mostly free of ambiguity, but there is potential for further improvement. - poor but correct (2): The documentation is correct but lacks detail, resulting in some ambiguity. It requires enhancement to be more informative. - incorrect (1): The documentation contains errors or misleading information, regardless of any correct segments present. Such inaccuracies necessitate an incorrect rating. - Provide a step-by-step reasoning to support your evaluation, along with the appropriate category label and numerical rating. + Provide a step-by-step reasoning to support your evaluation, along with the appropriate category label and numerical rating. """ # noqa: B950 database_schema: str = dspy.InputField() - category: Literal["perfect", "almost perfect", "poor but correct", "incorrect"] = ( - dspy.OutputField() - ) + category: Literal[ + "perfect", "almost perfect", "poor but correct", "incorrect" + ] = dspy.OutputField() rating: Literal[4, 3, 2, 1] = dspy.OutputField() @@ -70,9 +70,9 @@ class DocQualityDemonstrationSignature(dspy.Signature): """ # noqa: B950 database_schema: str = dspy.InputField() - category: Literal["perfect", "almost perfect", "poor but correct", "incorrect"] = ( - dspy.OutputField() - ) + category: Literal[ + "perfect", "almost perfect", "poor but correct", "incorrect" + ] = dspy.OutputField() rating: Literal[4, 3, 2, 1] = dspy.OutputField() diff --git a/graphdoc/scratch_queue.py b/graphdoc/scratch_queue.py index f39d117..0c4c271 100644 --- a/graphdoc/scratch_queue.py +++ b/graphdoc/scratch_queue.py @@ -1,15 +1,17 @@ -from datetime import datetime +import concurrent.futures +import logging import os +import queue import random import threading -from dotenv import load_dotenv -import litellm import uuid -import logging +from datetime import datetime + import dspy -import concurrent.futures +import litellm +from dotenv import load_dotenv + from graphdoc.config import dspy_lm_from_yaml -import queue logger = logging.getLogger(__name__) @@ -24,6 +26,7 @@ callback_queue = queue.Queue() all_tasks_done = threading.Event() + def global_token_callback(kwargs, response, start_time, end_time, **callback_kwargs): data = { "model": response.get("model", "unknown"), @@ -32,7 +35,10 @@ def global_token_callback(kwargs, response, start_time, end_time, **callback_kwa "total_tokens": response.get("usage", {}).get("total_tokens", 0), } callback_queue.put(data) - logger.info(f"Callback triggered, queued data, thread: {threading.current_thread().name}") + logger.info( + f"Callback triggered, queued data, thread: {threading.current_thread().name}" + ) + def math_chain(task_id, active_tasks): math = dspy.Predict("question -> answer: float") @@ -48,9 +54,10 @@ def math_chain(task_id, active_tasks): all_tasks_done.set() logger.info(f"Task {task_id} completed, remaining active tasks: {active_tasks[0]}") + def math_chain_multi(): global api_call_count, model_name, completion_tokens, prompt_tokens, total_tokens - + with callback_lock: start_count = api_call_count logger.info(f"math_chain_multi started, initial api_call_count: {start_count}") @@ -59,16 +66,18 @@ def math_chain_multi(): logger.info(f"math_chain_multi: num_tasks: {num_tasks}") active_tasks = [num_tasks] all_tasks_done.clear() - + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: - futures = [executor.submit(math_chain, i, active_tasks) for i in range(num_tasks)] + futures = [ + executor.submit(math_chain, i, active_tasks) for i in range(num_tasks) + ] logger.info(f"total futures: {len(futures)}") for future in concurrent.futures.as_completed(futures): future.result() - + all_tasks_done.wait() logger.info("All tasks completed, now draining callback queue...") - + callbacks_during_run = 0 while True: try: @@ -84,7 +93,7 @@ def math_chain_multi(): except queue.Empty: logger.info("Queue empty after timeout, assuming all callbacks processed") break - + logger.info(f"math_chain_multi: model_name: {model_name}") logger.info(f"math_chain_multi: completion_tokens: {completion_tokens}") logger.info(f"math_chain_multi: prompt_tokens: {prompt_tokens}") @@ -92,12 +101,13 @@ def math_chain_multi(): logger.info(f"math_chain_multi: total api_call_count: {api_call_count}") logger.info(f"math_chain_multi: callbacks during this run: {callbacks_during_run}") + def main(): print("hello, world!") load_dotenv("../.env") if global_token_callback not in litellm.callbacks: litellm.callbacks.append(global_token_callback) - + log_dir = "logs" os.makedirs(log_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -107,13 +117,15 @@ def main(): root_logger.setLevel(logging.DEBUG) file_handler = logging.FileHandler(log_file) file_handler.setLevel(logging.DEBUG) - file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + file_handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ) root_logger.addHandler(file_handler) console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) - console_handler.setFormatter(logging.Formatter('%(levelname)s - %(message)s')) + console_handler.setFormatter(logging.Formatter("%(levelname)s - %(message)s")) root_logger.addHandler(console_handler) - + print(f"logging to file: {log_file}") logger.info("logging initialized") os.environ["LITELLM_LOG"] = "DEBUG" @@ -123,5 +135,6 @@ def main(): math_chain_multi() + if __name__ == "__main__": - main() \ No newline at end of file + main() From ed6606bfc2bdeeb3983fd7a696d3969abb6ee6e9 Mon Sep 17 00:00:00 2001 From: denver Date: Fri, 14 Mar 2025 14:10:53 -0500 Subject: [PATCH 11/14] refactor: remove unused values based on pep8 check --- graphdoc/graphdoc/config.py | 1 - graphdoc/graphdoc/eval/doc_generator_eval.py | 1 - 2 files changed, 2 deletions(-) diff --git a/graphdoc/graphdoc/config.py b/graphdoc/graphdoc/config.py index 225dd62..07968cf 100644 --- a/graphdoc/graphdoc/config.py +++ b/graphdoc/graphdoc/config.py @@ -688,7 +688,6 @@ def doc_generator_eval_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorEva mdh = mlflow_data_helper_from_yaml(yaml_path) # noqa: F841 # load the eval config - mlflow_tracking_uri = config["mlflow"]["mlflow_tracking_uri"] mlflow_experiment_name = config["eval"]["mlflow_experiment_name"] generator_prediction_field = config["eval"]["generator_prediction_field"] evaluator_prediction_field = config["eval"]["evaluator_prediction_field"] diff --git a/graphdoc/graphdoc/eval/doc_generator_eval.py b/graphdoc/graphdoc/eval/doc_generator_eval.py index cd08a45..166daff 100644 --- a/graphdoc/graphdoc/eval/doc_generator_eval.py +++ b/graphdoc/graphdoc/eval/doc_generator_eval.py @@ -3,7 +3,6 @@ # system packages import logging -from pathlib import Path from typing import Any, List, Union # external packages From f79a4c757a060c3ba3be41132d647e95bdd4f3c5 Mon Sep 17 00:00:00 2001 From: denver Date: Fri, 14 Mar 2025 14:12:26 -0500 Subject: [PATCH 12/14] refactor: correct types for pyright --- graphdoc/graphdoc/eval/doc_generator_eval.py | 4 ++-- graphdoc/graphdoc/modules/doc_generator_module.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/graphdoc/graphdoc/eval/doc_generator_eval.py b/graphdoc/graphdoc/eval/doc_generator_eval.py index 166daff..12a7f0e 100644 --- a/graphdoc/graphdoc/eval/doc_generator_eval.py +++ b/graphdoc/graphdoc/eval/doc_generator_eval.py @@ -59,13 +59,13 @@ def forward(self, database_schema: str) -> dict[str, Any]: """Takes a database schema, documents it, and then evaluates each component and the aggregate.""" # (we assume we are using DocGeneratorModule) - generator_result = self.generator.document_full_schema( + generator_result = self.generator.document_full_schema( # type: ignore database_schema=database_schema, trace=True, client=self.mlflow_helper.mlflow_client, expirement_name=self.mlflow_experiment_name, api_key="temp", - ) # type: ignore + ) # TODO: let's decide if this is how we want to handle this in the future. # Alternatively, we could return the documented schema from forward, # not as a prediction object. diff --git a/graphdoc/graphdoc/modules/doc_generator_module.py b/graphdoc/graphdoc/modules/doc_generator_module.py index f8df262..16ef84b 100644 --- a/graphdoc/graphdoc/modules/doc_generator_module.py +++ b/graphdoc/graphdoc/modules/doc_generator_module.py @@ -28,7 +28,7 @@ def __init__( retry_limit: int = 1, rating_threshold: int = 3, fill_empty_descriptions: bool = True, - token_tracker: TokenTracker = None, + token_tracker: Optional[TokenTracker] = None, ) -> None: """Initialize the DocGeneratorModule. A module for generating documentation for a given GraphQL schema. Schemas are decomposed and individually used to generate From 81ea6033507479752ecc4ddbf36c072935158ec8 Mon Sep 17 00:00:00 2001 From: denver Date: Fri, 14 Mar 2025 14:13:09 -0500 Subject: [PATCH 13/14] refactor: remove scratch_queue.py file --- graphdoc/scratch_queue.py | 140 -------------------------------------- 1 file changed, 140 deletions(-) delete mode 100644 graphdoc/scratch_queue.py diff --git a/graphdoc/scratch_queue.py b/graphdoc/scratch_queue.py deleted file mode 100644 index 0c4c271..0000000 --- a/graphdoc/scratch_queue.py +++ /dev/null @@ -1,140 +0,0 @@ -import concurrent.futures -import logging -import os -import queue -import random -import threading -import uuid -from datetime import datetime - -import dspy -import litellm -from dotenv import load_dotenv - -from graphdoc.config import dspy_lm_from_yaml - -logger = logging.getLogger(__name__) - -# global variables -api_call_count = 0 -model_name = "gpt-4o-2024-08-06" -completion_tokens = 0 -prompt_tokens = 0 -total_tokens = 0 -callback_lock = threading.Lock() - -callback_queue = queue.Queue() -all_tasks_done = threading.Event() - - -def global_token_callback(kwargs, response, start_time, end_time, **callback_kwargs): - data = { - "model": response.get("model", "unknown"), - "completion_tokens": response.get("usage", {}).get("completion_tokens", 0), - "prompt_tokens": response.get("usage", {}).get("prompt_tokens", 0), - "total_tokens": response.get("usage", {}).get("total_tokens", 0), - } - callback_queue.put(data) - logger.info( - f"Callback triggered, queued data, thread: {threading.current_thread().name}" - ) - - -def math_chain(task_id, active_tasks): - math = dspy.Predict("question -> answer: float") - num_requests = random.randint(1, 3) - logger.info(f"math_chain: num_requests: {num_requests}") - for i in range(num_requests): - value = random.randint(1, 10) - result = math(question=f"What is 2+{value}? ID: {uuid.uuid4()}") - logger.info(f"Task {task_id}, Request {i+1}: Result: {result.answer}") - with callback_lock: - active_tasks[0] -= 1 - if active_tasks[0] == 0: - all_tasks_done.set() - logger.info(f"Task {task_id} completed, remaining active tasks: {active_tasks[0]}") - - -def math_chain_multi(): - global api_call_count, model_name, completion_tokens, prompt_tokens, total_tokens - - with callback_lock: - start_count = api_call_count - logger.info(f"math_chain_multi started, initial api_call_count: {start_count}") - - num_tasks = random.randint(3, 7) - logger.info(f"math_chain_multi: num_tasks: {num_tasks}") - active_tasks = [num_tasks] - all_tasks_done.clear() - - with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: - futures = [ - executor.submit(math_chain, i, active_tasks) for i in range(num_tasks) - ] - logger.info(f"total futures: {len(futures)}") - for future in concurrent.futures.as_completed(futures): - future.result() - - all_tasks_done.wait() - logger.info("All tasks completed, now draining callback queue...") - - callbacks_during_run = 0 - while True: - try: - data = callback_queue.get(timeout=2) - with callback_lock: - api_call_count += 1 - model_name = data["model"] - completion_tokens += data["completion_tokens"] - prompt_tokens += data["prompt_tokens"] - total_tokens += data["total_tokens"] - callbacks_during_run += 1 - callback_queue.task_done() - except queue.Empty: - logger.info("Queue empty after timeout, assuming all callbacks processed") - break - - logger.info(f"math_chain_multi: model_name: {model_name}") - logger.info(f"math_chain_multi: completion_tokens: {completion_tokens}") - logger.info(f"math_chain_multi: prompt_tokens: {prompt_tokens}") - logger.info(f"math_chain_multi: total_tokens: {total_tokens}") - logger.info(f"math_chain_multi: total api_call_count: {api_call_count}") - logger.info(f"math_chain_multi: callbacks during this run: {callbacks_during_run}") - - -def main(): - print("hello, world!") - load_dotenv("../.env") - if global_token_callback not in litellm.callbacks: - litellm.callbacks.append(global_token_callback) - - log_dir = "logs" - os.makedirs(log_dir, exist_ok=True) - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - log_file = os.path.join(log_dir, f"run_{timestamp}.log") - - root_logger = logging.getLogger() - root_logger.setLevel(logging.DEBUG) - file_handler = logging.FileHandler(log_file) - file_handler.setLevel(logging.DEBUG) - file_handler.setFormatter( - logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - ) - root_logger.addHandler(file_handler) - console_handler = logging.StreamHandler() - console_handler.setLevel(logging.INFO) - console_handler.setFormatter(logging.Formatter("%(levelname)s - %(message)s")) - root_logger.addHandler(console_handler) - - print(f"logging to file: {log_file}") - logger.info("logging initialized") - os.environ["LITELLM_LOG"] = "DEBUG" - - config_path = "/Users/denver/Documents/code/graph/graphdoc-mono/graphdoc/graphdoc/assets/configs/single_prompt_doc_generator_module_eval.yaml" - dspy_lm_from_yaml(config_path) - - math_chain_multi() - - -if __name__ == "__main__": - main() From 733794dc8fcadc9c64118281334307f843b9213e Mon Sep 17 00:00:00 2001 From: denver Date: Fri, 14 Mar 2025 14:19:17 -0500 Subject: [PATCH 14/14] style: remove outdated comments in doc gen eval --- graphdoc/graphdoc/eval/doc_generator_eval.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/graphdoc/graphdoc/eval/doc_generator_eval.py b/graphdoc/graphdoc/eval/doc_generator_eval.py index 12a7f0e..76b3a3a 100644 --- a/graphdoc/graphdoc/eval/doc_generator_eval.py +++ b/graphdoc/graphdoc/eval/doc_generator_eval.py @@ -29,7 +29,6 @@ def __init__( DocQualityPrompt, SinglePrompt, Any ], # we have type hints, but accept any type for flexibility evalset: Union[List[dspy.Example], Any], - # mlflow_tracking_uri: Union[str, Path], mlflow_helper: MlflowDataHelper, mlflow_experiment_name: str = "doc_generator_eval", generator_prediction_field: str = "documented_schema", @@ -47,11 +46,9 @@ def __init__( self.generator = generator self.evaluator = evaluator self.evalset = evalset - # self.mlflow_tracking_uri = mlflow_tracking_uri self.mlflow_helper = mlflow_helper self.generator_prediction_field = generator_prediction_field self.evaluator_prediction_field = evaluator_prediction_field - # self.mlflow_helper = MlflowDataHelper(mlflow_tracking_uri) self.mlflow_experiment_name = mlflow_experiment_name self.readable_value = readable_value @@ -110,7 +107,6 @@ def evaluate(self): """Batches the evaluation set and logs the results to mlflow.""" mlflow.set_experiment(self.mlflow_experiment_name) with mlflow.start_run(): - # evalset = [x.database_schema for x in self.evalset] evaluation_results = self.batch(self.evalset, num_threads=32) avg_overall_rating = sum( [x["overall_rating"] for x in evaluation_results]