Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions ragulate/pipelines/query_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
import signal
import time
from typing import Any, Dict, List, Optional
import pandas as pd

from tqdm import tqdm
from trulens_eval import Tru, TruChain
from trulens_eval.feedback.provider import (
AzureOpenAI,
Bedrock,
Huggingface,
Langchain,
LiteLLM,
OpenAI,
)
from trulens_eval.feedback.provider import OpenAI
from trulens_eval.feedback.provider.base import LLMProvider
from trulens_eval.schema.feedback import FeedbackMode, FeedbackResultStatus

Expand Down Expand Up @@ -116,13 +115,24 @@ def start_evaluation(self):
self._tru.start_evaluator(disable_tqdm=True)
self._evaluation_running = True

def export_results(self):
records = self._tru.get_records_and_feedbacks()
data = [record.__dict__ for record in records]

# Convert to DataFrame
df = pd.DataFrame(data)

# Export to JSON
df.to_json('results.json', orient='records')

def stop_evaluation(self, loc: str):
if self._evaluation_running:
try:
logger.debug(f"Stopping evaluation from: {loc}")
self._tru.stop_evaluator()
self._evaluation_running = False
self._tru.delete_singleton()
self.export_results()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably only want this if the test is complete... not if someone ctrl-c in the middle of a run.

except Exception as e:
logger.error(f"issue stopping evaluator: {e}")
finally:
Expand Down Expand Up @@ -158,12 +168,6 @@ def get_provider(self) -> LLMProvider:
return OpenAI(model_engine=model_name)
elif llm_provider == "azureopenai":
return AzureOpenAI(deployment_name=model_name)
elif llm_provider == "bedrock":
return Bedrock(model_id=model_name)
elif llm_provider == "litellm":
return LiteLLM(model_engine=model_name)
elif llm_provider == "Langchain":
return Langchain(model_engine=model_name)
elif llm_provider == "huggingface":
return Huggingface(name=model_name)
else:
Expand Down