Skip to content
This repository was archived by the owner on Feb 6, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion graphdoc/graphdoc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
setup_logging,
)
from graphdoc.eval import DocGeneratorEvaluator
from graphdoc.modules import DocGeneratorModule
from graphdoc.modules import DocGeneratorModule, TokenTracker
from graphdoc.prompts import (
BadDocGeneratorSignature,
DocGeneratorHelperSignature,
Expand All @@ -60,6 +60,7 @@

__all__ = [
"DocGeneratorModule",
"TokenTracker",
"DocGeneratorEvaluator",
"DocGeneratorTrainer",
"DocQualityTrainer",
Expand Down
63 changes: 60 additions & 3 deletions graphdoc/graphdoc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,50 @@
#######################


def lm_from_dict(lm_config: dict):
"""Load a language model from a dictionary of parameters.

:param lm_config: Dictionary containing language model parameters.
:type lm_config: dict

"""
return dspy.LM(**lm_config)


def lm_from_yaml(yaml_path: Union[str, Path]):
"""Load a language model from a YAML file.

:param lm_config: Dictionary containing language model parameters.
:type lm_config: dict

"""
config = load_yaml_config(yaml_path)
return lm_from_dict(config["language_model"])


def dspy_lm_from_dict(lm_config: dict):
"""Load a language model from a dictionary of parameters. Set the dspy language
model.

:param lm_config: Dictionary containing language model parameters.
:type lm_config: dict

"""
lm = lm_from_dict(lm_config)
dspy.configure(lm=lm)


def dspy_lm_from_yaml(yaml_path: Union[str, Path]):
"""Load a language model from a YAML file. Set the dspy language model.

:param lm_config: Dictionary containing language model parameters.
:type lm_config: dict

"""
config = load_yaml_config(yaml_path)
dspy_lm_from_dict(config["language_model"])


def mlflow_data_helper_from_dict(mlflow_config: dict) -> MlflowDataHelper:
"""Load a MLflow data helper from a dictionary of parameters.

Expand Down Expand Up @@ -354,6 +398,9 @@ def single_prompt_from_yaml(yaml_path: Union[str, Path]) -> SinglePrompt:
:rtype: SinglePrompt

"""
# set the dspy language model
dspy_lm_from_yaml(yaml_path)

config = load_yaml_config(yaml_path)
mlflow_config = config.get("mlflow", None)
if config["prompt"]["prompt_metric"]:
Expand Down Expand Up @@ -471,6 +518,9 @@ def single_trainer_from_yaml(yaml_path: Union[str, Path]) -> SinglePromptTrainer
:rtype: SinglePromptTrainer

"""
# set the dspy language model
dspy_lm_from_yaml(yaml_path)

try:
config = load_yaml_config(yaml_path)
prompt = single_prompt_from_yaml(yaml_path)
Expand Down Expand Up @@ -564,6 +614,9 @@ def doc_generator_module_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorM
:rtype: DocGeneratorModule

"""
# set the dspy language model
dspy_lm_from_yaml(yaml_path)

config = load_yaml_config(yaml_path)["module"]
prompt = single_prompt_from_yaml(yaml_path)
return doc_generator_module_from_dict(config, prompt)
Expand Down Expand Up @@ -622,6 +675,9 @@ def doc_generator_eval_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorEva
:rtype: DocGeneratorEvaluator

""" # noqa: B950
# set the dspy language model
dspy_lm_from_yaml(yaml_path)

# load the generator
generator = doc_generator_module_from_yaml(yaml_path)
config = load_yaml_config(yaml_path)
Expand All @@ -630,9 +686,10 @@ def doc_generator_eval_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorEva
metric_config = config["prompt_metric"]
evaluator = single_prompt_from_dict(metric_config, metric_config["metric"])

# load the eval config
# load the mlflow data helper
mdh = mlflow_data_helper_from_yaml(yaml_path) # noqa: F841
mlflow_tracking_uri = config["eval"]["mlflow_tracking_uri"]

# load the eval config
mlflow_experiment_name = config["eval"]["mlflow_experiment_name"]
generator_prediction_field = config["eval"]["generator_prediction_field"]
evaluator_prediction_field = config["eval"]["evaluator_prediction_field"]
Expand All @@ -646,7 +703,7 @@ def doc_generator_eval_from_yaml(yaml_path: Union[str, Path]) -> DocGeneratorEva
generator=generator,
evaluator=evaluator,
evalset=evalset,
mlflow_tracking_uri=mlflow_tracking_uri,
mlflow_helper=mdh,
mlflow_experiment_name=mlflow_experiment_name,
generator_prediction_field=generator_prediction_field,
evaluator_prediction_field=evaluator_prediction_field,
Expand Down
15 changes: 9 additions & 6 deletions graphdoc/graphdoc/eval/doc_generator_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

# system packages
import logging
from pathlib import Path
from typing import Any, List, Union

# external packages
Expand All @@ -30,7 +29,7 @@ def __init__(
DocQualityPrompt, SinglePrompt, Any
], # we have type hints, but accept any type for flexibility
evalset: Union[List[dspy.Example], Any],
mlflow_tracking_uri: Union[str, Path],
mlflow_helper: MlflowDataHelper,
mlflow_experiment_name: str = "doc_generator_eval",
generator_prediction_field: str = "documented_schema",
evaluator_prediction_field: str = "rating",
Expand All @@ -47,18 +46,23 @@ def __init__(
self.generator = generator
self.evaluator = evaluator
self.evalset = evalset
self.mlflow_tracking_uri = mlflow_tracking_uri
self.mlflow_helper = mlflow_helper
self.generator_prediction_field = generator_prediction_field
self.evaluator_prediction_field = evaluator_prediction_field
self.mlflow_helper = MlflowDataHelper(mlflow_tracking_uri)
self.mlflow_experiment_name = mlflow_experiment_name
self.readable_value = readable_value

def forward(self, database_schema: str) -> dict[str, Any]:
"""Takes a database schema, documents it, and then evaluates each component and
the aggregate."""
# (we assume we are using DocGeneratorModule)
generator_result = self.generator.forward(database_schema) # type: ignore
generator_result = self.generator.document_full_schema( # type: ignore
database_schema=database_schema,
trace=True,
client=self.mlflow_helper.mlflow_client,
expirement_name=self.mlflow_experiment_name,
api_key="temp",
)
# TODO: let's decide if this is how we want to handle this in the future.
# Alternatively, we could return the documented schema from forward,
# not as a prediction object.
Expand Down Expand Up @@ -103,7 +107,6 @@ def evaluate(self):
"""Batches the evaluation set and logs the results to mlflow."""
mlflow.set_experiment(self.mlflow_experiment_name)
with mlflow.start_run():
# evalset = [x.database_schema for x in self.evalset]
evaluation_results = self.batch(self.evalset, num_threads=32)
avg_overall_rating = sum(
[x["overall_rating"] for x in evaluation_results]
Expand Down
2 changes: 2 additions & 0 deletions graphdoc/graphdoc/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

from graphdoc.modules.doc_generator_module import DocGeneratorModule
from graphdoc.modules.token_tracker import TokenTracker

__all__ = [
"DocGeneratorModule",
"TokenTracker",
]
49 changes: 47 additions & 2 deletions graphdoc/graphdoc/modules/doc_generator_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# system packages
import logging
import queue
from typing import Any, Literal, Optional, Union

# external packages
Expand All @@ -12,6 +13,7 @@

# internal packages
from graphdoc.data import Parser
from graphdoc.modules.token_tracker import TokenTracker
from graphdoc.prompts import DocGeneratorPrompt, SinglePrompt

# logging
Expand All @@ -26,6 +28,7 @@ def __init__(
retry_limit: int = 1,
rating_threshold: int = 3,
fill_empty_descriptions: bool = True,
token_tracker: Optional[TokenTracker] = None,
) -> None:
"""Initialize the DocGeneratorModule. A module for generating documentation for
a given GraphQL schema. Schemas are decomposed and individually used to generate
Expand Down Expand Up @@ -59,6 +62,7 @@ def __init__(
# we should move to a dict like structure for passing in those parameters
self.fill_empty_descriptions = fill_empty_descriptions
self.par = Parser()
self.token_tracker = TokenTracker() if token_tracker is None else token_tracker

# ensure that the doc generator prompt metric is set to rating
if self.prompt.prompt_metric.prompt_metric != "rating":
Expand Down Expand Up @@ -250,11 +254,21 @@ def forward(self, database_schema: str) -> dspy.Prediction:
:rtype: dspy.Prediction

"""

def _update_active_tasks():
with self.token_tracker.callback_lock:
self.token_tracker.active_tasks -= 1
if self.token_tracker.active_tasks == 0:
self.token_tracker.all_tasks_done.set()

if self.retry:
database_schema = self._retry_by_rating(database_schema=database_schema)
_update_active_tasks()
return dspy.Prediction(documented_schema=database_schema)
else:
return self._predict(database_schema=database_schema)
prediction = self._predict(database_schema=database_schema)
_update_active_tasks()
return prediction

def document_full_schema(
self,
Expand Down Expand Up @@ -310,6 +324,10 @@ def document_full_schema(
)
log.info("created trace: " + str(root_trace))

# token tracker details
self.token_tracker.active_tasks = len(examples)
self.token_tracker.all_tasks_done.clear()

# batch generate the documentation
documented_examples = self.batch(examples, num_threads=32)
document_ast.definitions = tuple(
Expand All @@ -318,6 +336,26 @@ def document_full_schema(
# TODO: we should have better type handling, but we know this works
)

# token tracker details
self.token_tracker.all_tasks_done.wait()
callbacks_during_run = 0
while True:
try:
data = self.token_tracker.callback_queue.get(timeout=2)
with self.token_tracker.callback_lock:
self.token_tracker.api_call_count += 1
self.token_tracker.model_name = data.get("model", "unknown")
self.token_tracker.completion_tokens += data.get(
"completion_tokens", 0
)
self.token_tracker.prompt_tokens += data.get("prompt_tokens", 0)
self.token_tracker.total_tokens += data.get("total_tokens", 0)
callbacks_during_run += 1
self.token_tracker.callback_queue.task_done()
except queue.Empty:
log.info("Queue empty after timeout, assuming all callbacks processed")
break

# check that the generated schema matches the original schema
if self.par.schema_equality_check(parse(database_schema), document_ast):
log.info("Schema equality check passed, returning documented schema")
Expand All @@ -340,9 +378,16 @@ def document_full_schema(
trace=root_trace, # type: ignore
# TODO: we should have better type handling, but i believe we will get an
# error if root_trace has an issue during the start_trace call
outputs={"documented_schema": return_schema},
outputs={
"documented_schema": return_schema,
"token_tracker": self.token_tracker.stats(),
},
status=status,
)
log.info("ended trace: " + str(root_trace)) # type: ignore
# TODO: we should have better type handling, but we check at the top

# clear the token tracker
self.token_tracker.clear()

return dspy.Prediction(documented_schema=return_schema)
72 changes: 72 additions & 0 deletions graphdoc/graphdoc/modules/token_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2025-, Semiotic AI, Inc.
# SPDX-License-Identifier: Apache-2.0

import logging

# system packages
import queue
import threading

# external packages
import litellm

# internal packages

# logging
log = logging.getLogger(__name__)


class TokenTracker:
"""A class to track the number of tokens used."""

def __init__(self):
self.model_name = ""
self.api_call_count = 0
self.completion_tokens = 0
self.prompt_tokens = 0
self.total_tokens = 0
self.active_tasks = 0
self.callback_lock = threading.Lock()
self.callback_queue = queue.Queue()
self.all_tasks_done = threading.Event()

if self.global_token_callback not in litellm.callbacks:
litellm.callbacks.append(self.global_token_callback)

def clear(self):
"""Clear the token tracker."""
self.api_call_count = 0
self.model_name = ""
self.completion_tokens = 0
self.prompt_tokens = 0
self.total_tokens = 0
self.active_tasks = 0

def stats(self):
"""Get the stats of the token tracker."""
return {
"model_name": self.model_name,
"api_call_count": self.api_call_count,
"completion_tokens": self.completion_tokens,
"prompt_tokens": self.prompt_tokens,
"total_tokens": self.total_tokens,
}

def global_token_callback(
self, kwargs, response, start_time, end_time, **callback_kwargs
):
"""A global callback to track the number of tokens used.

Intended to be used with the litellm ModelResponse object.

"""
data = {
"model": response.get("model", "unknown"),
"completion_tokens": response.get("usage", {}).get("completion_tokens", 0),
"prompt_tokens": response.get("usage", {}).get("prompt_tokens", 0),
"total_tokens": response.get("usage", {}).get("total_tokens", 0),
}
self.callback_queue.put(data)
log.info(
f"Callback triggered, queued data, thread: {threading.current_thread().name}"
)
Loading