-
Notifications
You must be signed in to change notification settings - Fork 8
Support Unique Conversation ID #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
97af907
6fb999e
63a861e
4c87bba
cb4fef7
eabfb6d
733aa28
42320f6
14c8ed0
d4debf2
feedd85
80fb5c2
2ee3051
5c5b6f4
7c03595
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,20 +49,30 @@ def __init__( | |
| self.max_total_words = max_total_words | ||
| self.max_personas = max_personas | ||
|
|
||
| self.AGENT_SYSTEM_PROMPT = self.agent_model_config.get( | ||
| "system_prompt", "You are a helpful AI assistant." | ||
| ) | ||
|
|
||
| async def run_single_conversation( | ||
| self, | ||
| persona_config: dict, | ||
| agent, | ||
| max_turns: int, | ||
| conversation_id: int, | ||
| conversation_index: int, | ||
jgieringer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| run_number: int, | ||
| **kargs: dict, | ||
| **kwargs: dict, | ||
| ) -> Dict[str, Any]: | ||
| """Run a single conversation asynchronously.""" | ||
| """Run a single simulated conversation (persona vs provider LLM). | ||
|
|
||
| Uses fresh LLM instances per call; safe for concurrent use. Logs turns, | ||
| writes transcript to self.folder_name, then cleans up logger and LLMs. | ||
|
|
||
| Args: | ||
| persona_config (dict): Must have "model", "prompt", "name". | ||
| max_turns (int): Max conversation turns for a conversation. | ||
| conversation_index (int): Index in the batch of conversations. | ||
| run_number (int): Run index for this prompt (e.g. 1 of runs_per_prompt). | ||
| **kwargs: Unused; reserved for future use. | ||
|
|
||
| Returns: | ||
| Dict[str, Any]: index, llm1_model, llm1_prompt, run_number, turns, | ||
| filename, log_file, duration, early_termination, conversation. | ||
| """ | ||
| model_name = persona_config["model"] | ||
| system_prompt = persona_config["prompt"] # This is now the full persona prompt | ||
| persona_name = persona_config["name"] | ||
|
|
@@ -83,7 +93,7 @@ async def run_single_conversation( | |
| logger = setup_conversation_logger(filename_base, run_id=self.run_id) | ||
| start_time = time.time() | ||
|
|
||
| # Create LLM1 instance with the persona prompt and configuration | ||
| # Create persona instance | ||
| persona = LLMFactory.create_llm( | ||
| model_name=model_name, | ||
| name=f"{model_short} {persona_name}", | ||
|
|
@@ -92,6 +102,23 @@ async def run_single_conversation( | |
| **self.persona_model_config, | ||
| ) | ||
|
|
||
| # Create new agent instance to reset conversation_id and metadata. | ||
| # Exclude selected kwargs to avoid duplicate args expected in create_llm. | ||
| agent_kwargs = { | ||
jgieringer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| k: v | ||
| for k, v in self.agent_model_config.items() | ||
| if k not in ("model", "name", "system_prompt") | ||
| } | ||
| agent = LLMFactory.create_llm( | ||
| model_name=self.agent_model_config["model"], | ||
| name=self.agent_model_config.get("name", "Provider"), | ||
| system_prompt=self.agent_model_config.get( | ||
| "system_prompt", "You are a helpful AI assistant." | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would maybe have this as a optional arg with the default, since it seems a potentially consequential decision and it's buried here
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| ), | ||
| role=Role.PROVIDER, | ||
| **agent_kwargs, | ||
| ) | ||
|
|
||
| # Log conversation start | ||
| log_conversation_start( | ||
| logger=logger, | ||
|
|
@@ -148,7 +175,7 @@ async def run_single_conversation( | |
| simulator.save_conversation(f"{filename_base}.txt", self.folder_name) | ||
|
|
||
| result = { | ||
| "id": conversation_id, | ||
| "index": conversation_index, | ||
| "llm1_model": model_name, | ||
| "llm1_prompt": persona_name, | ||
| "run_number": run_number, | ||
|
|
@@ -164,11 +191,12 @@ async def run_single_conversation( | |
|
|
||
| # Cleanup LLM resources (e.g., close HTTP sessions for Azure) | ||
| # Always cleanup, even if conversation failed | ||
| try: | ||
| await persona.cleanup() | ||
| except Exception as e: | ||
| # Log but don't fail if cleanup fails | ||
| print(f"Warning: Failed to cleanup persona LLM: {e}") | ||
| for llm in (persona, agent): | ||
| try: | ||
| await llm.cleanup() | ||
| except Exception as e: | ||
| # Log but don't fail if cleanup fails | ||
| print(f"Warning: Failed to cleanup LLM: {e}") | ||
jgieringer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| return result | ||
|
|
||
|
|
@@ -179,37 +207,26 @@ async def run_conversations( | |
| # Load prompts from CSV based on persona names | ||
| personas = load_prompts_from_csv(persona_names, max_personas=self.max_personas) | ||
|
|
||
| # Load agent configuration (fixed, shared across all conversations) | ||
| agent = LLMFactory.create_llm( | ||
| model_name=self.agent_model_config["model"], | ||
| name=self.agent_model_config.pop("name"), | ||
| system_prompt=self.AGENT_SYSTEM_PROMPT, | ||
| role=Role.PROVIDER, | ||
| **self.agent_model_config, | ||
| ) | ||
|
|
||
| # Create tasks for all conversations (each prompt run multiple times) | ||
| tasks = [] | ||
| conversation_id = 1 | ||
| conversation_index = 1 | ||
|
|
||
| for persona in personas: | ||
| for run in range(1, self.runs_per_prompt + 1): | ||
| tasks.append( | ||
| # TODO: should we pass the persona object here? | ||
| self.run_single_conversation( | ||
| { | ||
| "model": self.persona_model_config["model"], | ||
| "prompt": persona["prompt"], | ||
| "name": persona["Name"], | ||
| "run": run, | ||
| }, | ||
| agent, | ||
| self.max_turns, | ||
| conversation_id, | ||
| conversation_index, | ||
| run, | ||
| ) | ||
| ) | ||
| conversation_id += 1 | ||
| conversation_index += 1 | ||
|
|
||
| # Run all conversations with concurrency limit | ||
| start_time = datetime.now() | ||
|
|
@@ -237,11 +254,4 @@ async def run_with_limit(task): | |
|
|
||
| print(f"\nCompleted {len(results)} conversations in {total_time:.2f} seconds") | ||
|
|
||
| # Cleanup agent LLM resources (e.g., close HTTP sessions for Azure) | ||
| try: | ||
| await agent.cleanup() | ||
| except Exception as e: | ||
| # Log but don't fail if cleanup fails | ||
| print(f"Warning: Failed to cleanup agent LLM: {e}") | ||
|
|
||
| return results | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -187,19 +187,19 @@ async def generate_response( | |
| # Extract token usage | ||
| if "token_usage" in metadata: | ||
| usage = metadata["token_usage"] | ||
| self.last_response_metadata["usage"] = { | ||
| self._last_response_metadata["usage"] = { | ||
| "input_tokens": usage.get("input_tokens", 0), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is off-topic, but... do we save this metadata anywhere? It doesn't seem to be in the logging output for chat generation, or the logs output for judging... but if we have total token usage for each conversation and judging evaluation somewhere that we could write out, it would help... everyone with understanding costs. (This is probably a separate ticket... but only if we really aren't storing it anywhere.)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| "output_tokens": usage.get("output_tokens", 0), | ||
| "total_tokens": usage.get("total_tokens", 0), | ||
| } | ||
|
|
||
| # Extract finish reason | ||
| self.last_response_metadata["finish_reason"] = metadata.get( | ||
| self._last_response_metadata["finish_reason"] = metadata.get( | ||
| "finish_reason" | ||
| ) | ||
|
|
||
| # Store raw metadata | ||
| self.last_response_metadata["raw_metadata"] = dict(metadata) | ||
| self._last_response_metadata["raw_metadata"] = dict(metadata) | ||
|
|
||
| return response.text | ||
jgieringer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| except Exception as e: | ||
|
|
@@ -307,10 +307,6 @@ async def generate_structured_response( | |
| } | ||
| raise RuntimeError(f"Error generating structured response: {str(e)}") from e | ||
|
|
||
| def get_last_response_metadata(self) -> Dict[str, Any]: | ||
| """Get metadata from the last response.""" | ||
| return self.last_response_metadata.copy() | ||
|
|
||
| def set_system_prompt(self, system_prompt: str) -> None: | ||
| """Set or update the system prompt.""" | ||
| self.system_prompt = system_prompt | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.