diff --git a/docs/colab_notebooks/5-multistep-toolcalling/multistep-toolcalling.ipynb b/docs/colab_notebooks/5-multistep-toolcalling/multistep-toolcalling.ipynb new file mode 100644 index 000000000..cc64ed659 --- /dev/null +++ b/docs/colab_notebooks/5-multistep-toolcalling/multistep-toolcalling.ipynb @@ -0,0 +1,1794 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Building Multi-Step Tool-Calling Datasets with Data Designer\n", + "\n", + "Generate synthetic training data for agentic Reinforcement-Learning using NVIDIA **Data Designer** to enhance multi-step tool calling ability.\n", + "\n", + "## Prerequisites\n", + "\n", + "- **NVIDIA API Key** from [build.nvidia.com](https://build.nvidia.com) to access a remote LLM for generation. Alternatively, you may choose to use your own endpoint or deployment.\n", + "- **Python 3.11+**\n", + "- **Tool definition files** in the `tools/` directory (included in this repo)\n", + "- Packages: `data-designer`, `pydantic`, `pandas`\n", + "\n", + "## Objectives\n", + "\n", + "By the end of this notebook, you will:\n", + "\n", + "- Load known **tool schemas** as the seed for generating agent queries and simulated trajectories\n", + "- Use **Data Designer** to generate realistic multi-step user queries\n", + "- Simulate **agent trajectories** (step-by-step tool-call solutions)\n", + "- Apply **dual-level LLM judge filtering** to ensure data quality\n", + "- Export training data in **NeMo Gym format** for rollout collection and RLVR training\n", + "\n", + "# \n", + "#\n", + "> **Context Note:** The primary goal of this notebook is user query generation. The trajectory generation step in this notebook serves as a sanity check to ensure the generated query leads to a feasible path. In production RL training, rollout (oracle trajectory) traces are generated from the environment itself. You may find more information in [NeMo Gym Rollout Collection](https://docs.nvidia.com/nemo/gym/latest/get-started/rollout-collection.html) documentation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Architecture Overview\n", + "\n", + "```\n", + " ┌───────────────────────────────────────────────────────────────┐\n", + " │ DATA GENERATION PIPELINE │\n", + " ├───────────────────────────────────────────────────────────────┤\n", + " │ │\n", + " │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │\n", + " │ │ Tool Schemas │───▶│ User Query │───▶│ Trajectory │ │\n", + " │ │ (Seed) │ │ Generation │ │ Simulation │ │\n", + " │ └──────────────┘ └──────────────┘ └──────────────┘ │\n", + " │ │ │\n", + " │ ▼ │\n", + " │ ┌──────────────┐ │\n", + " │ │ LLM Judge │ │\n", + " │ │ (Quality) │ │\n", + " │ └──────────────┘ │\n", + " │ │ │\n", + " │ ▼ │\n", + " │ ┌──────────────┐ │\n", + " │ │ NeMo Gym │ │\n", + " │ │ Format │ │\n", + " │ └──────────────┘ │\n", + " │ │\n", + " └───────────────────────────────────────────────────────────────┘\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Install and Import Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install -q data-designer pydantic pandas" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import random\n", + "from typing import List, Optional\n", + "from pydantic import BaseModel, Field\n", + "import pandas as pd\n", + "\n", + "# Data Designer imports\n", + "from data_designer.config import (\n", + " ChatCompletionInferenceParams,\n", + " DataDesignerConfigBuilder,\n", + " LLMStructuredColumnConfig,\n", + " LLMTextColumnConfig,\n", + " LocalFileSeedSource,\n", + " ModelConfig,\n", + " SamplingStrategy,\n", + " ModelProvider,\n", + ")\n", + "from data_designer.interface import DataDesigner" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Context: What is the Workplace Assistant Environment?\n", + "\n", + "[**Workplace Assistant**](https://docs.nvidia.com/nemo/gym/latest/tutorials/nemo-rl-grpo/about-workplace-assistant.html#) is a multi-step tool-using benchmark environment used in **NeMo Gym** for RL training. A model gets a natural language business request and must call tools in the right order with valid arguments (up to 6 steps).\n", + "\n", + "At a high level:\n", + "- The model reads a user request (for example, scheduling meetings or updating CRM records)\n", + "- The model decides which tools to call and with what parameters\n", + "- The environment verifies correctness using **state matching** (final database state), not exact step matching\n", + "\n", + "In this notebook, we focus on **data generation**: starting from known tool schemas, generating realistic user requests, and simulating feasible trajectories to produce NeMo Gym-compatible training data.\n", + "\n", + "> **Note:** The official NeMo Gym [Workplace Assistant environment](https://github.com/NVIDIA-NeMo/Gym/tree/main/resources_servers/workplace_assistant) is the training target. This notebook is an example synthetic data preparation stage that feeds that workflow." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Load Tool Definitions\n", + "\n", + "This notebook begins with **established tool schemas** and leverages them as foundational context for data generation. These schemas represent the (possibly domain-specific) tools on which you aim to enhance model performance.\n", + "\n", + "We use 27 tools across 6 tool groups:\n", + "- **Company Directory**: Look up employee email addresses\n", + "- **Email**: Send, search, reply, forward, delete emails\n", + "- **Calendar**: Create, search, update, delete events\n", + "- **Analytics**: Query website visitor data and create plots\n", + "- **Project Management**: Manage tasks across Kanban boards\n", + "- **CRM**: Manage customer records and sales pipeline\n", + "\n", + "These tools are designed to require **multi-step reasoning**. For example:\n", + "- \"Email John about the meeting\" requires first looking up John's email, then sending\n", + "- \"Reassign all of Sarah's leads to Mike\" requires looking up emails, searching customers, then updating each one\n", + "\n", + "> **Why this matters:** The schemas define valid arguments and values (for example, allowed board/list/status values). We use these constraints to generate realistic user queries and schema-compliant simulated trajectories." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded 27 tools across 6 databases\n", + "\n", + "Databases:\n", + " - company_directory: 1 tools\n", + " Employee directory for looking up email addresses by name.\n", + " - email: 6 tools\n", + " Email inbox and outbox for sending, receiving, and managing emails.\n", + " - calendar: 5 tools\n", + " Calendar for managing meetings and events.\n", + " - analytics: 6 tools\n", + " Website analytics data for tracking visitor behavior and engagement.\n", + " - project_management: 5 tools\n", + " Project management board for tracking tasks across teams.\n", + " - customer_relationship_manager: 4 tools\n", + " CRM for managing customer records and sales pipeline.\n" + ] + } + ], + "source": [ + "# Load tool definitions from separate JSON files (one per database)\n", + "import os\n", + "\n", + "TOOLS_DIR = 'tools'\n", + "\n", + "# Load environment config\n", + "with open(os.path.join(TOOLS_DIR, 'environment.json'), 'r') as f:\n", + " env_config = json.load(f)\n", + "\n", + "SYSTEM_PROMPT = env_config['system_prompt']\n", + "MULTI_STEP_PATTERNS = env_config['common_multi_step_patterns']\n", + "\n", + "# Load tools from each database file\n", + "DATABASE_FILES = [\n", + " 'company_directory.json',\n", + " 'email.json', \n", + " 'calendar.json',\n", + " 'analytics.json',\n", + " 'project_management.json',\n", + " 'customer_relationship_manager.json'\n", + "]\n", + "\n", + "TOOLS = []\n", + "DATABASES = {}\n", + "TOOL_CATEGORIES = {}\n", + "\n", + "for db_file in DATABASE_FILES:\n", + " with open(os.path.join(TOOLS_DIR, db_file), 'r') as f:\n", + " db_config = json.load(f)\n", + " \n", + " db_name = db_config['database']\n", + " DATABASES[db_name] = {\n", + " 'description': db_config['description'],\n", + " 'data_schema': db_config['data_schema']\n", + " }\n", + " \n", + " # Add tools and track category\n", + " db_tools = db_config['tools']\n", + " TOOLS.extend(db_tools)\n", + " TOOL_CATEGORIES[db_name] = [t['name'] for t in db_tools]\n", + "\n", + "print(f\"Loaded {len(TOOLS)} tools across {len(DATABASES)} databases\")\n", + "print(f\"\\nDatabases:\")\n", + "for db_name, db_info in DATABASES.items():\n", + " tool_count = len(TOOL_CATEGORIES[db_name])\n", + " print(f\" - {db_name}: {tool_count} tools\")\n", + " print(f\" {db_info['description']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display all loaded tools grouped by database. This summary shows each tool's name and description." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "### COMPANY_DIRECTORY (1 tools)\n", + "- **company_directory_find_email_address**: Finds all email addresses containing the given name (case-insensitive search).\n", + "\n", + "### EMAIL (6 tools)\n", + "- **email_get_email_information_by_id**: Retrieves specific details of an email by its ID.\n", + "- **email_search_emails**: Searches for emails matching the given query across subject, body, or sender fields. The function matches an email if all words in the query appear in any of these fields.\n", + "- **email_send_email**: Sends an email to the specified recipient.\n", + "- **email_delete_email**: Deletes an email by its ID.\n", + "- **email_forward_email**: Forwards an email to the specified recipient.\n", + "- **email_reply_email**: Replies to an email by its ID.\n", + "\n", + "### CALENDAR (5 tools)\n", + "- **calendar_get_event_information_by_id**: Returns the event for a given ID.\n", + "- **calendar_search_events**: Returns the events for a given query with pagination support.\n", + "- **calendar_create_event**: Creates a new event.\n", + "- **calendar_delete_event**: Deletes an event.\n", + "- **calendar_update_event**: Updates an event.\n", + "\n", + "### ANALYTICS (6 tools)\n", + "- **analytics_get_visitor_information_by_id**: Returns the analytics data for a given visitor ID.\n", + "- **analytics_create_plot**: Plots the analytics data for a given time range and value.\n", + "- **analytics_total_visits_count**: Returns the total number of visits within a specified time range.\n", + "- **analytics_engaged_users_count**: Returns the number of engaged users within a specified time range.\n", + "- **analytics_traffic_source_count**: Returns the number of visits from a specific traffic source within a specified time range.\n", + "- **analytics_get_average_session_duration**: Returns the average session duration within a specified time range.\n", + "\n", + "### PROJECT_MANAGEMENT (5 tools)\n", + "- **project_management_get_task_information_by_id**: Returns the task information for a given ID.\n", + "- **project_management_search_tasks**: Searches for tasks based on the given parameters.\n", + "- **project_management_create_task**: Creates a new task.\n", + "- **project_management_delete_task**: Deletes a task by ID.\n", + "- **project_management_update_task**: Updates a task by ID.\n", + "\n", + "### CUSTOMER_RELATIONSHIP_MANAGER (4 tools)\n", + "- **customer_relationship_manager_search_customers**: Searches for customers based on the given parameters with pagination support.\n", + "- **customer_relationship_manager_update_customer**: Updates a customer record by ID.\n", + "- **customer_relationship_manager_add_customer**: Adds a new customer record.\n", + "- **customer_relationship_manager_delete_customer**: Deletes a customer record by ID.\n" + ] + } + ], + "source": [ + "# Helper function to format tools for prompts\n", + "def format_tools_for_prompt(tools: List[dict], include_schemas: bool = False) -> str:\n", + " \"\"\"Format tool definitions into a readable string for LLM prompts.\"\"\"\n", + " lines = []\n", + " for tool in tools:\n", + " lines.append(f\"- **{tool['name']}**: {tool['description']}\")\n", + " if include_schemas:\n", + " params = tool['parameters']['properties']\n", + " if params:\n", + " lines.append(f\" Parameters: {list(params.keys())}\")\n", + " return \"\\n\".join(lines)\n", + "\n", + "# Display tool summary by category\n", + "for category, tool_names in TOOL_CATEGORIES.items():\n", + " print(f\"\\n### {category.upper()} ({len(tool_names)} tools)\")\n", + " category_tools = [t for t in TOOLS if t['name'] in tool_names]\n", + " print(format_tools_for_prompt(category_tools))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Define Output Schemas\n", + "\n", + "**Data Designer** uses **Pydantic** models to define structured output formats, ensuring the LLM generates data in a consistent, parseable format.\n", + "\n", + "We define five schemas:\n", + "1. **ToolCall** / **AgentStep** / **AgentTrajectory**: Represent a multi-step tool-calling solution\n", + "2. **UserQueryJudgeScores**: Quality scores for generated user queries\n", + "3. **TrajectoryJudgeScores**: Quality scores for generated trajectories" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "class ToolCall(BaseModel):\n", + " \"\"\"A single tool invocation.\"\"\"\n", + " name: str = Field(..., description=\"The name of the tool to call (e.g., 'email_send_email')\")\n", + " arguments: str = Field(..., description=\"JSON string of the tool arguments\")\n", + "\n", + "\n", + "class AgentStep(BaseModel):\n", + " \"\"\"A single step in the agent's reasoning trajectory.\"\"\"\n", + " step_number: int = Field(..., description=\"The step number (1-indexed)\")\n", + " thought: str = Field(\n", + " ..., \n", + " description=\"The agent's reasoning about what to do next and why. Should explain the purpose of the tool call.\"\n", + " )\n", + " tool_call: ToolCall = Field(..., description=\"The tool to call in this step\")\n", + " expected_result: str = Field(\n", + " ..., \n", + " description=\"What information or state change we expect from this tool call\"\n", + " )\n", + "\n", + "\n", + "class AgentTrajectory(BaseModel):\n", + " \"\"\"Complete trajectory for solving a multi-step task.\"\"\"\n", + " reasoning_trace: List[AgentStep] = Field(\n", + " ..., \n", + " description=\"The sequence of steps to solve the task. Should be 1-6 steps.\"\n", + " )\n", + " final_answer: str = Field(\n", + " ..., \n", + " description=\"A brief confirmation of what was accomplished\"\n", + " )\n", + "\n", + "\n", + "class UserQueryJudgeScores(BaseModel):\n", + " \"\"\"Quality scores for the generated user query.\"\"\"\n", + " feasibility: int = Field(\n", + " ..., ge=1, le=5, \n", + " description=\"Is the request achievable with the available tools? (1=impossible, 5=fully achievable)\"\n", + " )\n", + " schema_compliance: int = Field(\n", + " ..., ge=1, le=5, \n", + " description=\"Does the request use valid values as defined in tool schemas (e.g., valid board names, list names, statuses)? (1=uses invalid values, 5=all values valid)\"\n", + " )\n", + " naturalness: int = Field(\n", + " ..., ge=1, le=5, \n", + " description=\"Does the request sound like a natural user query? (1=robotic/artificial, 5=very natural)\"\n", + " )\n", + " is_valid: bool = Field(\n", + " ..., \n", + " description=\"True if the query is valid and should be kept, False if it should be discarded\"\n", + " )\n", + " issues: str = Field(\n", + " ..., \n", + " description=\"List any issues found (invalid enum values, impossible requests, etc.) or 'None' if valid\"\n", + " )\n", + "\n", + "\n", + "class TrajectoryJudgeScores(BaseModel):\n", + " \"\"\"Quality scores for the generated trajectory.\"\"\"\n", + " tool_validity: int = Field(\n", + " ..., ge=1, le=5, \n", + " description=\"Are all tool names valid and arguments schema-compliant? (1=invalid tools/args, 5=all valid)\"\n", + " )\n", + " argument_validity: int = Field(\n", + " ..., ge=1, le=5, \n", + " description=\"Do all arguments use valid values as specified in tool descriptions? (1=invalid values, 5=all valid)\"\n", + " )\n", + " completeness: int = Field(\n", + " ..., ge=1, le=5, \n", + " description=\"Does the trajectory fully solve the user request? (1=incomplete, 5=fully complete)\"\n", + " )\n", + " efficiency: int = Field(\n", + " ..., ge=1, le=5, \n", + " description=\"Is the trajectory optimal without unnecessary steps? (1=very inefficient, 5=optimal)\"\n", + " )\n", + " is_valid: bool = Field(\n", + " ..., \n", + " description=\"True if the trajectory is valid and executable, False if it has errors\"\n", + " )\n", + " issues: str = Field(\n", + " ..., \n", + " description=\"List any issues found (invalid enum values, wrong tool names, missing steps, etc.) or 'None' if valid\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Define Generation Prompts\n", + "\n", + "The heart of synthetic data generation is the prompts. We define four prompts using **Jinja2 templates** (with `{{ variable }}` placeholders that Data Designer fills from seed columns):\n", + "\n", + "1. **User Query Generation**: Create realistic workplace requests\n", + "2. **Trajectory Simulation**: Generate the step-by-step tool-call solution\n", + "3. **User Query Judge**: Evaluate query feasibility and schema compliance\n", + "4. **Trajectory Judge**: Evaluate tool-call correctness and completeness\n", + "\n", + "### Key Principles\n", + "- **Specificity**: Tell the LLM exactly what format you want\n", + "- **Examples**: Show don't tell — include concrete examples by complexity level\n", + "- **Constraints**: Define what NOT to do (avoid trivial tasks, don't skip steps)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "User Query Generation Prompt loaded\n" + ] + } + ], + "source": [ + "# Prompt 1: Generate a realistic user query that may require one or more tool calls\n", + "USER_QUERY_GENERATION_PROMPT = \"\"\"\n", + "You are creating training data for a workplace assistant AI agent.\n", + "\n", + "**Your Task:** Generate a realistic user request that requires the agent to use one or more tools to complete.\n", + "\n", + "**Available Tools (with full schemas):**\n", + "{{ tools_json }}\n", + "\n", + "**Selected Tool Category:** {{ category }}\n", + "\n", + "**Multi-Step Pattern to Use:** {{ pattern }}\n", + "\n", + "**CRITICAL - Valid Values:**\n", + "Many tool parameters have RESTRICTED VALUES specified in their descriptions. You MUST only reference values that exist in the tool schemas. Pay close attention to phrases like \"One of:\" in parameter descriptions.\n", + "\n", + "Common restrictions to follow:\n", + "- `list_name`: Only use 'Backlog', 'In Progress', 'In Review', or 'Completed' (NOT 'Prospects', 'Todo', 'Pipeline', etc.)\n", + "- `board`: Only use 'Back end', 'Front end', or 'Design' (NOT 'Sales', 'Marketing', 'Engineering', etc.)\n", + "- `status`: Only use 'Qualified', 'Won', 'Lost', 'Lead', or 'Proposal' (NOT 'Active', 'Prospect', 'Closed', etc.)\n", + "- `product_interest`: Only use 'Software', 'Hardware', 'Services', 'Consulting', or 'Training'\n", + "\n", + "**Guidelines:**\n", + "1. The request should sound natural - like something a real employee would ask\n", + "2. It should require 1-6 tool calls to complete\n", + "3. Include specific details that make the task concrete (names, dates, subjects)\n", + "4. Don't mention tool names or technical details - speak like a normal user\n", + "5. The task MUST be achievable with the available tools using ONLY valid parameter values\n", + "6. When referencing boards, lists, statuses, etc., use EXACTLY the values allowed in the tool schemas\n", + "\n", + "**Examples by Complexity:**\n", + "\n", + "*Simple (1 step):*\n", + "- \"Reply to Carlos's last email about the prototype with 'Thanks, I'll review it tomorrow'\"\n", + "- \"Change the name of my 3pm meeting to 'Risk Management Forum'\"\n", + "- \"How many website visitors did we have last week?\"\n", + "\n", + "*Medium (2-3 steps):*\n", + "- \"Send an email to John about the quarterly review meeting tomorrow\"\n", + "- \"Schedule a 30-minute sync with Lisa tomorrow at 2pm\"\n", + "- \"Get the total visits and engaged users for November\"\n", + "\n", + "*Complex (4-6 steps):*\n", + "- \"Raj is taking over all of Akira's leads that are interested in software. Can you reassign them in the CRM?\"\n", + "- \"Forward the last email from marketing about the Q4 report to everyone on the design team\"\n", + "- \"Move all of Sarah's overdue tasks on the Back end board to the Backlog\"\n", + "\n", + "**Output:** Return ONLY the user request as a single string. No quotes, no explanation.\n", + "\"\"\"\n", + "\n", + "print(\"User Query Generation Prompt loaded\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trajectory Simulation Prompt loaded\n" + ] + } + ], + "source": [ + "# Prompt 2: Simulate the agent's trajectory for solving the task\n", + "TRAJECTORY_SIMULATION_PROMPT = \"\"\"\n", + "You are simulating an expert workplace assistant agent solving a task step-by-step.\n", + "\n", + "**User Request:**\n", + "{{ user_query }}\n", + "\n", + "**System Context:**\n", + "{{ system_prompt }}\n", + "\n", + "**Available Tools:**\n", + "{{ tools_json }}\n", + "\n", + "**Your Task:** Generate a step-by-step trajectory showing how the agent would solve this request.\n", + "\n", + "**Guidelines:**\n", + "1. **Think Step-by-Step**: Each step should have a clear thought explaining WHY we're calling this tool\n", + "2. **Use Real Tool Names**: The tool_call.name must exactly match one of the available tools\n", + "3. **Valid JSON Arguments**: The tool_call.arguments must be valid JSON matching the tool's parameter schema\n", + "4. **Realistic IDs**: When referencing IDs discovered in previous steps, use placeholder format like \"00000001\"\n", + "5. **Complete the Task**: The trajectory must fully solve the user's request\n", + "6. **1-6 Steps**: Use the minimum number of steps needed. Simple tasks may need only 1 step.\n", + "\n", + "**Common Patterns:**\n", + "- Look up a person's email before sending them a message\n", + "- Search for records before updating/deleting them\n", + "- Get information from one database to use in another\n", + "- Some tasks can be completed in a single step (e.g., reply to an email, update an event)\n", + "\n", + "**Example Step:**\n", + "{% raw %}\n", + "```json\n", + "{\n", + " \"step_number\": 1,\n", + " \"thought\": \"The user wants to email Raj, but I need his email address first. I'll look it up in the company directory.\",\n", + " \"tool_call\": {\n", + " \"name\": \"company_directory_find_email_address\",\n", + " \"arguments\": \"{\\\"name\\\": \\\"Raj\\\"}\"\n", + " },\n", + " \"expected_result\": \"Raj's email address (likely raj.patel@atlas.com)\"\n", + "}\n", + "```\n", + "{% endraw %}\n", + "\n", + "Output the complete AgentTrajectory with all steps needed to solve the task.\n", + "\"\"\"\n", + "\n", + "print(\"Trajectory Simulation Prompt loaded\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "User Query Judge Prompt loaded\n", + "Trajectory Judge Prompt loaded\n" + ] + } + ], + "source": [ + "# Prompt 3a: Judge the quality of the generated USER QUERY\n", + "USER_QUERY_JUDGE_PROMPT = \"\"\"\n", + "You are a quality assurance judge evaluating a synthetically generated user query for training an AI workplace assistant.\n", + "\n", + "**Generated User Query:**\n", + "{{ user_query }}\n", + "\n", + "**Available Tools (with full schemas):**\n", + "{{ tools_json }}\n", + "\n", + "**Your Task:** Evaluate whether this user query is valid and achievable with the available tools.\n", + "\n", + "**CRITICAL - Check for Schema Compliance:**\n", + "Many tools have RESTRICTED VALUES for certain fields. The user query must only reference values that are valid according to the tool schemas. For example:\n", + "- If a tool says `list_name` must be one of 'Backlog', 'In Progress', 'In Review', 'Completed' - the query cannot ask for a \"Prospects\" list\n", + "- If a tool says `board` must be one of 'Back end', 'Front end', 'Design' - the query cannot ask for a \"Sales\" board \n", + "- If a tool says `status` must be one of 'Qualified', 'Won', 'Lost', 'Lead', 'Proposal' - the query cannot use other statuses\n", + "\n", + "**Evaluation Criteria:**\n", + "\n", + "1. **Feasibility (1-5)**: Can this request be fulfilled using the available tools?\n", + " - Score 1 if the request requires tools/capabilities that don't exist\n", + " - Score 5 if the request is fully achievable with available tools\n", + "\n", + "2. **Schema Compliance (1-5)**: Does the request use valid values?\n", + " - Score 1 if the query references invalid enum values (wrong board names, list names, statuses, etc.)\n", + " - Score 3 if the query is ambiguous but could map to valid values\n", + " - Score 5 if all referenced values exactly match valid options in tool schemas\n", + "\n", + "3. **Naturalness (1-5)**: Does this sound like a real user request?\n", + " - Score 1 if robotic or artificial sounding\n", + " - Score 5 if very natural and realistic\n", + "\n", + "**is_valid:** Set to False if feasibility < 3 OR schema_compliance < 3. These queries should be discarded.\n", + "\n", + "**issues:** List specific problems found. Examples:\n", + "- \"References 'Sales' board but valid boards are: 'Back end', 'Front end', 'Design'\"\n", + "- \"References 'Prospects' list but valid lists are: 'Backlog', 'In Progress', 'In Review', 'Completed'\"\n", + "- \"None\" if no issues found\n", + "\n", + "**Output:** Return UserQueryJudgeScores with all fields.\n", + "\"\"\"\n", + "\n", + "# Prompt 3b: Judge the quality of the generated TRAJECTORY\n", + "TRAJECTORY_JUDGE_PROMPT = \"\"\"\n", + "You are a quality assurance judge evaluating a generated trajectory (sequence of tool calls) for training an AI workplace assistant.\n", + "\n", + "**User Request:**\n", + "{{ user_query }}\n", + "\n", + "**Generated Trajectory:**\n", + "{{ trajectory }}\n", + "\n", + "**Available Tools (with full schemas):**\n", + "{{ tools_json }}\n", + "\n", + "**Your Task:** Evaluate whether this trajectory correctly solves the user request using valid tool calls.\n", + "\n", + "**CRITICAL - Check for Argument Validity:**\n", + "Tool arguments must use EXACTLY the values allowed by the tool schemas. For example:\n", + "- `list_name` must be one of: 'Backlog', 'In Progress', 'In Review', 'Completed' (NOT 'Prospects', 'Todo', etc.)\n", + "- `board` must be one of: 'Back end', 'Front end', 'Design' (NOT 'Sales', 'Marketing', etc.)\n", + "- `status` must be one of: 'Qualified', 'Won', 'Lost', 'Lead', 'Proposal' (NOT 'Active', 'Prospect', etc.)\n", + "- `product_interest` must be one of: 'Software', 'Hardware', 'Services', 'Consulting', 'Training'\n", + "\n", + "**Evaluation Criteria:**\n", + "\n", + "1. **Tool Validity (1-5)**: Are all tool names correct?\n", + " - Score 1 if any tool name doesn't match available tools\n", + " - Score 5 if all tool names exactly match\n", + "\n", + "2. **Argument Validity (1-5)**: Are all arguments schema-compliant?\n", + " - Score 1 if any argument uses invalid enum values or wrong types\n", + " - Score 3 if arguments are mostly valid but some are ambiguous\n", + " - Score 5 if all arguments perfectly match the schema requirements\n", + "\n", + "3. **Completeness (1-5)**: Does the trajectory fully solve the request?\n", + " - Score 1 if major parts of the request are unaddressed\n", + " - Score 5 if the trajectory completely fulfills the request\n", + "\n", + "4. **Efficiency (1-5)**: Is the trajectory optimal?\n", + " - Score 1 if there are many unnecessary steps\n", + " - Score 5 if the trajectory is optimal with no wasted steps\n", + "\n", + "**is_valid:** Set to False if tool_validity < 4 OR argument_validity < 4. These trajectories have errors and should be discarded.\n", + "\n", + "**issues:** List specific problems found. Examples:\n", + "- \"Step 2 uses list_name='Prospects' but valid values are: 'Backlog', 'In Progress', 'In Review', 'Completed'\"\n", + "- \"Step 1 calls 'email_send' but correct tool name is 'email_send_email'\"\n", + "- \"None\" if no issues found\n", + "\n", + "**Output:** Return TrajectoryJudgeScores with all fields.\n", + "\"\"\"\n", + "\n", + "print(\"User Query Judge Prompt loaded\")\n", + "print(\"Trajectory Judge Prompt loaded\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Create Seed Data\n", + "\n", + "**Data Designer** works by expanding seed data through LLM generation. Each seed row provides context variables that get substituted into the prompt templates:\n", + "\n", + "- `category`: Which tool database to focus on (ensures diversity across domains)\n", + "- `pattern`: Which multi-step pattern to use (e.g., lookup-then-send, search-then-update)\n", + "- `tools_json`: Full tool schemas for the LLM to reference\n", + "- `system_prompt`: The system context for the workplace assistant\n", + "\n", + "> **Pattern Engineering Note:** The multi-step patterns used as seeds in `create_seed_data()` are domain-informed. In practice, you can engineer these patterns from heuristics, inferred tool-call chains observed in production traffic, or other rule-based design choices. In this case, we had some common patterns stored in `tools/environments.json`." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Created 50 seeds\n", + "\n", + "Sample seed:\n", + "{'seed_id': 0, 'category': 'company_directory', 'pattern': \"lookup_then_send_email: Look up a person's email address, then send them an email\", 'tools_description': \"- **company_directory_find_email_address**: Finds all email addresses containing the given name (case-insensitive search).\\n Parameters: ['name']\", 'tools_json': '[\\n {\\n \"type\": \"function\",\\n \"name\": \"company_directory_find_email_address\",\\n \"description\": \"Finds all email addresses containing the given name (case-insensitive search).\",\\n \"database\": \"company_directory\",\\n \"operation_type\": \"read\",\\n \"parameters\": {\\n \"type\": \"object\",\\n \"properties\": {\\n \"name\": {\\n \"type\": \"string\",\\n \"description\": \"Name or partial name to search for in email addresses\"\\n }\\n },\\n \"required\": [],\\n \"additionalProperties\": false\\n },\\n \"strict\": false\\n }\\n]', 'tools_summary': '- **company_directory_find_email_address**: Finds all email addresses containing the given name (case-insensitive search).\\n- **email_get_email_information_by_id**: Retrieves specific details of an email by its ID.\\n- **email_search_emails**: Searches for emails matching the given query across subject, body, or sender fields. The function matches an email if all words in the query appear in any of these fields.\\n- **email_send_email**: Sends an email to the specified recipient.\\n- **email_delete_email**: Deletes an email by its ID.\\n- **email_forward_email**: Forwards an email to the specified recipient.\\n- **email_reply_email**: Replies to an email by its ID.\\n- **calendar_get_event_information_by_id**: Returns the event for a given ID.\\n- **calendar_search_events**: Returns the events for a given query with pagination support.\\n- **calendar_create_event**: Creates a new event.\\n- **calendar_delete_event**: Deletes an event.\\n- **calendar_update_event**: Updates an event.\\n- **analytics_get_visitor_information_by_id**: Returns the analytics data for a given visitor ID.\\n- **analytics_create_plot**: Plots the analytics data for a given time range and value.\\n- **analytics_total_visits_count**: Returns the total number of visits within a specified time range.\\n- **analytics_engaged_users_count**: Returns the number of engaged users within a specified time range.\\n- **analytics_traffic_source_count**: Returns the number of visits from a specific traffic source within a specified time range.\\n- **analytics_get_average_session_duration**: Returns the average session duration within a specified time range.\\n- **project_management_get_task_information_by_id**: Returns the task information for a given ID.\\n- **project_management_search_tasks**: Searches for tasks based on the given parameters.\\n- **project_management_create_task**: Creates a new task.\\n- **project_management_delete_task**: Deletes a task by ID.\\n- **project_management_update_task**: Updates a task by ID.\\n- **customer_relationship_manager_search_customers**: Searches for customers based on the given parameters with pagination support.\\n- **customer_relationship_manager_update_customer**: Updates a customer record by ID.\\n- **customer_relationship_manager_add_customer**: Adds a new customer record.\\n- **customer_relationship_manager_delete_customer**: Deletes a customer record by ID.', 'system_prompt': \"Today's date is Thursday, 2026-01-29 and the current time is 23:59:00. Remember the current date and time when answering queries. Meetings must not start before 9am or end after 6pm.\"}\n" + ] + } + ], + "source": [ + "def create_seed_data(num_seeds: int = 100) -> pd.DataFrame:\n", + " \"\"\"\n", + " Create seed data for the Data Designer pipeline.\n", + " \n", + " Each seed contains:\n", + " - category: Which tool category to focus on\n", + " - pattern: Which multi-step pattern to use\n", + " - tools_description: Formatted tool descriptions\n", + " - tools_json: Full tool schemas as JSON\n", + " - system_prompt: The system context\n", + " \"\"\"\n", + " seeds = []\n", + " \n", + " categories = list(TOOL_CATEGORIES.keys())\n", + " patterns = [\n", + " f\"{p['pattern']}: {p['description']}\" for p in MULTI_STEP_PATTERNS\n", + " ]\n", + " \n", + " for i in range(num_seeds):\n", + " # Select category and pattern (ensuring diversity)\n", + " category = categories[i % len(categories)]\n", + " pattern = patterns[i % len(patterns)]\n", + " \n", + " # Get tools for this category (plus company_directory for lookups)\n", + " relevant_tool_names = TOOL_CATEGORIES[category] + TOOL_CATEGORIES.get('company_directory', [])\n", + " relevant_tools = [t for t in TOOLS if t['name'] in relevant_tool_names]\n", + " \n", + " seeds.append({\n", + " 'seed_id': i,\n", + " 'category': category,\n", + " 'pattern': pattern,\n", + " 'tools_description': format_tools_for_prompt(relevant_tools, include_schemas=True),\n", + " 'tools_json': json.dumps(relevant_tools, indent=2),\n", + " 'tools_summary': format_tools_for_prompt(TOOLS), # All tools for judge\n", + " 'system_prompt': SYSTEM_PROMPT,\n", + " })\n", + " \n", + " return pd.DataFrame(seeds)\n", + "\n", + "# Create seed data\n", + "seed_df = create_seed_data(num_seeds=50)\n", + "print(f\"Created {len(seed_df)} seeds\")\n", + "print(f\"\\nSample seed:\")\n", + "print(seed_df.iloc[0].to_dict())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Save the seed data as a Parquet file for Data Designer to consume." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Seeds saved to workplace_assistant_seeds.parquet\n" + ] + } + ], + "source": [ + "# Save seeds to parquet for Data Designer\n", + "seed_df.to_parquet('workplace_assistant_seeds.parquet', index=False)\n", + "print(\"Seeds saved to workplace_assistant_seeds.parquet\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Configure the Data Designer Pipeline\n", + "\n", + "Now we wire everything together into a **Data Designer** workflow:\n", + "\n", + "1. **Load Seeds** — Provides category, pattern, and tools for each generation\n", + "2. **Generate User Query** — LLM creates a realistic workplace request\n", + "3. **Judge User Query** — LLM validates feasibility and schema compliance\n", + "4. **Simulate Trajectory** — LLM generates the step-by-step tool-call solution\n", + "5. **Judge Trajectory** — LLM validates tool names and argument correctness\n", + "\n", + "### Configuration\n", + "\n", + "First, set up the **NVIDIA Inference API** provider and model. The API key is read from the `NVIDIA_API_KEY` environment variable." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "\n", + "if \"NVIDIA_API_KEY\" not in os.environ or not os.environ[\"NVIDIA_API_KEY\"]:\n", + " os.environ[\"NVIDIA_API_KEY\"] = getpass.getpass(\"Enter your NVIDIA API key: \")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# Define custom provider pointing to NVIDIA Inference API\n", + "NVIDIA_INFERENCE_URL = \"https://inference-api.nvidia.com/v1\"\n", + "\n", + "custom_providers = [\n", + " ModelProvider(\n", + " name=\"nvidia-inference\",\n", + " endpoint=NVIDIA_INFERENCE_URL,\n", + " provider_type=\"openai\",\n", + " api_key=os.environ.get(\"NVIDIA_API_KEY\", \"\"),\n", + " ),\n", + "]\n", + "\n", + "# Model name must match NVIDIA's model identifier\n", + "MODEL_ID = \"nvidia/openai/gpt-oss-120b\"\n", + "MODEL_ALIAS = \"gpt-oss-120b\"\n", + "\n", + "model_configs = [\n", + " ModelConfig(\n", + " alias=MODEL_ALIAS,\n", + " model=MODEL_ID,\n", + " provider=\"nvidia-inference\",\n", + " inference_parameters=ChatCompletionInferenceParams(\n", + " max_tokens=16384,\n", + " ),\n", + " )\n", + "]\n", + "\n", + "# Initialize DataDesigner and config builder\n", + "data_designer = DataDesigner(model_providers=custom_providers)\n", + "config_builder = DataDesignerConfigBuilder(model_configs=model_configs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Build the pipeline with four generation columns: user query, user query judge, trajectory, and trajectory judge." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pipeline configured with 4 generation columns:\n", + " 1. user_query (text) - Generate realistic user request\n", + " 2. user_query_judge (structured) - Validate query feasibility and schema compliance\n", + " 3. trajectory (structured) - Generate step-by-step solution\n", + " 4. trajectory_judge (structured) - Validate tool calls and argument values\n" + ] + } + ], + "source": [ + "def build_workplace_assistant_pipeline():\n", + " \"\"\"\n", + " Build the complete Data Designer pipeline for generating \n", + " multi-step tool-calling training data.\n", + " \n", + " Pipeline stages:\n", + " 1. Generate user query\n", + " 2. Judge user query (filter invalid queries early)\n", + " 3. Generate trajectory \n", + " 4. Judge trajectory (filter invalid trajectories)\n", + " \"\"\"\n", + " \n", + " # Initialize the config builder\n", + " config_builder = DataDesignerConfigBuilder(model_configs=model_configs)\n", + " \n", + " # Load seed data\n", + " seed_ref = LocalFileSeedSource(path='workplace_assistant_seeds.parquet')\n", + " config_builder.with_seed_dataset(seed_ref, sampling_strategy=SamplingStrategy.SHUFFLE)\n", + " \n", + " # Column 1: Generate User Query\n", + " # This creates a realistic workplace request based on the category and pattern\n", + " config_builder.add_column(\n", + " LLMTextColumnConfig(\n", + " name=\"user_query\",\n", + " prompt=USER_QUERY_GENERATION_PROMPT,\n", + " model_alias=MODEL_ALIAS,\n", + " )\n", + " )\n", + " \n", + " # Column 2: Judge User Query\n", + " # Validates that the user query is feasible and uses valid enum values\n", + " config_builder.add_column(\n", + " LLMStructuredColumnConfig(\n", + " name=\"user_query_judge\",\n", + " prompt=USER_QUERY_JUDGE_PROMPT,\n", + " output_format=UserQueryJudgeScores,\n", + " model_alias=MODEL_ALIAS,\n", + " )\n", + " )\n", + " \n", + " # Column 3: Simulate Agent Trajectory\n", + " # This generates the step-by-step solution with tool calls\n", + " config_builder.add_column(\n", + " LLMStructuredColumnConfig(\n", + " name=\"trajectory\",\n", + " prompt=TRAJECTORY_SIMULATION_PROMPT,\n", + " output_format=AgentTrajectory,\n", + " model_alias=MODEL_ALIAS,\n", + " )\n", + " )\n", + " \n", + " # Column 4: Judge Trajectory\n", + " # Validates that the trajectory uses correct tool names and valid argument values\n", + " config_builder.add_column(\n", + " LLMStructuredColumnConfig(\n", + " name=\"trajectory_judge\",\n", + " prompt=TRAJECTORY_JUDGE_PROMPT,\n", + " output_format=TrajectoryJudgeScores,\n", + " model_alias=MODEL_ALIAS,\n", + " )\n", + " )\n", + " \n", + " return config_builder\n", + "\n", + "# Build the pipeline\n", + "pipeline = build_workplace_assistant_pipeline()\n", + "print(\"Pipeline configured with 4 generation columns:\")\n", + "print(\" 1. user_query (text) - Generate realistic user request\")\n", + "print(\" 2. user_query_judge (structured) - Validate query feasibility and schema compliance\")\n", + "print(\" 3. trajectory (structured) - Generate step-by-step solution\")\n", + "print(\" 4. trajectory_judge (structured) - Validate tool calls and argument values\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Validate the pipeline configuration to catch any issues before generation." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[15:45:07] [INFO] ✅ Validation passed\n" + ] + } + ], + "source": [ + "data_designer.validate(pipeline)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run a quick preview with 2 records to verify the pipeline produces well-formed outputs before scaling up." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[15:45:09] [INFO] 🔭 Preview generation in progress\n", + "[15:45:09] [INFO] ✅ Validation passed\n", + "[15:45:09] [INFO] ⛓️ Sorting column configs into a Directed Acyclic Graph\n", + "[15:45:09] [INFO] 🩺 Running health checks for models...\n", + "[15:45:09] [INFO] |-- 👀 Checking 'nvidia/openai/gpt-oss-120b' in provider named 'nvidia-inference' for model alias 'gpt-oss-120b'...\n", + "[15:45:12] [INFO] |-- ✅ Passed!\n", + "[15:45:12] [INFO] 🌱 Sampling 2 records from seed dataset\n", + "[15:45:12] [INFO] |-- seed dataset size: 50 records\n", + "[15:45:12] [INFO] |-- sampling strategy: shuffle\n", + "[15:45:12] [INFO] 📝 llm-text model config for column 'user_query'\n", + "[15:45:12] [INFO] |-- model: 'nvidia/openai/gpt-oss-120b'\n", + "[15:45:12] [INFO] |-- model alias: 'gpt-oss-120b'\n", + "[15:45:12] [INFO] |-- model provider: 'nvidia-inference'\n", + "[15:45:12] [INFO] |-- inference parameters:\n", + "[15:45:12] [INFO] | |-- generation_type=chat-completion\n", + "[15:45:12] [INFO] | |-- max_parallel_requests=4\n", + "[15:45:12] [INFO] | |-- max_tokens=16384\n", + "[15:45:12] [INFO] ⚡️ Processing llm-text column 'user_query' with 4 concurrent workers\n", + "[15:45:12] [INFO] ⏱️ llm-text column 'user_query' will report progress after each record\n", + "[15:45:14] [INFO] |-- 😸 llm-text column 'user_query' progress: 1/2 (50%) complete, 1 ok, 0 failed, 0.46 rec/s, eta 2.2s\n", + "[15:45:15] [INFO] |-- 🦁 llm-text column 'user_query' progress: 2/2 (100%) complete, 2 ok, 0 failed, 0.85 rec/s, eta 0.0s\n", + "[15:45:15] [INFO] 🗂️ llm-structured model config for column 'user_query_judge'\n", + "[15:45:15] [INFO] |-- model: 'nvidia/openai/gpt-oss-120b'\n", + "[15:45:15] [INFO] |-- model alias: 'gpt-oss-120b'\n", + "[15:45:15] [INFO] |-- model provider: 'nvidia-inference'\n", + "[15:45:15] [INFO] |-- inference parameters:\n", + "[15:45:15] [INFO] | |-- generation_type=chat-completion\n", + "[15:45:15] [INFO] | |-- max_parallel_requests=4\n", + "[15:45:15] [INFO] | |-- max_tokens=16384\n", + "[15:45:15] [INFO] ⚡️ Processing llm-structured column 'user_query_judge' with 4 concurrent workers\n", + "[15:45:15] [INFO] ⏱️ llm-structured column 'user_query_judge' will report progress after each record\n", + "[15:45:17] [INFO] |-- 😸 llm-structured column 'user_query_judge' progress: 1/2 (50%) complete, 1 ok, 0 failed, 0.41 rec/s, eta 2.4s\n", + "[15:45:18] [INFO] |-- 🦁 llm-structured column 'user_query_judge' progress: 2/2 (100%) complete, 2 ok, 0 failed, 0.62 rec/s, eta 0.0s\n", + "[15:45:18] [INFO] 🗂️ llm-structured model config for column 'trajectory'\n", + "[15:45:18] [INFO] |-- model: 'nvidia/openai/gpt-oss-120b'\n", + "[15:45:18] [INFO] |-- model alias: 'gpt-oss-120b'\n", + "[15:45:18] [INFO] |-- model provider: 'nvidia-inference'\n", + "[15:45:18] [INFO] |-- inference parameters:\n", + "[15:45:18] [INFO] | |-- generation_type=chat-completion\n", + "[15:45:18] [INFO] | |-- max_parallel_requests=4\n", + "[15:45:18] [INFO] | |-- max_tokens=16384\n", + "[15:45:18] [INFO] ⚡️ Processing llm-structured column 'trajectory' with 4 concurrent workers\n", + "[15:45:18] [INFO] ⏱️ llm-structured column 'trajectory' will report progress after each record\n", + "[15:45:23] [INFO] |-- 🚗 llm-structured column 'trajectory' progress: 1/2 (50%) complete, 1 ok, 0 failed, 0.21 rec/s, eta 4.8s\n", + "[15:45:24] [INFO] |-- 🚀 llm-structured column 'trajectory' progress: 2/2 (100%) complete, 2 ok, 0 failed, 0.35 rec/s, eta 0.0s\n", + "[15:45:24] [INFO] 🗂️ llm-structured model config for column 'trajectory_judge'\n", + "[15:45:24] [INFO] |-- model: 'nvidia/openai/gpt-oss-120b'\n", + "[15:45:24] [INFO] |-- model alias: 'gpt-oss-120b'\n", + "[15:45:24] [INFO] |-- model provider: 'nvidia-inference'\n", + "[15:45:24] [INFO] |-- inference parameters:\n", + "[15:45:24] [INFO] | |-- generation_type=chat-completion\n", + "[15:45:24] [INFO] | |-- max_parallel_requests=4\n", + "[15:45:24] [INFO] | |-- max_tokens=16384\n", + "[15:45:24] [INFO] ⚡️ Processing llm-structured column 'trajectory_judge' with 4 concurrent workers\n", + "[15:45:24] [INFO] ⏱️ llm-structured column 'trajectory_judge' will report progress after each record\n", + "[15:45:26] [INFO] |-- 🐥 llm-structured column 'trajectory_judge' progress: 1/2 (50%) complete, 1 ok, 0 failed, 0.49 rec/s, eta 2.1s\n", + "[15:45:26] [INFO] |-- 🐔 llm-structured column 'trajectory_judge' progress: 2/2 (100%) complete, 2 ok, 0 failed, 0.72 rec/s, eta 0.0s\n", + "[15:45:27] [INFO] 📊 Model usage summary:\n", + "[15:45:27] [INFO] |-- model: nvidia/openai/gpt-oss-120b\n", + "[15:45:27] [INFO] |-- tokens: input=15944, output=4466, total=20410, tps=1417\n", + "[15:45:27] [INFO] |-- requests: success=8, failed=0, total=8, rpm=33\n", + "[15:45:27] [INFO] 📐 Measuring dataset column statistics:\n", + "[15:45:27] [INFO] |-- 📝 column: 'user_query'\n", + "[15:45:27] [INFO] |-- 🗂️ column: 'user_query_judge'\n", + "[15:45:27] [INFO] |-- 🗂️ column: 'trajectory'\n", + "[15:45:27] [INFO] |-- 🗂️ column: 'trajectory_judge'\n", + "[15:45:27] [INFO] 🏆 Preview complete!\n" + ] + } + ], + "source": [ + "preview = data_designer.preview(pipeline, num_records=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inspect a sample generated user query from the preview." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
seed_idcategorypatterntools_descriptiontools_jsontools_summarysystem_promptuser_queryuser_query_judgetrajectorytrajectory_judge
017customer_relationship_managerget_task_info_then_update: Get specific task i...- **company_directory_find_email_address**: Fi...[\\n {\\n \"type\": \"function\",\\n \"name\": \"...- **company_directory_find_email_address**: Fi...Today's date is Thursday, 2026-01-29 and the c...Can you find the customer John Doe who is a Le...{'feasibility': 5, 'schema_compliance': 5, 'na...{'reasoning_trace': [{'step_number': 1, 'thoug...{'tool_validity': 5, 'argument_validity': 5, '...
17emailsearch_then_batch_delete_events: Search for ca...- **company_directory_find_email_address**: Fi...[\\n {\\n \"type\": \"function\",\\n \"name\": \"...- **company_directory_find_email_address**: Fi...Today's date is Thursday, 2026-01-29 and the c...I need to clean up my inbox—please find all em...{'feasibility': 5, 'schema_compliance': 5, 'na...{'reasoning_trace': [{'step_number': 1, 'thoug...{'tool_validity': 5, 'argument_validity': 5, '...
\n", + "
" + ], + "text/plain": [ + " seed_id category \\\n", + "0 17 customer_relationship_manager \n", + "1 7 email \n", + "\n", + " pattern \\\n", + "0 get_task_info_then_update: Get specific task i... \n", + "1 search_then_batch_delete_events: Search for ca... \n", + "\n", + " tools_description \\\n", + "0 - **company_directory_find_email_address**: Fi... \n", + "1 - **company_directory_find_email_address**: Fi... \n", + "\n", + " tools_json \\\n", + "0 [\\n {\\n \"type\": \"function\",\\n \"name\": \"... \n", + "1 [\\n {\\n \"type\": \"function\",\\n \"name\": \"... \n", + "\n", + " tools_summary \\\n", + "0 - **company_directory_find_email_address**: Fi... \n", + "1 - **company_directory_find_email_address**: Fi... \n", + "\n", + " system_prompt \\\n", + "0 Today's date is Thursday, 2026-01-29 and the c... \n", + "1 Today's date is Thursday, 2026-01-29 and the c... \n", + "\n", + " user_query \\\n", + "0 Can you find the customer John Doe who is a Le... \n", + "1 I need to clean up my inbox—please find all em... \n", + "\n", + " user_query_judge \\\n", + "0 {'feasibility': 5, 'schema_compliance': 5, 'na... \n", + "1 {'feasibility': 5, 'schema_compliance': 5, 'na... \n", + "\n", + " trajectory \\\n", + "0 {'reasoning_trace': [{'step_number': 1, 'thoug... \n", + "1 {'reasoning_trace': [{'step_number': 1, 'thoug... \n", + "\n", + " trajectory_judge \n", + "0 {'tool_validity': 5, 'argument_validity': 5, '... \n", + "1 {'tool_validity': 5, 'argument_validity': 5, '... " + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "preview.dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('Can you find the customer John Doe who is a Lead interested in Software and reassign it to sarah.lee@company.com, then set the follow‑up date to 2024-04-30?',\n", + " {'reasoning_trace': [{'step_number': 1,\n", + " 'thought': 'I need to locate the specific customer record for John Doe who is a Lead interested in Software. I will search the CRM using those criteria.',\n", + " 'tool_call': {'name': 'customer_relationship_manager_search_customers',\n", + " 'arguments': '{\"customer_name\": \"John Doe\", \"status\": \"Lead\", \"product_interest\": \"Software\", \"page\": 1, \"page_size\": 10}'},\n", + " 'expected_result': 'A list of matching customers, including the customer\\'s ID (e.g., \"00000001\").'},\n", + " {'step_number': 2,\n", + " 'thought': \"Now that I have the customer's ID, I will reassign the customer to sarah.lee@company.com by updating the 'assigned_to_email' field.\",\n", + " 'tool_call': {'name': 'customer_relationship_manager_update_customer',\n", + " 'arguments': '{\"customer_id\": \"00000001\", \"field\": \"assigned_to_email\", \"new_value\": \"sarah.lee@company.com\"}'},\n", + " 'expected_result': \"Confirmation that the 'assigned_to_email' field was updated for customer ID 00000001.\"},\n", + " {'step_number': 3,\n", + " 'thought': \"Finally, I will set the follow‑up date to 2024‑04‑30 by updating the 'follow_up_by' field for the same customer.\",\n", + " 'tool_call': {'name': 'customer_relationship_manager_update_customer',\n", + " 'arguments': '{\"customer_id\": \"00000001\", \"field\": \"follow_up_by\", \"new_value\": \"2024-04-30\"}'},\n", + " 'expected_result': \"Confirmation that the 'follow_up_by' field was updated for customer ID 00000001.\"}],\n", + " 'final_answer': 'John Doe has been reassigned to sarah.lee@company.com and the follow‑up date has been set to 2024‑04‑30.'})" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "preview.dataset.user_query[0], preview.dataset.trajectory[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 7: Set Up Quality Filtering (Generic)\n", + "\n", + "Before any downstream format conversion, we apply dual-level quality filtering to keep only high-quality examples.\n", + "\n", + "This stage is generic to **Data Designer** workflows and not specific to NeMo Gym.\n", + "\n", + "To keep this notebook clean, quality filtering helpers live in `utils/quality_filtering.py`." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "from utils import filter_high_quality, show_rejection_reasons" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 8: Generate and Filter the Dataset\n", + "\n", + "Run the full pipeline end-to-end: generate records and apply **dual-level quality filtering**.\n", + "\n", + "```\n", + "┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐\n", + "│ Generate │───▶│ Stage 1: │───▶│ Stage 2: │───▶│ Filtered │\n", + "│ Records │ │ Query Judge │ │ Traj Judge │ │ Dataset │\n", + "└──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘\n", + "```\n", + "\n", + "**Utility location:** `utils/quality_filtering.py`\n", + "\n", + "**Quick usage:**\n", + "- Run `show_rejection_reasons(results_df, num_examples=3)` to inspect failures\n", + "- Run `filter_high_quality(results_df, verbose=True)` to apply default strict filtering\n", + "- Optionally tune thresholds with `FilterThresholds(...).to_kwargs()`\n", + "\n", + "**Why Dual-Level Filtering?**\n", + "- **Stage 1 (User Query)**: Catches queries like which are intractable in this context.\n", + "- **Stage 2 (Trajectory)**: Catches tool argument errors that slipped through, or that doesn't answer the query." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[15:45:43] [INFO] 🎨 Creating Data Designer dataset\n", + "[15:45:43] [INFO] ✅ Validation passed\n", + "[15:45:43] [INFO] ⛓️ Sorting column configs into a Directed Acyclic Graph\n", + "[15:45:44] [INFO] 🩺 Running health checks for models...\n", + "[15:45:44] [INFO] |-- 👀 Checking 'nvidia/openai/gpt-oss-120b' in provider named 'nvidia-inference' for model alias 'gpt-oss-120b'...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generating 10 examples...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[15:45:44] [INFO] |-- ✅ Passed!\n", + "[15:45:44] [INFO] 📂 Dataset path '/Users/shashankv/Documents/Work/workplace-asst-sdg/DataDesigner/docs/colab_notebooks/5-multistep-toolcalling/artifacts/dataset' already exists. Dataset from this session\n", + "\t\t will be saved to '/Users/shashankv/Documents/Work/workplace-asst-sdg/DataDesigner/docs/colab_notebooks/5-multistep-toolcalling/artifacts/dataset_02-12-2026_154544' instead.\n", + "[15:45:44] [INFO] ⏳ Processing batch 1 of 1\n", + "[15:45:44] [INFO] 🌱 Sampling 10 records from seed dataset\n", + "[15:45:44] [INFO] |-- seed dataset size: 50 records\n", + "[15:45:44] [INFO] |-- sampling strategy: shuffle\n", + "[15:45:44] [INFO] 📝 llm-text model config for column 'user_query'\n", + "[15:45:44] [INFO] |-- model: 'nvidia/openai/gpt-oss-120b'\n", + "[15:45:44] [INFO] |-- model alias: 'gpt-oss-120b'\n", + "[15:45:44] [INFO] |-- model provider: 'nvidia-inference'\n", + "[15:45:44] [INFO] |-- inference parameters:\n", + "[15:45:44] [INFO] | |-- generation_type=chat-completion\n", + "[15:45:44] [INFO] | |-- max_parallel_requests=4\n", + "[15:45:44] [INFO] | |-- max_tokens=16384\n", + "[15:45:44] [INFO] ⚡️ Processing llm-text column 'user_query' with 4 concurrent workers\n", + "[15:45:44] [INFO] ⏱️ llm-text column 'user_query' will report progress after each record\n", + "[15:45:46] [INFO] |-- 🌧️ llm-text column 'user_query' progress: 1/10 (10%) complete, 1 ok, 0 failed, 0.52 rec/s, eta 17.4s\n", + "[15:45:46] [INFO] |-- 🌧️ llm-text column 'user_query' progress: 2/10 (20%) complete, 2 ok, 0 failed, 1.02 rec/s, eta 7.9s\n", + "[15:45:46] [INFO] |-- 🌦️ llm-text column 'user_query' progress: 3/10 (30%) complete, 3 ok, 0 failed, 1.46 rec/s, eta 4.8s\n", + "[15:45:47] [INFO] |-- 🌦️ llm-text column 'user_query' progress: 4/10 (40%) complete, 4 ok, 0 failed, 1.34 rec/s, eta 4.5s\n", + "[15:45:47] [INFO] |-- ⛅ llm-text column 'user_query' progress: 5/10 (50%) complete, 5 ok, 0 failed, 1.45 rec/s, eta 3.5s\n", + "[15:45:48] [INFO] |-- ⛅ llm-text column 'user_query' progress: 6/10 (60%) complete, 6 ok, 0 failed, 1.38 rec/s, eta 2.9s\n", + "[15:45:48] [INFO] |-- ⛅ llm-text column 'user_query' progress: 7/10 (70%) complete, 7 ok, 0 failed, 1.53 rec/s, eta 2.0s\n", + "[15:45:49] [INFO] |-- 🌤️ llm-text column 'user_query' progress: 8/10 (80%) complete, 8 ok, 0 failed, 1.56 rec/s, eta 1.3s\n", + "[15:45:49] [INFO] |-- 🌤️ llm-text column 'user_query' progress: 9/10 (90%) complete, 9 ok, 0 failed, 1.73 rec/s, eta 0.6s\n", + "[15:45:51] [INFO] |-- ☀️ llm-text column 'user_query' progress: 10/10 (100%) complete, 10 ok, 0 failed, 1.45 rec/s, eta 0.0s\n", + "[15:45:51] [INFO] 🗂️ llm-structured model config for column 'user_query_judge'\n", + "[15:45:51] [INFO] |-- model: 'nvidia/openai/gpt-oss-120b'\n", + "[15:45:51] [INFO] |-- model alias: 'gpt-oss-120b'\n", + "[15:45:51] [INFO] |-- model provider: 'nvidia-inference'\n", + "[15:45:51] [INFO] |-- inference parameters:\n", + "[15:45:51] [INFO] | |-- generation_type=chat-completion\n", + "[15:45:51] [INFO] | |-- max_parallel_requests=4\n", + "[15:45:51] [INFO] | |-- max_tokens=16384\n", + "[15:45:51] [INFO] ⚡️ Processing llm-structured column 'user_query_judge' with 4 concurrent workers\n", + "[15:45:51] [INFO] ⏱️ llm-structured column 'user_query_judge' will report progress after each record\n", + "[15:45:52] [INFO] |-- 🌧️ llm-structured column 'user_query_judge' progress: 1/10 (10%) complete, 1 ok, 0 failed, 0.61 rec/s, eta 14.7s\n", + "[15:45:52] [INFO] |-- 🌧️ llm-structured column 'user_query_judge' progress: 2/10 (20%) complete, 2 ok, 0 failed, 1.13 rec/s, eta 7.1s\n", + "[15:45:53] [INFO] |-- 🌦️ llm-structured column 'user_query_judge' progress: 3/10 (30%) complete, 3 ok, 0 failed, 1.56 rec/s, eta 4.5s\n", + "[15:45:54] [INFO] |-- 🌦️ llm-structured column 'user_query_judge' progress: 4/10 (40%) complete, 4 ok, 0 failed, 1.33 rec/s, eta 4.5s\n", + "[15:45:54] [INFO] |-- ⛅ llm-structured column 'user_query_judge' progress: 5/10 (50%) complete, 5 ok, 0 failed, 1.49 rec/s, eta 3.4s\n", + "[15:45:54] [INFO] |-- ⛅ llm-structured column 'user_query_judge' progress: 6/10 (60%) complete, 6 ok, 0 failed, 1.61 rec/s, eta 2.5s\n", + "[15:45:55] [INFO] |-- ⛅ llm-structured column 'user_query_judge' progress: 7/10 (70%) complete, 7 ok, 0 failed, 1.70 rec/s, eta 1.8s\n", + "[15:45:56] [INFO] |-- 🌤️ llm-structured column 'user_query_judge' progress: 8/10 (80%) complete, 8 ok, 0 failed, 1.65 rec/s, eta 1.2s\n", + "[15:45:56] [INFO] |-- 🌤️ llm-structured column 'user_query_judge' progress: 9/10 (90%) complete, 9 ok, 0 failed, 1.58 rec/s, eta 0.6s\n", + "[15:46:02] [INFO] |-- ☀️ llm-structured column 'user_query_judge' progress: 10/10 (100%) complete, 10 ok, 0 failed, 0.90 rec/s, eta 0.0s\n", + "[15:46:02] [INFO] 🗂️ llm-structured model config for column 'trajectory'\n", + "[15:46:02] [INFO] |-- model: 'nvidia/openai/gpt-oss-120b'\n", + "[15:46:02] [INFO] |-- model alias: 'gpt-oss-120b'\n", + "[15:46:02] [INFO] |-- model provider: 'nvidia-inference'\n", + "[15:46:02] [INFO] |-- inference parameters:\n", + "[15:46:02] [INFO] | |-- generation_type=chat-completion\n", + "[15:46:02] [INFO] | |-- max_parallel_requests=4\n", + "[15:46:02] [INFO] | |-- max_tokens=16384\n", + "[15:46:02] [INFO] ⚡️ Processing llm-structured column 'trajectory' with 4 concurrent workers\n", + "[15:46:02] [INFO] ⏱️ llm-structured column 'trajectory' will report progress after each record\n", + "[15:46:09] [INFO] |-- 🐱 llm-structured column 'trajectory' progress: 1/10 (10%) complete, 1 ok, 0 failed, 0.14 rec/s, eta 62.4s\n", + "[15:46:11] [INFO] |-- 🐱 llm-structured column 'trajectory' progress: 2/10 (20%) complete, 2 ok, 0 failed, 0.22 rec/s, eta 35.8s\n", + "[15:46:13] [INFO] |-- 😺 llm-structured column 'trajectory' progress: 3/10 (30%) complete, 3 ok, 0 failed, 0.28 rec/s, eta 25.2s\n", + "[15:46:14] [INFO] |-- 😺 llm-structured column 'trajectory' progress: 4/10 (40%) complete, 4 ok, 0 failed, 0.34 rec/s, eta 17.8s\n", + "[15:46:15] [INFO] |-- 😸 llm-structured column 'trajectory' progress: 5/10 (50%) complete, 5 ok, 0 failed, 0.39 rec/s, eta 13.0s\n", + "[15:46:17] [INFO] |-- 😸 llm-structured column 'trajectory' progress: 6/10 (60%) complete, 6 ok, 0 failed, 0.40 rec/s, eta 9.9s\n", + "[15:46:18] [INFO] |-- 😸 llm-structured column 'trajectory' progress: 7/10 (70%) complete, 7 ok, 0 failed, 0.43 rec/s, eta 7.0s\n", + "[15:46:20] [INFO] |-- 😼 llm-structured column 'trajectory' progress: 8/10 (80%) complete, 8 ok, 0 failed, 0.44 rec/s, eta 4.5s\n", + "[15:46:22] [INFO] |-- 😼 llm-structured column 'trajectory' progress: 9/10 (90%) complete, 9 ok, 0 failed, 0.44 rec/s, eta 2.3s\n", + "[15:46:24] [INFO] |-- 🦁 llm-structured column 'trajectory' progress: 10/10 (100%) complete, 10 ok, 0 failed, 0.46 rec/s, eta 0.0s\n", + "[15:46:24] [INFO] 🗂️ llm-structured model config for column 'trajectory_judge'\n", + "[15:46:24] [INFO] |-- model: 'nvidia/openai/gpt-oss-120b'\n", + "[15:46:24] [INFO] |-- model alias: 'gpt-oss-120b'\n", + "[15:46:24] [INFO] |-- model provider: 'nvidia-inference'\n", + "[15:46:24] [INFO] |-- inference parameters:\n", + "[15:46:24] [INFO] | |-- generation_type=chat-completion\n", + "[15:46:24] [INFO] | |-- max_parallel_requests=4\n", + "[15:46:24] [INFO] | |-- max_tokens=16384\n", + "[15:46:24] [INFO] ⚡️ Processing llm-structured column 'trajectory_judge' with 4 concurrent workers\n", + "[15:46:24] [INFO] ⏱️ llm-structured column 'trajectory_judge' will report progress after each record\n", + "[15:46:26] [INFO] |-- 🚶 llm-structured column 'trajectory_judge' progress: 1/10 (10%) complete, 1 ok, 0 failed, 0.40 rec/s, eta 22.7s\n", + "[15:46:27] [INFO] |-- 🚶 llm-structured column 'trajectory_judge' progress: 2/10 (20%) complete, 2 ok, 0 failed, 0.63 rec/s, eta 12.8s\n", + "[15:46:27] [INFO] |-- 🐴 llm-structured column 'trajectory_judge' progress: 3/10 (30%) complete, 3 ok, 0 failed, 0.84 rec/s, eta 8.3s\n", + "[15:46:28] [INFO] |-- 🐴 llm-structured column 'trajectory_judge' progress: 4/10 (40%) complete, 4 ok, 0 failed, 0.86 rec/s, eta 7.0s\n", + "[15:46:29] [INFO] |-- 🚗 llm-structured column 'trajectory_judge' progress: 5/10 (50%) complete, 5 ok, 0 failed, 0.93 rec/s, eta 5.4s\n", + "[15:46:30] [INFO] |-- 🚗 llm-structured column 'trajectory_judge' progress: 6/10 (60%) complete, 6 ok, 0 failed, 0.99 rec/s, eta 4.0s\n", + "[15:46:30] [INFO] |-- 🚗 llm-structured column 'trajectory_judge' progress: 7/10 (70%) complete, 7 ok, 0 failed, 1.16 rec/s, eta 2.6s\n", + "[15:46:32] [INFO] |-- ✈️ llm-structured column 'trajectory_judge' progress: 8/10 (80%) complete, 8 ok, 0 failed, 1.00 rec/s, eta 2.0s\n", + "[15:46:33] [INFO] |-- ✈️ llm-structured column 'trajectory_judge' progress: 9/10 (90%) complete, 9 ok, 0 failed, 0.99 rec/s, eta 1.0s\n", + "[15:46:34] [INFO] |-- 🚀 llm-structured column 'trajectory_judge' progress: 10/10 (100%) complete, 10 ok, 0 failed, 1.00 rec/s, eta 0.0s\n", + "[15:46:34] [INFO] 📊 Model usage summary:\n", + "[15:46:34] [INFO] |-- model: nvidia/openai/gpt-oss-120b\n", + "[15:46:34] [INFO] |-- tokens: input=71735, output=23947, total=95682, tps=1908\n", + "[15:46:34] [INFO] |-- requests: success=40, failed=0, total=40, rpm=47\n", + "[15:46:34] [INFO] 📐 Measuring dataset column statistics:\n", + "[15:46:34] [INFO] |-- 📝 column: 'user_query'\n", + "[15:46:34] [INFO] |-- 🗂️ column: 'user_query_judge'\n", + "[15:46:34] [INFO] |-- 🗂️ column: 'trajectory'\n", + "[15:46:34] [INFO] |-- 🗂️ column: 'trajectory_judge'\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Generated 10 records\n", + "\n", + "Columns: ['seed_id', 'category', 'pattern', 'tools_description', 'tools_json', 'tools_summary', 'system_prompt', 'user_query', 'user_query_judge', 'trajectory', 'trajectory_judge']\n" + ] + } + ], + "source": [ + "print(\"Generating 10 examples...\")\n", + "results = data_designer.create(pipeline, num_records=10)\n", + "\n", + "results_df = results.load_dataset()\n", + "print(f\"\\nGenerated {len(results_df)} records\")\n", + "print(\"\\nColumns:\", list(results_df.columns))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inspect rejection reasons at both judge levels to understand what kinds of errors the pipeline catches." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "=== User Query Issues (2/10 rejected) ===\n", + " [1] I’d like a list of all the calendar events I have in the next two weeks that involve John Doe or jan...\n", + " Issues: calendar_create_event only accepts a single participant_email, so a meeting involving both John Doe and jane.smith@example.com cannot be created as a single event with the available tools.\n", + " [2] Please find every email I received from Sarah Johnson from March 1‑15, 2024 that mentions the Q1 bud...\n", + " Issues: email_forward_email does not support adding a custom note to the forwarded message, so the request to forward each email with the added note cannot be fully satisfied.\n", + "\n", + "=== Trajectory Issues (1/10 rejected) ===\n", + " [1] I’d like a list of all the calendar events I have in the next two weeks that involve John Doe or jan...\n", + " Issues: Step 2 uses a past time range (2026-01-29 to 2026-02-12) instead of the next two weeks from today. Step 3 provides participant_email as a comma‑separated list, but the schema expects a single email string. Step 3 schedules the meeting on 2026-02-06, which is not the requested next Thursday at 3 PM.\n" + ] + } + ], + "source": [ + "show_rejection_reasons(results_df, num_examples=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Apply dual-level filtering with strict schema compliance requirements. Records must pass **both** the user query judge and the trajectory judge to be kept." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "=== Quality Filtering Results ===\n", + "Total records: 10\n", + "\n", + "Stage 1 (User Query): 8/10 passed (80%)\n", + " is_valid: 8 | feasibility>=3: 8 | schema>=4: 10 | naturalness>=3: 10\n", + "\n", + "Stage 2 (Trajectory): 6/10 passed (60%)\n", + " is_valid: 9 | tool_validity>=4: 10 | arg_validity>=4: 9 | completeness>=3: 6 | efficiency>=3: 10\n", + "\n", + "Final: 6/10 passed (60%)\n" + ] + } + ], + "source": [ + "filtered_df = filter_high_quality(\n", + " results_df,\n", + " min_query_feasibility=3,\n", + " min_query_schema_compliance=4,\n", + " min_query_naturalness=3,\n", + " min_trajectory_tool_validity=4,\n", + " min_trajectory_argument_validity=4,\n", + " min_trajectory_completeness=3,\n", + " min_trajectory_efficiency=3,\n", + " verbose=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 9 (Optional): Convert to NeMo Gym Format and Save\n", + "\n", + "If you plan to use this data with **NeMo Gym**, convert filtered records into NeMo Gym JSONL format and save them.\n", + "\n", + "This conversion is NeMo Gym-specific and optional for generic Data Designer workflows." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saved 6 examples to workplace_assistant_train-gpt-oss.jsonl\n", + "\n", + "Sample generated data (passed both quality stages):\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
seed_idcategorypatterntools_descriptiontools_jsontools_summarysystem_promptuser_queryuser_query_judgetrajectorytrajectory_judge
035customer_relationship_managersearch_then_batch_delete_events: Search for ca...- **company_directory_find_email_address**: Fi...[\n", + " {\n", + " \"type\": \"function\",\n", + " \"name\": \"com...- **company_directory_find_email_address**: Fi...Today's date is Thursday, 2026-01-29 and the c...Please delete all customers assigned to me (al...{'feasibility': 5, 'is_valid': True, 'issues':...{'final_answer': 'All customers assigned to al...{'argument_validity': 5, 'completeness': 3, 'e...
145analyticsget_task_info_then_update: Get specific task i...- **company_directory_find_email_address**: Fi...[\n", + " {\n", + " \"type\": \"function\",\n", + " \"name\": \"com...- **company_directory_find_email_address**: Fi...Today's date is Thursday, 2026-01-29 and the c...Can you tell me how many website visits we had...{'feasibility': 5, 'is_valid': True, 'issues':...{'final_answer': 'I have retrieved the total n...{'argument_validity': 5, 'completeness': 3, 'e...
246project_managementlookup_then_add_customer: Look up a sales rep'...- **company_directory_find_email_address**: Fi...[\n", + " {\n", + " \"type\": \"function\",\n", + " \"name\": \"com...- **company_directory_find_email_address**: Fi...Today's date is Thursday, 2026-01-29 and the c...I need to add a follow‑up task for our new cli...{'feasibility': 5, 'is_valid': True, 'issues':...{'final_answer': 'The task \"Acme Corp – initia...{'argument_validity': 5, 'completeness': 5, 'e...
348company_directorysearch_then_batch_delete_customers: Search for...- **company_directory_find_email_address**: Fi...[\n", + " {\n", + " \"type\": \"function\",\n", + " \"name\": \"com...- **company_directory_find_email_address**: Fi...Today's date is Thursday, 2026-01-29 and the c...Could you pull the email addresses for everyon...{'feasibility': 5, 'is_valid': True, 'issues':...{'final_answer': 'Retrieved email addresses fo...{'argument_validity': 5, 'completeness': 5, 'e...
426calendarcrm_to_email: Search CRM, then send emails to ...- **company_directory_find_email_address**: Fi...[\n", + " {\n", + " \"type\": \"function\",\n", + " \"name\": \"com...- **company_directory_find_email_address**: Fi...Today's date is Thursday, 2026-01-29 and the c...Please find the meeting I have with Michael Le...{'feasibility': 5, 'is_valid': True, 'issues':...{'final_answer': 'The meeting with Michael Lee...{'argument_validity': 5, 'completeness': 5, 'e...
\n", + "
" + ], + "text/plain": [ + " seed_id category \\\n", + "0 35 customer_relationship_manager \n", + "1 45 analytics \n", + "2 46 project_management \n", + "3 48 company_directory \n", + "4 26 calendar \n", + "\n", + " pattern \\\n", + "0 search_then_batch_delete_events: Search for ca... \n", + "1 get_task_info_then_update: Get specific task i... \n", + "2 lookup_then_add_customer: Look up a sales rep'... \n", + "3 search_then_batch_delete_customers: Search for... \n", + "4 crm_to_email: Search CRM, then send emails to ... \n", + "\n", + " tools_description \\\n", + "0 - **company_directory_find_email_address**: Fi... \n", + "1 - **company_directory_find_email_address**: Fi... \n", + "2 - **company_directory_find_email_address**: Fi... \n", + "3 - **company_directory_find_email_address**: Fi... \n", + "4 - **company_directory_find_email_address**: Fi... \n", + "\n", + " tools_json \\\n", + "0 [\n", + " {\n", + " \"type\": \"function\",\n", + " \"name\": \"com... \n", + "1 [\n", + " {\n", + " \"type\": \"function\",\n", + " \"name\": \"com... \n", + "2 [\n", + " {\n", + " \"type\": \"function\",\n", + " \"name\": \"com... \n", + "3 [\n", + " {\n", + " \"type\": \"function\",\n", + " \"name\": \"com... \n", + "4 [\n", + " {\n", + " \"type\": \"function\",\n", + " \"name\": \"com... \n", + "\n", + " tools_summary \\\n", + "0 - **company_directory_find_email_address**: Fi... \n", + "1 - **company_directory_find_email_address**: Fi... \n", + "2 - **company_directory_find_email_address**: Fi... \n", + "3 - **company_directory_find_email_address**: Fi... \n", + "4 - **company_directory_find_email_address**: Fi... \n", + "\n", + " system_prompt \\\n", + "0 Today's date is Thursday, 2026-01-29 and the c... \n", + "1 Today's date is Thursday, 2026-01-29 and the c... \n", + "2 Today's date is Thursday, 2026-01-29 and the c... \n", + "3 Today's date is Thursday, 2026-01-29 and the c... \n", + "4 Today's date is Thursday, 2026-01-29 and the c... \n", + "\n", + " user_query \\\n", + "0 Please delete all customers assigned to me (al... \n", + "1 Can you tell me how many website visits we had... \n", + "2 I need to add a follow‑up task for our new cli... \n", + "3 Could you pull the email addresses for everyon... \n", + "4 Please find the meeting I have with Michael Le... \n", + "\n", + " user_query_judge \\\n", + "0 {'feasibility': 5, 'is_valid': True, 'issues':... \n", + "1 {'feasibility': 5, 'is_valid': True, 'issues':... \n", + "2 {'feasibility': 5, 'is_valid': True, 'issues':... \n", + "3 {'feasibility': 5, 'is_valid': True, 'issues':... \n", + "4 {'feasibility': 5, 'is_valid': True, 'issues':... \n", + "\n", + " trajectory \\\n", + "0 {'final_answer': 'All customers assigned to al... \n", + "1 {'final_answer': 'I have retrieved the total n... \n", + "2 {'final_answer': 'The task \"Acme Corp – initia... \n", + "3 {'final_answer': 'Retrieved email addresses fo... \n", + "4 {'final_answer': 'The meeting with Michael Lee... \n", + "\n", + " trajectory_judge \n", + "0 {'argument_validity': 5, 'completeness': 3, 'e... \n", + "1 {'argument_validity': 5, 'completeness': 3, 'e... \n", + "2 {'argument_validity': 5, 'completeness': 5, 'e... \n", + "3 {'argument_validity': 5, 'completeness': 5, 'e... \n", + "4 {'argument_validity': 5, 'completeness': 5, 'e... " + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from functools import partial\n", + "from utils import convert_to_nemo_gym_format, save_for_nemo_gym\n", + "\n", + "convert_fn = partial(convert_to_nemo_gym_format, tools=TOOLS, system_prompt=SYSTEM_PROMPT)\n", + "\n", + "output_path = \"workplace_assistant_train-gpt-oss.jsonl\"\n", + "save_for_nemo_gym(filtered_df, output_path, convert_fn=convert_fn)\n", + "\n", + "print(\"\\nSample generated data (passed both quality stages):\")\n", + "filtered_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook demonstrated how to build a complete synthetic data generation pipeline for multi-step tool-calling tasks using **Data Designer**. The pipeline generates user queries, simulates agent trajectories, and applies dual-level LLM judge filtering to produce high-quality training data.\n", + "\n", + "## Next Steps\n", + "\n", + "- **Scale up generation**: Increase `num_seeds` and `num_records` to produce larger training sets (1,000+ examples)\n", + "- **Customize for your domain**: Replace the Workplace Assistant tools with your own tool definitions\n", + "- **Add more multi-step patterns**: Define new patterns in `environment.json` to increase task diversity\n", + "- **Tune judge thresholds**: Inspect rejected examples with `show_rejection_reasons()` and adjust filtering thresholds\n", + "- **Train with NeMo Gym and NeMo RL**: Use the exported JSONL file for GRPO training with NeMo RL, using a NeMo Gym environment." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/colab_notebooks/5-multistep-toolcalling/tools/analytics.json b/docs/colab_notebooks/5-multistep-toolcalling/tools/analytics.json new file mode 100644 index 000000000..071f9b737 --- /dev/null +++ b/docs/colab_notebooks/5-multistep-toolcalling/tools/analytics.json @@ -0,0 +1,125 @@ +{ + "database": "analytics", + "description": "Website analytics data for tracking visitor behavior and engagement.", + "data_schema": { + "columns": ["date_of_visit", "visitor_id", "page_views", "session_duration_seconds", "traffic_source", "user_engaged"], + "enums": { + "traffic_source": ["direct", "referral", "search engine", "social media"] + } + }, + "tools": [ + { + "type": "function", + "name": "analytics_get_visitor_information_by_id", + "description": "Returns the analytics data for a given visitor ID.", + "database": "analytics", + "operation_type": "read", + "parameters": { + "type": "object", + "properties": { + "visitor_id": {"type": "string", "description": "ID of the visitor"} + }, + "required": ["visitor_id"], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "analytics_create_plot", + "description": "Plots the analytics data for a given time range and value.", + "database": "analytics", + "operation_type": "create", + "parameters": { + "type": "object", + "properties": { + "time_min": {"type": "string", "description": "Start date of the time range. Date format is YYYY-MM-DD"}, + "time_max": {"type": "string", "description": "End date of the time range. Date format is YYYY-MM-DD"}, + "value_to_plot": { + "type": "string", + "description": "Value to plot. Available values are: 'total_visits', 'session_duration_seconds', 'user_engaged', 'visits_direct', 'visits_referral', 'visits_search_engine', 'visits_social_media'" + }, + "plot_type": { + "type": "string", + "description": "Type of plot. Can be 'bar', 'line', 'scatter' or 'histogram'" + } + }, + "required": ["time_min", "time_max", "value_to_plot", "plot_type"], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "analytics_total_visits_count", + "description": "Returns the total number of visits within a specified time range.", + "database": "analytics", + "operation_type": "read", + "parameters": { + "type": "object", + "properties": { + "time_min": {"type": "string", "description": "Start date of the time range. Date format is YYYY-MM-DD"}, + "time_max": {"type": "string", "description": "End date of the time range. Date format is YYYY-MM-DD"} + }, + "required": [], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "analytics_engaged_users_count", + "description": "Returns the number of engaged users within a specified time range.", + "database": "analytics", + "operation_type": "read", + "parameters": { + "type": "object", + "properties": { + "time_min": {"type": "string", "description": "Start date of the time range. Date format is YYYY-MM-DD"}, + "time_max": {"type": "string", "description": "End date of the time range. Date format is YYYY-MM-DD"} + }, + "required": [], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "analytics_traffic_source_count", + "description": "Returns the number of visits from a specific traffic source within a specified time range.", + "database": "analytics", + "operation_type": "read", + "parameters": { + "type": "object", + "properties": { + "time_min": {"type": "string", "description": "Start date of the time range. Date format is YYYY-MM-DD"}, + "time_max": {"type": "string", "description": "End date of the time range. Date format is YYYY-MM-DD"}, + "traffic_source": { + "type": "string", + "description": "Traffic source to filter the visits. Available values are: 'direct', 'referral', 'search engine', 'social media'" + } + }, + "required": [], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "analytics_get_average_session_duration", + "description": "Returns the average session duration within a specified time range.", + "database": "analytics", + "operation_type": "read", + "parameters": { + "type": "object", + "properties": { + "time_min": {"type": "string", "description": "Start date of the time range. Date format is YYYY-MM-DD"}, + "time_max": {"type": "string", "description": "End date of the time range. Date format is YYYY-MM-DD"} + }, + "required": [], + "additionalProperties": false + }, + "strict": false + } + ] +} diff --git a/docs/colab_notebooks/5-multistep-toolcalling/tools/calendar.json b/docs/colab_notebooks/5-multistep-toolcalling/tools/calendar.json new file mode 100644 index 000000000..afc15a636 --- /dev/null +++ b/docs/colab_notebooks/5-multistep-toolcalling/tools/calendar.json @@ -0,0 +1,117 @@ +{ + "database": "calendar", + "description": "Calendar for managing meetings and events.", + "data_schema": { + "columns": ["event_id", "event_name", "participant_email", "event_start", "duration"] + }, + "tools": [ + { + "type": "function", + "name": "calendar_get_event_information_by_id", + "description": "Returns the event for a given ID.", + "database": "calendar", + "operation_type": "read", + "parameters": { + "type": "object", + "properties": { + "event_id": {"type": "string", "description": "8-digit ID of the event"}, + "field": { + "type": "string", + "description": "Field to return. Available fields are: 'event_id', 'event_name', 'participant_email', 'event_start', 'duration'" + } + }, + "required": ["event_id", "field"], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "calendar_search_events", + "description": "Returns the events for a given query with pagination support.", + "database": "calendar", + "operation_type": "read", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Query to search for. Terms will be matched in the event_name and participant_email fields" + }, + "time_min": { + "type": "string", + "description": "Lower bound (inclusive) for an event's end time to filter by. Format: YYYY-MM-DD HH:MM:SS" + }, + "time_max": { + "type": "string", + "description": "Upper bound (inclusive) for an event's start time to filter by. Format: YYYY-MM-DD HH:MM:SS" + }, + "page": {"type": "integer", "description": "Page number of results to return"}, + "page_size": {"type": "integer", "description": "Number of events per page"} + }, + "required": [], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "calendar_create_event", + "description": "Creates a new event.", + "database": "calendar", + "operation_type": "create", + "parameters": { + "type": "object", + "properties": { + "event_name": {"type": "string", "description": "Name of the event"}, + "participant_email": {"type": "string", "description": "Email of the participant"}, + "event_start": { + "type": "string", + "description": "Start time of the event. Format: YYYY-MM-DD HH:MM:SS" + }, + "duration": {"type": "string", "description": "Duration of the event in minutes"} + }, + "required": ["event_name", "participant_email", "event_start", "duration"], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "calendar_delete_event", + "description": "Deletes an event.", + "database": "calendar", + "operation_type": "delete", + "parameters": { + "type": "object", + "properties": { + "event_id": {"type": "string", "description": "8-digit ID of the event"} + }, + "required": ["event_id"], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "calendar_update_event", + "description": "Updates an event.", + "database": "calendar", + "operation_type": "update", + "parameters": { + "type": "object", + "properties": { + "event_id": {"type": "string", "description": "8-digit ID of the event"}, + "field": { + "type": "string", + "description": "Field to update. Available fields are: 'event_name', 'participant_email', 'event_start', 'duration'" + }, + "new_value": {"type": "string", "description": "New value for the field"} + }, + "required": ["event_id", "field", "new_value"], + "additionalProperties": false + }, + "strict": false + } + ] +} diff --git a/docs/colab_notebooks/5-multistep-toolcalling/tools/company_directory.json b/docs/colab_notebooks/5-multistep-toolcalling/tools/company_directory.json new file mode 100644 index 000000000..dfcfd21bc --- /dev/null +++ b/docs/colab_notebooks/5-multistep-toolcalling/tools/company_directory.json @@ -0,0 +1,28 @@ +{ + "database": "company_directory", + "description": "Employee directory for looking up email addresses by name.", + "data_schema": { + "columns": ["email_address"] + }, + "tools": [ + { + "type": "function", + "name": "company_directory_find_email_address", + "description": "Finds all email addresses containing the given name (case-insensitive search).", + "database": "company_directory", + "operation_type": "read", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name or partial name to search for in email addresses" + } + }, + "required": [], + "additionalProperties": false + }, + "strict": false + } + ] +} diff --git a/docs/colab_notebooks/5-multistep-toolcalling/tools/customer_relationship_manager.json b/docs/colab_notebooks/5-multistep-toolcalling/tools/customer_relationship_manager.json new file mode 100644 index 000000000..982be3161 --- /dev/null +++ b/docs/colab_notebooks/5-multistep-toolcalling/tools/customer_relationship_manager.json @@ -0,0 +1,106 @@ +{ + "database": "customer_relationship_manager", + "description": "CRM for managing customer records and sales pipeline.", + "data_schema": { + "columns": ["customer_id", "customer_name", "customer_email", "customer_phone", "last_contact_date", "product_interest", "status", "assigned_to_email", "notes", "follow_up_by"], + "enums": { + "status": ["Qualified", "Won", "Lost", "Lead", "Proposal"], + "product_interest": ["Software", "Hardware", "Services", "Consulting", "Training"] + } + }, + "tools": [ + { + "type": "function", + "name": "customer_relationship_manager_search_customers", + "description": "Searches for customers based on the given parameters with pagination support.", + "database": "customer_relationship_manager", + "operation_type": "read", + "parameters": { + "type": "object", + "properties": { + "customer_name": {"type": "string", "description": "Name of the customer"}, + "customer_email": {"type": "string", "description": "Email address of the customer"}, + "product_interest": {"type": "string", "description": "Product interest of the customer"}, + "status": {"type": "string", "description": "Current status of the customer"}, + "assigned_to_email": {"type": "string", "description": "Email address of the person assigned to the customer"}, + "last_contact_date_min": {"type": "string", "description": "Minimum last contact date. Format: YYYY-MM-DD"}, + "last_contact_date_max": {"type": "string", "description": "Maximum last contact date. Format: YYYY-MM-DD"}, + "follow_up_by_min": {"type": "string", "description": "Minimum follow up date. Format: YYYY-MM-DD"}, + "follow_up_by_max": {"type": "string", "description": "Maximum follow up date. Format: YYYY-MM-DD"}, + "page": {"type": "integer", "description": "Page number of results to return"}, + "page_size": {"type": "integer", "description": "Number of customers per page"} + }, + "required": [], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "customer_relationship_manager_update_customer", + "description": "Updates a customer record by ID.", + "database": "customer_relationship_manager", + "operation_type": "update", + "parameters": { + "type": "object", + "properties": { + "customer_id": {"type": "string", "description": "ID of the customer"}, + "field": { + "type": "string", + "description": "Field to update. Available fields are: 'customer_name', 'assigned_to_email', 'customer_email', 'customer_phone', 'last_contact_date', 'product_interest', 'status', 'notes', 'follow_up_by'" + }, + "new_value": {"type": "string", "description": "New value for the field"} + }, + "required": ["customer_id", "field", "new_value"], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "customer_relationship_manager_add_customer", + "description": "Adds a new customer record.", + "database": "customer_relationship_manager", + "operation_type": "create", + "parameters": { + "type": "object", + "properties": { + "customer_name": {"type": "string", "description": "Name of the customer"}, + "assigned_to_email": {"type": "string", "description": "Email address of the person assigned to the customer"}, + "status": { + "type": "string", + "description": "Current status of the customer. One of: 'Qualified', 'Won', 'Lost', 'Lead', 'Proposal'" + }, + "customer_email": {"type": "string", "description": "Email address of the customer"}, + "customer_phone": {"type": "string", "description": "Phone number of the customer"}, + "last_contact_date": {"type": "string", "description": "The last date the customer was contacted. Format: YYYY-MM-DD"}, + "product_interest": { + "type": "string", + "description": "Product interest of the customer. One of: 'Software', 'Hardware', 'Services', 'Consulting', 'Training'" + }, + "notes": {"type": "string", "description": "Notes about the customer"}, + "follow_up_by": {"type": "string", "description": "Date for the next follow up. Format: YYYY-MM-DD"} + }, + "required": ["customer_name", "assigned_to_email", "status"], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "customer_relationship_manager_delete_customer", + "description": "Deletes a customer record by ID.", + "database": "customer_relationship_manager", + "operation_type": "delete", + "parameters": { + "type": "object", + "properties": { + "customer_id": {"type": "string", "description": "ID of the customer"} + }, + "required": ["customer_id"], + "additionalProperties": false + }, + "strict": false + } + ] +} diff --git a/docs/colab_notebooks/5-multistep-toolcalling/tools/email.json b/docs/colab_notebooks/5-multistep-toolcalling/tools/email.json new file mode 100644 index 000000000..b2cfa8f6e --- /dev/null +++ b/docs/colab_notebooks/5-multistep-toolcalling/tools/email.json @@ -0,0 +1,126 @@ +{ + "database": "email", + "description": "Email inbox and outbox for sending, receiving, and managing emails.", + "data_schema": { + "columns": ["email_id", "inbox/outbox", "sender/recipient", "subject", "sent_datetime", "body"] + }, + "tools": [ + { + "type": "function", + "name": "email_get_email_information_by_id", + "description": "Retrieves specific details of an email by its ID.", + "database": "email", + "operation_type": "read", + "parameters": { + "type": "object", + "properties": { + "email_id": {"type": "string", "description": "Unique ID of the email"}, + "field": { + "type": "string", + "description": "Specific field to return. Available fields: 'email_id', 'inbox/outbox', 'sender/recipient', 'subject', 'sent_datetime', 'body'" + } + }, + "required": ["email_id", "field"], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "email_search_emails", + "description": "Searches for emails matching the given query across subject, body, or sender fields. The function matches an email if all words in the query appear in any of these fields.", + "database": "email", + "operation_type": "read", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query, matching terms in subject, body, or sender/recipient fields" + }, + "date_min": { + "type": "string", + "description": "Lower date limit for the email's sent date (inclusive). Format: YYYY-MM-DD" + }, + "date_max": { + "type": "string", + "description": "Upper date limit for the email's sent date (inclusive). Format: YYYY-MM-DD" + }, + "page": {"type": "integer", "description": "Page number of results to return"}, + "page_size": {"type": "integer", "description": "Number of emails per page"} + }, + "required": [], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "email_send_email", + "description": "Sends an email to the specified recipient.", + "database": "email", + "operation_type": "create", + "parameters": { + "type": "object", + "properties": { + "recipient": {"type": "string", "description": "Email address of the recipient"}, + "subject": {"type": "string", "description": "Subject line of the email"}, + "body": {"type": "string", "description": "Body content of the email"} + }, + "required": ["recipient", "subject", "body"], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "email_delete_email", + "description": "Deletes an email by its ID.", + "database": "email", + "operation_type": "delete", + "parameters": { + "type": "object", + "properties": { + "email_id": {"type": "string", "description": "Unique ID of the email to be deleted"} + }, + "required": ["email_id"], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "email_forward_email", + "description": "Forwards an email to the specified recipient.", + "database": "email", + "operation_type": "create", + "parameters": { + "type": "object", + "properties": { + "email_id": {"type": "string", "description": "Unique ID of the email to be forwarded"}, + "recipient": {"type": "string", "description": "Email address of the recipient"} + }, + "required": ["email_id", "recipient"], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "email_reply_email", + "description": "Replies to an email by its ID.", + "database": "email", + "operation_type": "create", + "parameters": { + "type": "object", + "properties": { + "email_id": {"type": "string", "description": "Unique ID of the email to be replied"}, + "body": {"type": "string", "description": "Body content of the email"} + }, + "required": ["email_id", "body"], + "additionalProperties": false + }, + "strict": false + } + ] +} diff --git a/docs/colab_notebooks/5-multistep-toolcalling/tools/environment.json b/docs/colab_notebooks/5-multistep-toolcalling/tools/environment.json new file mode 100644 index 000000000..b7d4e7433 --- /dev/null +++ b/docs/colab_notebooks/5-multistep-toolcalling/tools/environment.json @@ -0,0 +1,277 @@ +{ + "environment_name": "workplace_assistant", + "description": "A simulated workplace environment with 5 databases (email, calendar, analytics, project_management, customer_relationship_manager) plus a company directory utility, totaling 27 tools for managing business activities.", + "system_prompt": "Today's date is Thursday, 2026-01-29 and the current time is 23:59:00. Remember the current date and time when answering queries. Meetings must not start before 9am or end after 6pm.", + "databases": [ + "company_directory", + "email", + "calendar", + "analytics", + "project_management", + "customer_relationship_manager" + ], + "tool_count": { + "company_directory": 1, + "email": 6, + "calendar": 5, + "analytics": 6, + "project_management": 5, + "customer_relationship_manager": 4, + "total": 27 + }, + "common_multi_step_patterns": [ + { + "id": "email_lookup_send", + "pattern": "lookup_then_send_email", + "description": "Look up a person's email address, then send them an email", + "tools_used": ["company_directory_find_email_address", "email_send_email"], + "example_request": "Send an email to John about the quarterly review meeting tomorrow", + "expected_steps": 2 + }, + { + "id": "email_lookup_forward", + "pattern": "lookup_then_forward_email", + "description": "Search for an email, then forward it to someone after looking up their address", + "tools_used": ["email_search_emails", "company_directory_find_email_address", "email_forward_email"], + "example_request": "Forward the latest email about the budget report to Sarah", + "expected_steps": 3 + }, + { + "id": "email_search_reply", + "pattern": "search_then_reply_email", + "description": "Search for a specific email, then reply to it", + "tools_used": ["email_search_emails", "email_reply_email"], + "example_request": "Reply to Carlos's last email about the prototype with 'Thanks, I'll review it tomorrow'", + "expected_steps": 2 + }, + { + "id": "email_search_delete_batch", + "pattern": "search_then_batch_delete_emails", + "description": "Search for emails matching criteria, then delete multiple emails", + "tools_used": ["email_search_emails", "email_delete_email"], + "example_request": "Delete all emails from last week about the cancelled project", + "expected_steps": "2-6" + }, + { + "id": "email_get_info_then_act", + "pattern": "get_email_info_then_forward", + "description": "Get specific information from an email, then forward it based on content", + "tools_used": ["email_get_email_information_by_id", "company_directory_find_email_address", "email_forward_email"], + "example_request": "Check who sent the email about the deadline and forward it to the project lead", + "expected_steps": 3 + }, + { + "id": "calendar_lookup_create", + "pattern": "lookup_then_create_event", + "description": "Look up a person's email, then create a calendar event with them", + "tools_used": ["company_directory_find_email_address", "calendar_create_event"], + "example_request": "Schedule a 30-minute meeting with Lisa tomorrow at 2pm called 'Project Sync'", + "expected_steps": 2 + }, + { + "id": "calendar_search_update_batch", + "pattern": "search_then_batch_update_events", + "description": "Search for calendar events, then update multiple events", + "tools_used": ["calendar_search_events", "calendar_update_event"], + "example_request": "Reschedule all of Mike's meetings on Friday to start 1 hour later", + "expected_steps": "2-6" + }, + { + "id": "calendar_search_delete_batch", + "pattern": "search_then_batch_delete_events", + "description": "Search for calendar events, then delete multiple events", + "tools_used": ["calendar_search_events", "calendar_delete_event"], + "example_request": "Cancel all meetings with the vendor next week", + "expected_steps": "2-6" + }, + { + "id": "calendar_get_info_update", + "pattern": "get_event_info_then_update", + "description": "Get specific event information, then update based on that info", + "tools_used": ["calendar_get_event_information_by_id", "calendar_update_event"], + "example_request": "Change the name of the last event on December 1st to 'Risk Management Forum'", + "expected_steps": 2 + }, + { + "id": "analytics_multi_query", + "pattern": "multiple_analytics_queries", + "description": "Query multiple analytics metrics for comparison or reporting", + "tools_used": ["analytics_total_visits_count", "analytics_engaged_users_count", "analytics_traffic_source_count"], + "example_request": "Get me the total visits, engaged users, and social media traffic for last month", + "expected_steps": 3 + }, + { + "id": "analytics_query_then_plot", + "pattern": "query_analytics_then_create_plot", + "description": "Query analytics data, then create a visualization", + "tools_used": ["analytics_total_visits_count", "analytics_create_plot"], + "example_request": "Show me total visits for November and create a bar chart of daily visits", + "expected_steps": 2 + }, + { + "id": "analytics_visitor_deep_dive", + "pattern": "get_visitor_info_and_session_stats", + "description": "Look up specific visitor information and get session statistics", + "tools_used": ["analytics_get_visitor_information_by_id", "analytics_get_average_session_duration"], + "example_request": "Get details on visitor V12345 and compare their session to the average", + "expected_steps": 2 + }, + { + "id": "analytics_traffic_comparison", + "pattern": "compare_traffic_sources", + "description": "Compare multiple traffic sources over a time period", + "tools_used": ["analytics_traffic_source_count"], + "example_request": "Compare direct vs referral vs search engine traffic for the last 2 weeks", + "expected_steps": 3 + }, + { + "id": "pm_lookup_create_task", + "pattern": "lookup_then_create_task", + "description": "Look up a person's email, then create a task assigned to them", + "tools_used": ["company_directory_find_email_address", "project_management_create_task"], + "example_request": "Create a new Backend task 'Fix login bug' in the Backlog assigned to Alex, due next Friday", + "expected_steps": 2 + }, + { + "id": "pm_search_update_batch", + "pattern": "search_then_batch_update_tasks", + "description": "Search for tasks, then update multiple tasks", + "tools_used": ["company_directory_find_email_address", "project_management_search_tasks", "project_management_update_task"], + "example_request": "Move all of Sarah's In Progress tasks on the Backend board to In Review", + "expected_steps": "3-6" + }, + { + "id": "pm_search_delete_batch", + "pattern": "search_then_batch_delete_tasks", + "description": "Search for tasks, then delete multiple tasks", + "tools_used": ["project_management_search_tasks", "project_management_delete_task"], + "example_request": "Delete all completed tasks on the Design board from last month", + "expected_steps": "2-6" + }, + { + "id": "pm_reassign_tasks", + "pattern": "lookup_and_reassign_tasks", + "description": "Look up two people's emails, search for tasks, then reassign them", + "tools_used": ["company_directory_find_email_address", "project_management_search_tasks", "project_management_update_task"], + "example_request": "Reassign all of Tom's Frontend tasks to Jennifer", + "expected_steps": "4-6" + }, + { + "id": "pm_get_info_update", + "pattern": "get_task_info_then_update", + "description": "Get specific task information, then update it", + "tools_used": ["project_management_get_task_information_by_id", "project_management_update_task"], + "example_request": "Check who's assigned to task 00000123 and change the due date to next week", + "expected_steps": 2 + }, + { + "id": "crm_lookup_add", + "pattern": "lookup_then_add_customer", + "description": "Look up a sales rep's email, then add a new customer assigned to them", + "tools_used": ["company_directory_find_email_address", "customer_relationship_manager_add_customer"], + "example_request": "Add a new lead 'Acme Corp' interested in Software, assigned to Maria", + "expected_steps": 2 + }, + { + "id": "crm_search_update_batch", + "pattern": "search_then_batch_update_customers", + "description": "Search for customers, then update multiple customer records", + "tools_used": ["company_directory_find_email_address", "customer_relationship_manager_search_customers", "customer_relationship_manager_update_customer"], + "example_request": "Raj is taking over all of Akira's software leads. Reassign them in the CRM.", + "expected_steps": "4-6" + }, + { + "id": "crm_search_delete_batch", + "pattern": "search_then_batch_delete_customers", + "description": "Search for customers, then delete multiple customer records", + "tools_used": ["company_directory_find_email_address", "customer_relationship_manager_search_customers", "customer_relationship_manager_delete_customer"], + "example_request": "Delete all of Lena's won customers interested in services from the CRM", + "expected_steps": "3-6" + }, + { + "id": "crm_status_update", + "pattern": "search_then_update_status", + "description": "Search for customers by criteria including date, then update their status", + "tools_used": ["customer_relationship_manager_search_customers", "customer_relationship_manager_update_customer"], + "example_request": "Move all hardware proposals that haven't been contacted in 4 weeks to Lost", + "expected_steps": "2-6" + }, + { + "id": "cross_email_calendar", + "pattern": "email_to_calendar", + "description": "Search emails for meeting details, then create calendar events", + "tools_used": ["email_search_emails", "email_get_email_information_by_id", "company_directory_find_email_address", "calendar_create_event"], + "example_request": "Find the email about the client meeting and schedule it on the calendar with attendees", + "expected_steps": "3-5" + }, + { + "id": "cross_calendar_email", + "pattern": "calendar_to_email", + "description": "Search calendar for events, then send email notifications", + "tools_used": ["calendar_search_events", "calendar_get_event_information_by_id", "email_send_email"], + "example_request": "Find tomorrow's meetings and send a reminder email to all participants", + "expected_steps": "3-6" + }, + { + "id": "cross_pm_email", + "pattern": "tasks_to_email", + "description": "Search tasks, then send email updates about them", + "tools_used": ["project_management_search_tasks", "company_directory_find_email_address", "email_send_email"], + "example_request": "Email the team lead about all overdue Backend tasks", + "expected_steps": "3-4" + }, + { + "id": "cross_crm_calendar", + "pattern": "crm_to_calendar", + "description": "Search CRM for customers needing follow-up, then schedule meetings", + "tools_used": ["customer_relationship_manager_search_customers", "company_directory_find_email_address", "calendar_create_event"], + "example_request": "Schedule follow-up meetings for all qualified leads with follow-up dates this week", + "expected_steps": "3-6" + }, + { + "id": "cross_crm_email", + "pattern": "crm_to_email", + "description": "Search CRM, then send emails to customers or sales reps", + "tools_used": ["customer_relationship_manager_search_customers", "email_send_email"], + "example_request": "Send a thank-you email to all customers marked as Won this month", + "expected_steps": "2-6" + }, + { + "id": "cross_analytics_email", + "pattern": "analytics_to_email", + "description": "Query analytics data, then send a report via email", + "tools_used": ["analytics_total_visits_count", "analytics_engaged_users_count", "company_directory_find_email_address", "email_send_email"], + "example_request": "Get this week's website stats and email the summary to the marketing manager", + "expected_steps": "4-5" + } + ], + "tool_coverage_matrix": { + "company_directory_find_email_address": ["email_lookup_send", "email_lookup_forward", "calendar_lookup_create", "pm_lookup_create_task", "pm_search_update_batch", "pm_reassign_tasks", "crm_lookup_add", "crm_search_update_batch", "crm_search_delete_batch", "cross_email_calendar", "cross_pm_email", "cross_crm_calendar", "cross_analytics_email"], + "email_get_email_information_by_id": ["email_get_info_then_act", "cross_email_calendar"], + "email_search_emails": ["email_lookup_forward", "email_search_reply", "email_search_delete_batch", "cross_email_calendar"], + "email_send_email": ["email_lookup_send", "cross_calendar_email", "cross_pm_email", "cross_crm_email", "cross_analytics_email"], + "email_delete_email": ["email_search_delete_batch"], + "email_forward_email": ["email_lookup_forward", "email_get_info_then_act"], + "email_reply_email": ["email_search_reply"], + "calendar_get_event_information_by_id": ["calendar_get_info_update", "cross_calendar_email"], + "calendar_search_events": ["calendar_search_update_batch", "calendar_search_delete_batch", "cross_calendar_email"], + "calendar_create_event": ["calendar_lookup_create", "cross_email_calendar", "cross_crm_calendar"], + "calendar_delete_event": ["calendar_search_delete_batch"], + "calendar_update_event": ["calendar_search_update_batch", "calendar_get_info_update"], + "analytics_get_visitor_information_by_id": ["analytics_visitor_deep_dive"], + "analytics_create_plot": ["analytics_query_then_plot"], + "analytics_total_visits_count": ["analytics_multi_query", "analytics_query_then_plot", "cross_analytics_email"], + "analytics_engaged_users_count": ["analytics_multi_query", "cross_analytics_email"], + "analytics_traffic_source_count": ["analytics_multi_query", "analytics_traffic_comparison"], + "analytics_get_average_session_duration": ["analytics_visitor_deep_dive"], + "project_management_get_task_information_by_id": ["pm_get_info_update"], + "project_management_search_tasks": ["pm_search_update_batch", "pm_search_delete_batch", "pm_reassign_tasks", "cross_pm_email"], + "project_management_create_task": ["pm_lookup_create_task"], + "project_management_delete_task": ["pm_search_delete_batch"], + "project_management_update_task": ["pm_search_update_batch", "pm_reassign_tasks", "pm_get_info_update"], + "customer_relationship_manager_search_customers": ["crm_search_update_batch", "crm_search_delete_batch", "crm_status_update", "cross_crm_calendar", "cross_crm_email"], + "customer_relationship_manager_update_customer": ["crm_search_update_batch", "crm_status_update"], + "customer_relationship_manager_add_customer": ["crm_lookup_add"], + "customer_relationship_manager_delete_customer": ["crm_search_delete_batch"] + } +} diff --git a/docs/colab_notebooks/5-multistep-toolcalling/tools/project_management.json b/docs/colab_notebooks/5-multistep-toolcalling/tools/project_management.json new file mode 100644 index 000000000..da38bcc3f --- /dev/null +++ b/docs/colab_notebooks/5-multistep-toolcalling/tools/project_management.json @@ -0,0 +1,116 @@ +{ + "database": "project_management", + "description": "Project management board for tracking tasks across teams.", + "data_schema": { + "columns": ["task_id", "task_name", "assigned_to_email", "list_name", "due_date", "board"], + "enums": { + "list_name": ["Backlog", "In Progress", "In Review", "Completed"], + "board": ["Back end", "Front end", "Design"] + } + }, + "tools": [ + { + "type": "function", + "name": "project_management_get_task_information_by_id", + "description": "Returns the task information for a given ID.", + "database": "project_management", + "operation_type": "read", + "parameters": { + "type": "object", + "properties": { + "task_id": {"type": "string", "description": "8-digit ID of the task"}, + "field": { + "type": "string", + "description": "Field to return. Available fields are: 'task_id', 'task_name', 'assigned_to_email', 'list_name', 'due_date', 'board'" + } + }, + "required": ["task_id", "field"], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "project_management_search_tasks", + "description": "Searches for tasks based on the given parameters.", + "database": "project_management", + "operation_type": "read", + "parameters": { + "type": "object", + "properties": { + "task_name": {"type": "string", "description": "Name of the task"}, + "assigned_to_email": {"type": "string", "description": "Email address of the person assigned to the task"}, + "list_name": {"type": "string", "description": "Name of the list the task belongs to"}, + "due_date": {"type": "string", "description": "Due date of the task in YYYY-MM-DD format"}, + "board": {"type": "string", "description": "Name of the board the task belongs to"} + }, + "required": [], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "project_management_create_task", + "description": "Creates a new task.", + "database": "project_management", + "operation_type": "create", + "parameters": { + "type": "object", + "properties": { + "task_name": {"type": "string", "description": "Name of the task"}, + "assigned_to_email": {"type": "string", "description": "Email address of the person assigned to the task"}, + "list_name": { + "type": "string", + "description": "Name of the list the task belongs to. One of: 'Backlog', 'In Progress', 'In Review', 'Completed'" + }, + "due_date": {"type": "string", "description": "Due date of the task in YYYY-MM-DD format"}, + "board": { + "type": "string", + "description": "Name of the board the task belongs to. One of: 'Back end', 'Front end', 'Design'" + } + }, + "required": ["task_name", "assigned_to_email", "list_name", "due_date", "board"], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "project_management_delete_task", + "description": "Deletes a task by ID.", + "database": "project_management", + "operation_type": "delete", + "parameters": { + "type": "object", + "properties": { + "task_id": {"type": "string", "description": "8-digit ID of the task"} + }, + "required": ["task_id"], + "additionalProperties": false + }, + "strict": false + }, + { + "type": "function", + "name": "project_management_update_task", + "description": "Updates a task by ID.", + "database": "project_management", + "operation_type": "update", + "parameters": { + "type": "object", + "properties": { + "task_id": {"type": "string", "description": "8-digit ID of the task"}, + "field": { + "type": "string", + "description": "Field to update. Available fields are: 'task_name', 'assigned_to_email', 'list_name', 'due_date', 'board'" + }, + "new_value": {"type": "string", "description": "New value for the field"} + }, + "required": ["task_id", "field", "new_value"], + "additionalProperties": false + }, + "strict": false + } + ] +} diff --git a/docs/colab_notebooks/5-multistep-toolcalling/utils/__init__.py b/docs/colab_notebooks/5-multistep-toolcalling/utils/__init__.py new file mode 100644 index 000000000..8a76f5133 --- /dev/null +++ b/docs/colab_notebooks/5-multistep-toolcalling/utils/__init__.py @@ -0,0 +1,9 @@ +from .quality_filtering import filter_high_quality, show_rejection_reasons +from .convert_to_nemo_gym_format import convert_to_nemo_gym_format, save_for_nemo_gym + +__all__ = [ + "convert_to_nemo_gym_format", + "filter_high_quality", + "save_for_nemo_gym", + "show_rejection_reasons", +] diff --git a/docs/colab_notebooks/5-multistep-toolcalling/utils/convert_to_nemo_gym_format.py b/docs/colab_notebooks/5-multistep-toolcalling/utils/convert_to_nemo_gym_format.py new file mode 100644 index 000000000..1585d2c68 --- /dev/null +++ b/docs/colab_notebooks/5-multistep-toolcalling/utils/convert_to_nemo_gym_format.py @@ -0,0 +1,75 @@ +"""Utilities for converting generated records to NeMo Gym JSONL format.""" + +from __future__ import annotations + +import json +from typing import Any, Callable + +import pandas as pd + + +def convert_to_nemo_gym_format( + row: dict[str, Any], + idx: int, + tools: list[dict[str, Any]], + system_prompt: str, + environment_name: str = "workplace_assistant", +) -> dict[str, Any]: + """Convert a generated row to NeMo Gym rollout format.""" + trajectory = row.get("trajectory", {}) + if isinstance(trajectory, str): + trajectory = json.loads(trajectory) + + ground_truth = [] + for step in trajectory.get("reasoning_trace", []): + tool_call = step.get("tool_call", {}) + ground_truth.append( + { + "name": tool_call.get("name", ""), + "arguments": tool_call.get("arguments", "{}"), + } + ) + + cleaned_tools = [ + { + "type": tool.get("type"), + "name": tool.get("name"), + "description": tool.get("description"), + "parameters": tool.get("parameters"), + "strict": tool.get("strict"), + } + for tool in tools + ] + + return { + "id": idx, + "responses_create_params": { + "input": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": row.get("user_query", "")}, + ], + "tools": cleaned_tools, + "parallel_tool_calls": False, + "temperature": 1.0, + }, + "ground_truth": ground_truth, + "category": f"workplace_assistant_{row.get('category', 'general')}", + "environment_name": environment_name, + "user_query_judge": row.get("user_query_judge", {}), + "trajectory_judge": row.get("trajectory_judge", {}), + "pattern": row.get("pattern", ""), + } + + +def save_for_nemo_gym( + df: pd.DataFrame, + output_path: str, + convert_fn: Callable[[dict[str, Any], int], dict[str, Any]], +) -> None: + """Save records as JSONL for NeMo Gym.""" + with open(output_path, "w") as f: + for idx, row in df.iterrows(): + record = convert_fn(row.to_dict(), idx) + f.write(json.dumps(record) + "\n") + + print(f"Saved {len(df)} examples to {output_path}") diff --git a/docs/colab_notebooks/5-multistep-toolcalling/utils/quality_filtering.py b/docs/colab_notebooks/5-multistep-toolcalling/utils/quality_filtering.py new file mode 100644 index 000000000..f33d43c25 --- /dev/null +++ b/docs/colab_notebooks/5-multistep-toolcalling/utils/quality_filtering.py @@ -0,0 +1,86 @@ +"""Utilities for dual-level quality filtering of generated datasets.""" + +from __future__ import annotations + +import json +from typing import Any + +import pandas as pd + + +def _parse_scores(scores: Any) -> dict[str, Any]: + """Normalize judge outputs to dictionaries.""" + if isinstance(scores, str): + return json.loads(scores) + return scores or {} + + +def filter_high_quality( + df: pd.DataFrame, + min_query_feasibility: int = 3, + min_query_schema_compliance: int = 4, + min_query_naturalness: int = 3, + min_trajectory_tool_validity: int = 4, + min_trajectory_argument_validity: int = 4, + min_trajectory_completeness: int = 3, + min_trajectory_efficiency: int = 3, + verbose: bool = True, +) -> pd.DataFrame: + """Filter generated data with dual-level quality control. + + Stage 1 checks user-query quality. + Stage 2 checks trajectory quality. + Records must pass both stages. + """ + out = df.copy() + out["_query_scores"] = out["user_query_judge"].apply(_parse_scores) + out["_traj_scores"] = out["trajectory_judge"].apply(_parse_scores) + + # Stage 1: user query quality + query_is_valid = out["_query_scores"].apply(lambda x: x.get("is_valid", False)) == True # noqa: E712 + query_feasibility_ok = out["_query_scores"].apply(lambda x: x.get("feasibility", 0)) >= min_query_feasibility + query_schema_ok = out["_query_scores"].apply(lambda x: x.get("schema_compliance", 0)) >= min_query_schema_compliance + query_natural_ok = out["_query_scores"].apply(lambda x: x.get("naturalness", 0)) >= min_query_naturalness + query_passed = query_is_valid & query_feasibility_ok & query_schema_ok & query_natural_ok + + # Stage 2: trajectory quality + traj_is_valid = out["_traj_scores"].apply(lambda x: x.get("is_valid", False)) == True # noqa: E712 + traj_tool_ok = out["_traj_scores"].apply(lambda x: x.get("tool_validity", 0)) >= min_trajectory_tool_validity + traj_args_ok = out["_traj_scores"].apply(lambda x: x.get("argument_validity", 0)) >= min_trajectory_argument_validity + traj_complete_ok = out["_traj_scores"].apply(lambda x: x.get("completeness", 0)) >= min_trajectory_completeness + traj_efficient_ok = out["_traj_scores"].apply(lambda x: x.get("efficiency", 0)) >= min_trajectory_efficiency + traj_passed = traj_is_valid & traj_tool_ok & traj_args_ok & traj_complete_ok & traj_efficient_ok + + final_passed = query_passed & traj_passed + + if verbose: + n = len(out) + print(f"\n=== Quality Filtering Results ===") + print(f"Total records: {n}") + print(f"\nStage 1 (User Query): {query_passed.sum()}/{n} passed ({query_passed.mean() * 100:.0f}%)") + print(f" is_valid: {query_is_valid.sum()} | feasibility>={min_query_feasibility}: {query_feasibility_ok.sum()} " + f"| schema>={min_query_schema_compliance}: {query_schema_ok.sum()} | naturalness>={min_query_naturalness}: {query_natural_ok.sum()}") + print(f"\nStage 2 (Trajectory): {traj_passed.sum()}/{n} passed ({traj_passed.mean() * 100:.0f}%)") + print(f" is_valid: {traj_is_valid.sum()} | tool_validity>={min_trajectory_tool_validity}: {traj_tool_ok.sum()} " + f"| arg_validity>={min_trajectory_argument_validity}: {traj_args_ok.sum()} " + f"| completeness>={min_trajectory_completeness}: {traj_complete_ok.sum()} " + f"| efficiency>={min_trajectory_efficiency}: {traj_efficient_ok.sum()}") + print(f"\nFinal: {final_passed.sum()}/{n} passed ({final_passed.mean() * 100:.0f}%)") + + return out[final_passed].drop(columns=["_query_scores", "_traj_scores"]).reset_index(drop=True) + + +def show_rejection_reasons(df: pd.DataFrame, num_examples: int = 5) -> None: + """Print example rejection reasons from both judges.""" + query_scores = df["user_query_judge"].apply(_parse_scores) + traj_scores = df["trajectory_judge"].apply(_parse_scores) + + for label, scores in [("User Query", query_scores), ("Trajectory", traj_scores)]: + rejected = scores[scores.apply(lambda x: not x.get("is_valid", True))] + print(f"\n=== {label} Issues ({len(rejected)}/{len(df)} rejected) ===") + if len(rejected) == 0: + print(" No issues found.") + continue + for i, (idx, s) in enumerate(rejected.head(num_examples).items()): + print(f" [{i + 1}] {df.loc[idx, 'user_query'][:100]}...") + print(f" Issues: {s.get('issues', 'N/A')}")