diff --git a/.gitignore b/.gitignore index a292cd68..0400e75c 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,5 @@ terraform/.tfvars try* -*-2.db +*.db *_code*.py \ No newline at end of file diff --git a/DEFAULT_AGENTS_SETUP.md b/DEFAULT_AGENTS_SETUP.md new file mode 100644 index 00000000..df341e55 --- /dev/null +++ b/DEFAULT_AGENTS_SETUP.md @@ -0,0 +1,47 @@ +# Default Agents Setup Guide + +## Overview + +The system includes 4 default agents automatically loaded on app startup: + +1. **Data Preprocessing Agent** - Data cleaning and preparation +2. **Statistical Analytics Agent** - Statistical analysis using statsmodels +3. **Machine Learning Agent** - ML modeling using scikit-learn +4. **Data Visualization Agent** - Interactive visualizations using Plotly + +## Automatic Setup + +Default agents are automatically initialized when the application starts. You'll see: + +``` +Initializing default agents on startup... +βœ… Default agents initialized successfully +``` + +## User Template Preferences + +- **Default agents (preprocessing, statistical_analytics, sk_learn, data_viz) are ENABLED by default** for all users +- **Other templates are DISABLED by default** and must be explicitly enabled +- Only enabled templates appear in the planner + +## Key Features + +- Agents load automatically on startup +- No manual setup required +- Default agents are active by default for better user experience +- Users can still disable default agents if desired +- Planner shows helpful messages when no agents enabled +- Full API support for template management + +## API Endpoints + +- `GET /templates/user/{user_id}` - Get user preferences +- `POST /templates/user/{user_id}/template/{template_id}/toggle` - Enable/disable templates +- `GET /templates/user/{user_id}/enabled` - Get enabled templates only + +## Manual Script (Optional) + +For manual updates: +```bash +python load_default_agents.py +``` \ No newline at end of file diff --git a/auto-analyst-backend/.gitignore b/auto-analyst-backend/.gitignore index 5f058853..43b85dde 100644 --- a/auto-analyst-backend/.gitignore +++ b/auto-analyst-backend/.gitignore @@ -25,7 +25,7 @@ migrations/ alembic.ini -*-2.db +*.db schema*.md diff --git a/auto-analyst-backend/DEFAULT_AGENTS_SETUP.md b/auto-analyst-backend/DEFAULT_AGENTS_SETUP.md new file mode 100644 index 00000000..1242c834 --- /dev/null +++ b/auto-analyst-backend/DEFAULT_AGENTS_SETUP.md @@ -0,0 +1,237 @@ +# Default Agents Setup Guide + +This guide explains how to set up and use the default agents system in the Auto-Analyst backend. + +## Overview + +The system now includes 4 default agents that are stored in the database as templates: + +1. **Data Preprocessing Agent** (`preprocessing_agent`) - Data cleaning and preparation +2. **Statistical Analytics Agent** (`statistical_analytics_agent`) - Statistical analysis using statsmodels +3. **Machine Learning Agent** (`sk_learn_agent`) - ML modeling using scikit-learn +4. **Data Visualization Agent** (`data_viz_agent`) - Interactive visualizations using Plotly + +## Setup Instructions + +### 1. Load Default Agents into Database + +Run the setup script to populate the database with default agents: + +```bash +cd Auto-Analyst-CS/auto-analyst-backend +python load_default_agents.py +``` + +**Or** use the API endpoint: + +```bash +curl -X POST "http://localhost:8000/templates/load-default-agents" \ + -H "Content-Type: application/json" \ + -d '{"force_update": false}' +``` + +### 2. Agent Properties + +All default agents are created with: +- `is_active = True` (available for use) +- `is_premium_only = False` (free to use) +- Proper categories (Data Manipulation, Statistical Analysis, Modelling, Visualization) + +## User Preferences System + +### Default Behavior +- **Default agents (preprocessing, statistical_analytics, sk_learn, data_viz) are ENABLED by default** for all users +- **Other templates are DISABLED by default** and must be explicitly enabled +- Templates can be used directly via `@template_name` regardless of preferences + +### Managing User Preferences + +#### Enable/Disable Templates +```bash +# Enable a template for a user +curl -X POST "http://localhost:8000/templates/user/1/template/1/toggle" \ + -H "Content-Type: application/json" \ + -d '{"is_enabled": true}' + +# Disable a template for a user +curl -X POST "http://localhost:8000/templates/user/1/template/1/toggle" \ + -H "Content-Type: application/json" \ + -d '{"is_enabled": false}' +``` + +#### Bulk Enable/Disable +```bash +# Enable multiple templates at once +curl -X POST "http://localhost:8000/templates/user/1/bulk-toggle" \ + -H "Content-Type: application/json" \ + -d '{ + "template_preferences": { + "1": true, + "2": true, + "3": false + } + }' +``` + +#### Get User's Template Preferences +```bash +# Get all templates with user's enabled/disabled status +curl "http://localhost:8000/templates/user/1" + +# Get only enabled templates for user +curl "http://localhost:8000/templates/user/1/enabled" + +# Get enabled templates for planner (max 10, ordered by usage) +curl "http://localhost:8000/templates/user/1/enabled/planner" +``` + +## Planner Integration + +### How It Works +1. **Template Loading**: Only user-enabled templates are loaded into the planner +2. **No Agents Available**: If no templates are enabled, planner returns a helpful message +3. **Usage Tracking**: Template usage is tracked for prioritization + +### Planner Response When No Agents Enabled +```json +{ + "complexity": "no_agents_available", + "plan": "no_agents_available", + "plan_instructions": { + "message": "No agents are currently enabled for analysis. Please enable at least one agent (preprocessing, statistical analysis, machine learning, or visualization) in your template preferences to proceed with data analysis." + } +} +``` + +## API Endpoints + +### Template Management +- `GET /templates/` - Get all available templates +- `GET /templates/template/{template_id}` - Get specific template +- `POST /templates/load-default-agents` - Load default agents into database + +### User Preferences +- `GET /templates/user/{user_id}` - Get user's template preferences +- `GET /templates/user/{user_id}/enabled` - Get user's enabled templates +- `GET /templates/user/{user_id}/enabled/planner` - Get templates for planner (max 10) +- `POST /templates/user/{user_id}/template/{template_id}/toggle` - Toggle template preference +- `POST /templates/user/{user_id}/bulk-toggle` - Bulk toggle preferences + +### Categories +- `GET /templates/categories/list` - Get all categories +- `GET /templates/categories` - Get templates grouped by category +- `GET /templates/category/{category}` - Get templates in specific category + +## Usage Examples + +### Frontend Integration +```typescript +// Enable preprocessing and visualization agents for user +const enableAgents = async (userId: number) => { + await fetch(`/templates/user/${userId}/bulk-toggle`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + template_preferences: { + "1": true, // preprocessing_agent + "4": true // data_viz_agent + } + }) + }); +}; + +// Get user's enabled templates +const getUserTemplates = async (userId: number) => { + const response = await fetch(`/templates/user/${userId}/enabled`); + return await response.json(); +}; +``` + +### Direct Agent Usage +Users can still use any agent directly regardless of preferences: +``` +@preprocessing_agent clean this data +@data_viz_agent create a scatter plot of sales vs price +``` + +### Planner Usage +Only enabled agents will be available to the planner: +``` +User: "Clean the data and create a visualization" +System: Uses only enabled agents to create the plan +``` + +## Database Schema + +### AgentTemplate Table +```sql +CREATE TABLE agent_templates ( + template_id SERIAL PRIMARY KEY, + template_name VARCHAR UNIQUE NOT NULL, + display_name VARCHAR, + description TEXT, + prompt_template TEXT, + category VARCHAR, + is_premium_only BOOLEAN DEFAULT FALSE, + is_active BOOLEAN DEFAULT TRUE, + created_at TIMESTAMP, + updated_at TIMESTAMP +); +``` + +### UserTemplatePreference Table +```sql +CREATE TABLE user_template_preferences ( + user_id INTEGER, + template_id INTEGER, + is_enabled BOOLEAN DEFAULT FALSE, + usage_count INTEGER DEFAULT 0, + last_used_at TIMESTAMP, + created_at TIMESTAMP, + updated_at TIMESTAMP, + PRIMARY KEY (user_id, template_id), + FOREIGN KEY (user_id) REFERENCES users(user_id), + FOREIGN KEY (template_id) REFERENCES agent_templates(template_id) +); +``` + +## Troubleshooting + +### Common Issues + +1. **No agents available in planner** + - Check if user has enabled any templates: `GET /templates/user/{user_id}/enabled` + - Enable templates using the toggle endpoint + +2. **Default agents not found** + - Run the load script: `python load_default_agents.py` + - Check if agents exist: `GET /templates/` + +3. **Import errors in load script** + - Ensure you're in the backend directory + - Check that all dependencies are installed + - Verify database connection + +### Logs +Check the application logs for detailed error messages: +```bash +tail -f logs/templates_routes.log +tail -f logs/agents.log +``` + +## Migration from Old System + +If migrating from the previous custom agents system: + +1. **Data Migration**: Existing custom agents should be migrated to the new template system +2. **User Preferences**: Users will need to re-enable their preferred agents +3. **API Updates**: Update frontend code to use new template endpoints +4. **Testing**: Verify planner works with enabled templates only + +## Support + +For issues or questions: +1. Check the logs for error messages +2. Verify database connections +3. Ensure proper API endpoint usage +4. Test with the load script first \ No newline at end of file diff --git a/auto-analyst-backend/Dockerfile b/auto-analyst-backend/Dockerfile index d373f797..403ae0fe 100644 --- a/auto-analyst-backend/Dockerfile +++ b/auto-analyst-backend/Dockerfile @@ -14,6 +14,8 @@ COPY --chown=user . /app # Make entrypoint script executable USER root RUN chmod +x /app/entrypoint.sh +# Make populate script executable +RUN chmod +x /app/scripts/populate_agent_templates.py USER user # Use the entrypoint script instead of directly running uvicorn diff --git a/auto-analyst-backend/app.py b/auto-analyst-backend/app.py index 6ecee146..bfe73b8b 100644 --- a/auto-analyst-backend/app.py +++ b/auto-analyst-backend/app.py @@ -255,19 +255,7 @@ def clear_console(): logger.log_message(f"Housing.csv not found at {os.path.abspath(housing_csv_path)}", level=logging.ERROR) raise FileNotFoundError(f"Housing.csv not found at {os.path.abspath(housing_csv_path)}") -AVAILABLE_AGENTS = { - "data_viz_agent": data_viz_agent, - "sk_learn_agent": sk_learn_agent, - "statistical_analytics_agent": statistical_analytics_agent, - "preprocessing_agent": preprocessing_agent, -} - -PLANNER_AGENTS = { - "planner_preprocessing_agent": planner_preprocessing_agent, - "planner_sk_learn_agent": planner_sk_learn_agent, - "planner_statistical_analytics_agent": planner_statistical_analytics_agent, - "planner_data_viz_agent": planner_data_viz_agent, -} +# All agents are now loaded from database - no hardcoded dictionaries needed # Add session header X_SESSION_ID = APIKeyHeader(name="X-Session-ID", auto_error=False) @@ -275,7 +263,7 @@ def clear_console(): # Update AppState class to use SessionManager class AppState: def __init__(self): - self._session_manager = SessionManager(styling_instructions, PLANNER_AGENTS) + self._session_manager = SessionManager(styling_instructions, {}) # Empty dict, agents loaded from DB self.model_config = DEFAULT_MODEL_CONFIG.copy() # Update the SessionManager with the current model_config self._session_manager._app_model_config = self.model_config @@ -326,17 +314,82 @@ def get_chat_history_name_agent(self): def get_deep_analyzer(self, session_id: str): """Get or create deep analysis module for a session""" session_state = self.get_session_state(session_id) - if not hasattr(session_state, 'deep_analyzer') or session_state.get('deep_analyzer') is None: - # Create agents dictionary for deep analysis - deep_agents = { - "planner_data_viz_agent": dspy.asyncify(dspy.ChainOfThought(planner_data_viz_agent)), - "planner_statistical_analytics_agent": dspy.asyncify(dspy.ChainOfThought(planner_statistical_analytics_agent)), - "planner_sk_learn_agent": dspy.asyncify(dspy.ChainOfThought(planner_sk_learn_agent)), - "planner_preprocessing_agent": dspy.asyncify(dspy.ChainOfThought(planner_preprocessing_agent)) - } + user_id = session_state.get("user_id") + + # Check if we need to recreate the deep analyzer (user changed or doesn't exist) + current_analyzer = session_state.get('deep_analyzer') + analyzer_user_id = session_state.get('deep_analyzer_user_id') + + logger.log_message(f"Deep analyzer check - session: {session_id}, current_user: {user_id}, analyzer_user: {analyzer_user_id}, has_analyzer: {current_analyzer is not None}", level=logging.INFO) + + if (not current_analyzer or + analyzer_user_id != user_id or + not hasattr(session_state, 'deep_analyzer')): + + logger.log_message(f"Creating/recreating deep analyzer for session {session_id}, user_id: {user_id} (reason: analyzer_exists={current_analyzer is not None}, user_match={analyzer_user_id == user_id})", level=logging.INFO) + + # Load user-enabled agents from database using preference system + from src.db.init_db import session_factory + from src.agents.agents import load_user_enabled_templates_for_planner_from_db + + db_session = session_factory() + try: + # Load user-enabled agents for planner (respects preferences) + if user_id: + enabled_agents_dict = load_user_enabled_templates_for_planner_from_db(user_id, db_session) + logger.log_message(f"Deep analyzer loaded {len(enabled_agents_dict)} enabled agents for user {user_id}: {list(enabled_agents_dict.keys())}", level=logging.INFO) + + if not enabled_agents_dict: + logger.log_message(f"WARNING: No enabled agents found for user {user_id}, falling back to defaults", level=logging.WARNING) + # Fallback to default agents if no enabled agents + from src.agents.agents import preprocessing_agent, statistical_analytics_agent, sk_learn_agent, data_viz_agent + enabled_agents_dict = { + "preprocessing_agent": preprocessing_agent, + "statistical_analytics_agent": statistical_analytics_agent, + "sk_learn_agent": sk_learn_agent, + "data_viz_agent": data_viz_agent + } + else: + # Fallback to default agents if no user_id + logger.log_message("No user_id in session, loading default agents for deep analysis", level=logging.WARNING) + from src.agents.agents import preprocessing_agent, statistical_analytics_agent, sk_learn_agent, data_viz_agent + enabled_agents_dict = { + "preprocessing_agent": preprocessing_agent, + "statistical_analytics_agent": statistical_analytics_agent, + "sk_learn_agent": sk_learn_agent, + "data_viz_agent": data_viz_agent + } + + # Create agents dictionary for deep analysis using enabled agents + deep_agents = {} + deep_agents_desc = {} + + for agent_name, signature in enabled_agents_dict.items(): + deep_agents[agent_name] = dspy.asyncify(dspy.ChainOfThought(signature)) + # Get agent description from database + deep_agents_desc[agent_name] = get_agent_description(agent_name) + + logger.log_message(f"Deep analyzer initialized with {len(deep_agents)} agents: {list(deep_agents.keys())}", level=logging.INFO) + + except Exception as e: + logger.log_message(f"Error loading agents for deep analysis: {str(e)}", level=logging.ERROR) + # Fallback to minimal set + from src.agents.agents import preprocessing_agent, statistical_analytics_agent, sk_learn_agent, data_viz_agent + deep_agents = { + "preprocessing_agent": dspy.asyncify(dspy.ChainOfThought(preprocessing_agent)), + "statistical_analytics_agent": dspy.asyncify(dspy.ChainOfThought(statistical_analytics_agent)), + "sk_learn_agent": dspy.asyncify(dspy.ChainOfThought(sk_learn_agent)), + "data_viz_agent": dspy.asyncify(dspy.ChainOfThought(data_viz_agent)) + } + deep_agents_desc = {name: get_agent_description(name) for name in deep_agents.keys()} + logger.log_message(f"Using fallback agents: {list(deep_agents.keys())}", level=logging.WARNING) + finally: + db_session.close() - deep_agents_desc = PLANNER_AGENTS_WITH_DESCRIPTION session_state['deep_analyzer'] = deep_analysis_module(agents=deep_agents, agents_desc=deep_agents_desc) + session_state['deep_analyzer_user_id'] = user_id # Track which user this analyzer was created for + else: + logger.log_message(f"Using existing deep analyzer for session {session_id}, user_id: {user_id}", level=logging.INFO) return session_state['deep_analyzer'] @@ -344,6 +397,7 @@ def get_deep_analyzer(self, session_id: str): app = FastAPI(title="AI Analytics API", version="1.0") app.state = AppState() + # Configure middleware # Use a wildcard for local development or read from environment is_development = os.getenv("ENVIRONMENT", "development").lower() == "development" @@ -412,25 +466,34 @@ async def chat_with_agent( session_id: str = Depends(get_session_id_dependency) ): session_state = app.state.get_session_state(session_id) + logger.log_message(f"[DEBUG] chat_with_agent called with agent: '{agent_name}', query: '{request.query[:100]}...'", level=logging.DEBUG) try: # Extract and validate query parameters + logger.log_message(f"[DEBUG] Updating session from query params", level=logging.DEBUG) _update_session_from_query_params(request_obj, session_state) + logger.log_message(f"[DEBUG] Session state after query params: user_id={session_state.get('user_id')}, chat_id={session_state.get('chat_id')}", level=logging.DEBUG) # Validate dataset and agent name if session_state["current_df"] is None: + logger.log_message(f"[DEBUG] No dataset loaded", level=logging.DEBUG) raise HTTPException(status_code=400, detail=RESPONSE_ERROR_NO_DATASET) + logger.log_message(f"[DEBUG] About to validate agent name: '{agent_name}'", level=logging.DEBUG) _validate_agent_name(agent_name, session_state) + logger.log_message(f"[DEBUG] Agent validation completed successfully", level=logging.DEBUG) # Record start time for timing start_time = time.time() # Get chat context and prepare query + logger.log_message(f"[DEBUG] Preparing query with context", level=logging.DEBUG) enhanced_query = _prepare_query_with_context(request.query, session_state) + logger.log_message(f"[DEBUG] Enhanced query length: {len(enhanced_query)}", level=logging.DEBUG) # Initialize agent - handle standard, template, and custom agents if "," in agent_name: + logger.log_message(f"[DEBUG] Processing multiple agents: {agent_name}", level=logging.DEBUG) # Multiple agents case agent_list = [agent.strip() for agent in agent_name.split(",")] @@ -439,83 +502,84 @@ async def chat_with_agent( template_agents = [agent for agent in agent_list if _is_template_agent(agent)] custom_agents = [agent for agent in agent_list if not _is_standard_agent(agent) and not _is_template_agent(agent)] + logger.log_message(f"[DEBUG] Agent categorization - standard: {standard_agents}, template: {template_agents}, custom: {custom_agents}", level=logging.DEBUG) if custom_agents: # If any custom agents, use session AI system for all ai_system = session_state["ai_system"] session_lm = get_session_lm(session_state) + logger.log_message(f"[DEBUG] Using custom agent execution path", level=logging.DEBUG) with dspy.context(lm=session_lm): response = await asyncio.wait_for( _execute_custom_agents(ai_system, agent_list, enhanced_query), timeout=REQUEST_TIMEOUT_SECONDS ) + logger.log_message(f"[DEBUG] Custom agents response type: {type(response)}, keys: {list(response.keys()) if isinstance(response, dict) else 'not a dict'}", level=logging.DEBUG) else: - # All standard/template agents - use auto_analyst_ind - standard_agent_sigs = [AVAILABLE_AGENTS[agent] for agent in standard_agents] + # All standard/template agents - use auto_analyst_ind which loads from DB user_id = session_state.get("user_id") + logger.log_message(f"[DEBUG] Using auto_analyst_ind for multiple standard/template agents with user_id: {user_id}", level=logging.DEBUG) - # Create database session for template loading + # Create database session for agent loading from src.db.init_db import session_factory db_session = session_factory() try: - agent = auto_analyst_ind(agents=standard_agent_sigs, retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session) + # auto_analyst_ind will load all agents from database + logger.log_message(f"[DEBUG] Creating auto_analyst_ind instance", level=logging.DEBUG) + agent = auto_analyst_ind(agents=[], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session) session_lm = get_session_lm(session_state) + logger.log_message(f"[DEBUG] About to call agent.forward with query and agent list", level=logging.DEBUG) with dspy.context(lm=session_lm): response = await asyncio.wait_for( agent.forward(enhanced_query, ",".join(agent_list)), timeout=REQUEST_TIMEOUT_SECONDS ) + logger.log_message(f"[DEBUG] auto_analyst_ind response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG) finally: db_session.close() else: + logger.log_message(f"[DEBUG] Processing single agent: {agent_name}", level=logging.DEBUG) # Single agent case - if _is_standard_agent(agent_name): - # Standard agent - use auto_analyst_ind - user_id = session_state.get("user_id") - - # Create database session for template loading - from src.db.init_db import session_factory - db_session = session_factory() - try: - agent = auto_analyst_ind(agents=[AVAILABLE_AGENTS[agent_name]], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session) - session_lm = get_session_lm(session_state) - with dspy.context(lm=session_lm): - response = await asyncio.wait_for( - agent.forward(enhanced_query, agent_name), - timeout=REQUEST_TIMEOUT_SECONDS - ) - finally: - db_session.close() - elif _is_template_agent(agent_name): - # Template agent - use auto_analyst_ind with empty agents list (templates loaded in init) + if _is_standard_agent(agent_name) or _is_template_agent(agent_name): + # Standard or template agent - use auto_analyst_ind which loads from DB user_id = session_state.get("user_id") + logger.log_message(f"[DEBUG] Using auto_analyst_ind for single standard/template agent '{agent_name}' with user_id: {user_id}", level=logging.DEBUG) - # Create database session for template loading + # Create database session for agent loading from src.db.init_db import session_factory db_session = session_factory() try: + # auto_analyst_ind will load all agents from database + logger.log_message(f"[DEBUG] Creating auto_analyst_ind instance for single agent", level=logging.DEBUG) agent = auto_analyst_ind(agents=[], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session) session_lm = get_session_lm(session_state) + logger.log_message(f"[DEBUG] About to call agent.forward for single agent '{agent_name}'", level=logging.DEBUG) with dspy.context(lm=session_lm): response = await asyncio.wait_for( agent.forward(enhanced_query, agent_name), timeout=REQUEST_TIMEOUT_SECONDS ) + logger.log_message(f"[DEBUG] Single agent response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG) finally: db_session.close() else: # Custom agent - use session AI system ai_system = session_state["ai_system"] session_lm = get_session_lm(session_state) + logger.log_message(f"[DEBUG] Using custom agent execution for '{agent_name}'", level=logging.DEBUG) with dspy.context(lm=session_lm): response = await asyncio.wait_for( _execute_custom_agents(ai_system, [agent_name], enhanced_query), timeout=REQUEST_TIMEOUT_SECONDS ) + logger.log_message(f"[DEBUG] Custom single agent response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG) + logger.log_message(f"[DEBUG] About to format response to markdown. Response type: {type(response)}", level=logging.DEBUG) formatted_response = format_response_to_markdown(response, agent_name, session_state["current_df"]) + logger.log_message(f"[DEBUG] Formatted response type: {type(formatted_response)}, length: {len(str(formatted_response))}", level=logging.DEBUG) if formatted_response == RESPONSE_ERROR_INVALID_QUERY: + logger.log_message(f"[DEBUG] Response was invalid query error", level=logging.DEBUG) return { "agent_name": agent_name, "query": request.query, @@ -525,6 +589,7 @@ async def chat_with_agent( # Track usage statistics if session_state.get("user_id"): + logger.log_message(f"[DEBUG] Tracking model usage", level=logging.DEBUG) _track_model_usage( session_state=session_state, enhanced_query=enhanced_query, @@ -532,6 +597,7 @@ async def chat_with_agent( processing_time_ms=int((time.time() - start_time) * 1000) ) + logger.log_message(f"[DEBUG] chat_with_agent completed successfully", level=logging.DEBUG) return { "agent_name": agent_name, "query": request.query, # Return original query without context @@ -540,13 +606,19 @@ async def chat_with_agent( } except HTTPException: # Re-raise HTTP exceptions to preserve status codes + logger.log_message(f"[DEBUG] HTTPException caught and re-raised", level=logging.DEBUG) raise except asyncio.TimeoutError: + logger.log_message(f"[ERROR] Timeout error in chat_with_agent", level=logging.ERROR) raise HTTPException(status_code=504, detail="Request timed out. Please try a simpler query.") except Exception as e: + logger.log_message(f"[ERROR] Unexpected error in chat_with_agent: {str(e)}", level=logging.ERROR) + logger.log_message(f"[ERROR] Exception type: {type(e)}, traceback: {str(e)}", level=logging.ERROR) + import traceback + logger.log_message(f"[ERROR] Full traceback: {traceback.format_exc()}", level=logging.ERROR) raise HTTPException(status_code=500, detail="An unexpected error occurred. Please try again later.") - - + + @app.post("/chat", response_model=dict) async def chat_with_all( request: QueryRequest, @@ -614,48 +686,48 @@ def _update_session_from_query_params(request_obj: Request, session_state: dict) def _validate_agent_name(agent_name: str, session_state: dict = None): - """Validate that the requested agent(s) exist in either standard agents or user's custom agents""" + """Validate that the agent name(s) are available""" + logger.log_message(f"[DEBUG] Validating agent name: '{agent_name}'", level=logging.DEBUG) + if "," in agent_name: + # Multiple agents agent_list = [agent.strip() for agent in agent_name.split(",")] + logger.log_message(f"[DEBUG] Multiple agents detected: {agent_list}", level=logging.DEBUG) for agent in agent_list: - if not _is_agent_available(agent, session_state): + is_available = _is_agent_available(agent, session_state) + logger.log_message(f"[DEBUG] Agent '{agent}' availability: {is_available}", level=logging.DEBUG) + if not is_available: available_agents = _get_available_agents_list(session_state) + logger.log_message(f"[DEBUG] Agent '{agent}' not found. Available: {available_agents}", level=logging.DEBUG) raise HTTPException( - status_code=404, + status_code=400, detail=f"Agent '{agent}' not found. Available agents: {available_agents}" ) - elif not _is_agent_available(agent_name, session_state): - available_agents = _get_available_agents_list(session_state) - raise HTTPException( - status_code=404, - detail=f"Agent '{agent_name}' not found. Available agents: {available_agents}" - ) + else: + # Single agent + is_available = _is_agent_available(agent_name, session_state) + logger.log_message(f"[DEBUG] Single agent '{agent_name}' availability: {is_available}", level=logging.DEBUG) + if not is_available: + available_agents = _get_available_agents_list(session_state) + logger.log_message(f"[DEBUG] Agent '{agent_name}' not found. Available: {available_agents}", level=logging.DEBUG) + raise HTTPException( + status_code=400, + detail=f"Agent '{agent_name}' not found. Available agents: {available_agents}" + ) + + logger.log_message(f"[DEBUG] Agent validation passed for: '{agent_name}'", level=logging.DEBUG) def _is_agent_available(agent_name: str, session_state: dict = None) -> bool: - """Check if agent is available in either standard agents, template agents, or user's custom agents""" - # Check standard agents - if agent_name in AVAILABLE_AGENTS: + """Check if an agent is available (standard, template, or custom)""" + # Check if it's a standard agent + if _is_standard_agent(agent_name): return True - # Check template agents - try: - from src.db.init_db import session_factory - from src.db.schemas.models import AgentTemplate - - db_session = session_factory() - try: - template = db_session.query(AgentTemplate).filter( - AgentTemplate.template_name == agent_name, - AgentTemplate.is_active == True - ).first() - if template: - return True - finally: - db_session.close() - except Exception as e: - logger.log_message(f"Error checking template availability for {agent_name}: {str(e)}", level=logging.ERROR) + # Check if it's a template agent + if _is_template_agent(agent_name): + return True - # Check custom agents if session has an AI system with custom agents + # Check if it's a custom agent in session if session_state and "ai_system" in session_state: ai_system = session_state["ai_system"] if hasattr(ai_system, 'agents') and agent_name in ai_system.agents: @@ -664,22 +736,32 @@ def _is_agent_available(agent_name: str, session_state: dict = None) -> bool: return False def _get_available_agents_list(session_state: dict = None) -> list: - """Get list of all available agents (standard + custom)""" - available = list(AVAILABLE_AGENTS.keys()) + """Get list of all available agents from database""" + from src.db.init_db import session_factory + from src.agents.agents import load_all_available_templates_from_db - # Add custom agents if available - if session_state and "ai_system" in session_state: - ai_system = session_state["ai_system"] - if hasattr(ai_system, 'agents'): - custom_agents = [name for name in ai_system.agents.keys() - if name not in AVAILABLE_AGENTS and name != 'basic_qa_agent'] - available.extend(custom_agents) + # Core agents (always available) + available = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"] + + # Add template agents from database + db_session = session_factory() + try: + template_agents_dict = load_all_available_templates_from_db(db_session) + # template_agents_dict is a dict with template_name as keys + template_names = [template_name for template_name in template_agents_dict.keys() + if template_name not in available and template_name != 'basic_qa_agent'] + available.extend(template_names) + except Exception as e: + logger.log_message(f"Error loading template agents: {str(e)}", level=logging.ERROR) + finally: + db_session.close() return available def _is_standard_agent(agent_name: str) -> bool: - """Check if agent is a standard agent (not custom or template)""" - return agent_name in AVAILABLE_AGENTS + """Check if agent is one of the 4 core standard agents""" + standard_agents = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"] + return agent_name in standard_agents def _is_template_agent(agent_name: str) -> bool: """Check if agent is a template agent""" @@ -999,66 +1081,54 @@ async def _execute_plan_with_timeout(ai_system, enhanced_query, plan_response): # Add an endpoint to list available agents @app.get("/agents", response_model=dict) async def list_agents(request: Request, session_id: str = Depends(get_session_id_dependency)): + """Get all available agents (standard, template, and custom)""" session_state = app.state.get_session_state(session_id) - # Check if user_id is provided in query params to associate with session - user_id_param = request.query_params.get("user_id") - if user_id_param: - try: - user_id = int(user_id_param) - # Associate the user with this session to load custom agents - app.state.set_session_user(session_id, user_id) - # Refresh session state after user association - session_state = app.state.get_session_state(session_id) - except (ValueError, TypeError): - logger.log_message(f"Invalid user_id in agents endpoint: {user_id_param}", level=logging.WARNING) - - # Get user-specific agent list including custom agents - available_agents_list = _get_available_agents_list(session_state) - standard_agents = list(AVAILABLE_AGENTS.keys()) - planner_agents = list(PLANNER_AGENTS.keys()) - - # Get template agents from database - template_agents = [] try: + # Get all available agents from database and session + available_agents_list = _get_available_agents_list(session_state) + + # Categorize agents + standard_agents = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"] + + # Get template agents from database from src.db.init_db import session_factory - from src.db.schemas.models import AgentTemplate + from src.agents.agents import load_all_available_templates_from_db db_session = session_factory() try: - templates = db_session.query(AgentTemplate).filter( - AgentTemplate.is_active == True - ).all() - - template_agents = [template.template_name for template in templates] - logger.log_message(f"Found {len(template_agents)} template agents", level=logging.DEBUG) - + template_agents_dict = load_all_available_templates_from_db(db_session) + # template_agents_dict is a dict with template_name as keys + template_agents = [template_name for template_name in template_agents_dict.keys() + if template_name not in standard_agents and template_name != 'basic_qa_agent'] + except Exception as e: + logger.log_message(f"Error loading template agents in /agents endpoint: {str(e)}", level=logging.ERROR) + template_agents = [] finally: db_session.close() + + # Get custom agents from session + custom_agents = [] + if session_state and "ai_system" in session_state: + ai_system = session_state["ai_system"] + if hasattr(ai_system, 'agents'): + custom_agents = [agent for agent in available_agents_list + if agent not in standard_agents and agent not in template_agents] + + # Ensure template agents are in the available list + for template_agent in template_agents: + if template_agent not in available_agents_list: + available_agents_list.append(template_agent) + + return { + "available_agents": available_agents_list, + "standard_agents": standard_agents, + "template_agents": template_agents, + "custom_agents": custom_agents + } except Exception as e: - logger.log_message(f"Error fetching template agents: {str(e)}", level=logging.ERROR) - - # Custom agents are user-created agents (not standard, not planner, not template) - custom_agents = [agent for agent in available_agents_list - if agent not in standard_agents and agent not in planner_agents and agent not in template_agents] - - # Add template agents to available agents list if they're not already there - for template_agent in template_agents: - if template_agent not in available_agents_list: - available_agents_list.append(template_agent) - - return { - "available_agents": available_agents_list, - "standard_agents": standard_agents, - "custom_agents": custom_agents, - "template_agents": template_agents, - "planner_agents": planner_agents, - "deep_analysis": { - "available": True, - "description": "Comprehensive multi-step analysis with automated planning" - }, - "description": "List of available specialized agents that can be called using @agent_name" - } + logger.log_message(f"Error getting agents list: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Error getting agents list: {str(e)}") @app.get("/health", response_model=dict) async def health(): @@ -1162,7 +1232,7 @@ async def deep_analysis_streaming( session_lm = dspy.LM(model="anthropic/claude-sonnet-4-20250514", max_tokens=7000, temperature=0.5) return StreamingResponse( - _generate_deep_analysis_stream(session_state, request.goal, session_lm), + _generate_deep_analysis_stream(session_state, request.goal, session_lm, session_id), media_type='text/event-stream', headers={ 'Cache-Control': 'no-cache', @@ -1179,7 +1249,7 @@ async def deep_analysis_streaming( logger.log_message(f"Streaming deep analysis failed: {str(e)}", level=logging.ERROR) raise HTTPException(status_code=500, detail=f"Streaming deep analysis failed: {str(e)}") -async def _generate_deep_analysis_stream(session_state: dict, goal: str, session_lm): +async def _generate_deep_analysis_stream(session_state: dict, goal: str, session_lm, session_id: str): """Generate streaming responses for deep analysis""" # Track the start time for duration calculation start_time = datetime.now(UTC) @@ -1285,8 +1355,9 @@ async def update_report_in_db(status, progress, step=None, content=None): # Update DB status to running await update_report_in_db("running", 5) - # Get deep analyzer - deep_analyzer = app.state.get_deep_analyzer(session_state.get("session_id", "default")) + # Get deep analyzer - use the correct session_id from the session_state + logger.log_message(f"Getting deep analyzer for session_id: {session_id}, user_id: {user_id}", level=logging.INFO) + deep_analyzer = app.state.get_deep_analyzer(session_id) # Make the dataset available globally for code execution globals()['df'] = df @@ -1510,6 +1581,68 @@ async def download_html_report( logger.log_message(f"Failed to generate HTML report: {str(e)}", level=logging.ERROR) raise HTTPException(status_code=500, detail=f"Failed to generate report: {str(e)}") +@app.get("/debug/deep_analysis_agents") +async def debug_deep_analysis_agents(session_id: str = Depends(get_session_id_dependency)): + """Debug endpoint to show which agents are loaded for deep analysis""" + session_state = app.state.get_session_state(session_id) + user_id = session_state.get("user_id") + + try: + # Get the deep analyzer for this session + deep_analyzer = app.state.get_deep_analyzer(session_id) + + # Get the agents from the deep analyzer + available_agents = list(deep_analyzer.agents.keys()) if hasattr(deep_analyzer, 'agents') else [] + + # Also get the raw enabled agents from database + from src.db.init_db import session_factory + from src.agents.agents import load_user_enabled_templates_for_planner_from_db + + db_session = session_factory() + try: + if user_id: + enabled_agents_dict = load_user_enabled_templates_for_planner_from_db(user_id, db_session) + db_enabled_agents = list(enabled_agents_dict.keys()) + else: + db_enabled_agents = ["No user_id - using defaults"] + finally: + db_session.close() + + return { + "session_id": session_id, + "user_id": user_id, + "deep_analyzer_agents": available_agents, + "db_enabled_agents": db_enabled_agents, + "agents_match": set(available_agents) == set(db_enabled_agents) if user_id else "N/A" + } + + except Exception as e: + logger.log_message(f"Error in debug endpoint: {str(e)}", level=logging.ERROR) + return { + "error": str(e), + "session_id": session_id, + "user_id": user_id + } + +@app.post("/debug/clear_deep_analyzer") +async def clear_deep_analyzer_cache(session_id: str = Depends(get_session_id_dependency)): + """Debug endpoint to clear the deep analyzer cache and force reload""" + session_state = app.state.get_session_state(session_id) + + # Clear the cached deep analyzer + if 'deep_analyzer' in session_state: + del session_state['deep_analyzer'] + if 'deep_analyzer_user_id' in session_state: + del session_state['deep_analyzer_user_id'] + + logger.log_message(f"Cleared deep analyzer cache for session {session_id}", level=logging.INFO) + + return { + "message": "Deep analyzer cache cleared", + "session_id": session_id, + "user_id": session_state.get("user_id") + } + # In the section where routers are included, add the session_router app.include_router(chat_router) app.include_router(analytics_router) diff --git a/auto-analyst-backend/chat_database.db b/auto-analyst-backend/chat_database.db index c8fe109a..e99a3734 100644 --- a/auto-analyst-backend/chat_database.db +++ b/auto-analyst-backend/chat_database.db @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5182504f09b433865500dea5844a934662270e3aba11fae15c58df19855b2ba1 -size 81920 +oid sha256:061a038faca46d439d7178af5bc140ee7a523d4ef00cf57ca75d2d9e708fcdb2 +size 69632 diff --git a/auto-analyst-backend/entrypoint.sh b/auto-analyst-backend/entrypoint.sh index 8e82044f..2ce6c9ac 100644 --- a/auto-analyst-backend/entrypoint.sh +++ b/auto-analyst-backend/entrypoint.sh @@ -70,6 +70,59 @@ except Exception as e: # Don't exit on database connectivity issues - let app try to start } +# Function to populate agents and templates for development (SQLite only) +populate_agents_templates() { + echo "πŸ”§ Checking if agents/templates need to be populated..." + python -c " +try: + from src.db.init_db import DATABASE_URL + from src.db.schemas.models import AgentTemplate + from src.db.init_db import session_factory + + # Check database type + if DATABASE_URL.startswith('sqlite'): + print('πŸ” SQLite database detected - checking template population') + + session = session_factory() + try: + template_count = session.query(AgentTemplate).count() + + if template_count == 0: + print('πŸ“‹ No templates found - populating agents and templates...') + session.close() + exit(1) # Signal that population is needed + else: + print(f'βœ… Found {template_count} templates - population not needed') + session.close() + exit(0) # Signal that population is not needed + except Exception as e: + print(f'⚠️ Error checking templates: {e}') + print('πŸ“‹ Will attempt to populate anyway') + session.close() + exit(1) # Signal that population is needed + else: + print('πŸ” PostgreSQL/RDS detected - skipping auto-population') + exit(0) # Signal that population is not needed + +except Exception as e: + print(f'❌ Error during template check: {e}') + exit(0) # Don't fail startup, just skip population +" + + # Check if population is needed (exit code 1 means yes) + if [ $? -eq 1 ]; then + echo "πŸš€ Running agent/template population for SQLite..." + python scripts/populate_agent_templates.py auto + + if [ $? -eq 0 ]; then + echo "βœ… Agent/template population completed successfully" + else + echo "⚠️ Agent/template population had issues, but continuing..." + echo "πŸ“‹ You may need to populate templates manually" + fi + fi +} + # Main startup sequence echo "πŸ”§ Initializing production environment..." @@ -82,6 +135,9 @@ init_production_database # Test database connectivity (non-failing) verify_database_connectivity +# Populate agents and templates for development (SQLite only) +populate_agents_templates + echo "🎯 Starting FastAPI application..." echo "🌐 Application will be available on port 7860" diff --git a/auto-analyst-backend/scripts/populate_agent_templates.py b/auto-analyst-backend/scripts/populate_agent_templates.py index 97894af3..d1d7b87d 100644 --- a/auto-analyst-backend/scripts/populate_agent_templates.py +++ b/auto-analyst-backend/scripts/populate_agent_templates.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 """ -Script to populate agent templates. -These templates are available to all users but usable only by paid users. +Enhanced Script to populate agent templates for development. +Includes both default agents (free) and premium templates. +Automatically detects database type and populates accordingly. """ import sys @@ -11,13 +12,200 @@ # Add the project root to the Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from src.db.init_db import session_factory +from src.db.init_db import session_factory, DATABASE_URL from src.db.schemas.models import AgentTemplate from sqlalchemy.exc import IntegrityError -# Template agent definitions -AGENT_TEMPLATES = { - "Visualization": [ +def get_database_type(): + """Detect database type from DATABASE_URL""" + if DATABASE_URL.startswith('postgresql'): + return "postgresql" + elif DATABASE_URL.startswith('sqlite'): + return "sqlite" + else: + return "unknown" + +# Default agents (free for all users) +DEFAULT_AGENTS = { + "Data Manipulation": [ + { + "template_name": "preprocessing_agent", + "display_name": "Data Preprocessing Agent", + "description": "Cleans and prepares a DataFrame using Pandas and NumPyβ€”handles missing values, detects column types, and converts date strings to datetime.", + "icon_url": "/icons/templates/pandas.svg", + "prompt_template": """You are a AI data-preprocessing agent. Generate clean and efficient Python code using NumPy and Pandas to perform introductory data preprocessing on a pre-loaded DataFrame df, based on the user's analysis goals. +Preprocessing Requirements: +1. Identify Column Types +- Separate columns into numeric and categorical using: + categorical_columns = df.select_dtypes(include=[object, 'category']).columns.tolist() + numeric_columns = df.select_dtypes(include=[np.number]).columns.tolist() +2. Handle Missing Values +- Numeric columns: Impute missing values using the mean of each column +- Categorical columns: Impute missing values using the mode of each column +3. Convert Date Strings to Datetime +- For any column suspected to represent dates (in string format), convert it to datetime using: + def safe_to_datetime(date): + try: + return pd.to_datetime(date, errors='coerce', cache=False) + except (ValueError, TypeError): + return pd.NaT + df['datetime_column'] = df['datetime_column'].apply(safe_to_datetime) +- Replace 'datetime_column' with the actual column names containing date-like strings +Important Notes: +- Do NOT create a correlation matrix β€” correlation analysis is outside the scope of preprocessing +- Do NOT generate any plots or visualizations +Output Instructions: +1. Include the full preprocessing Python code +2. Provide a brief bullet-point summary of the steps performed. Example: +β€’ Identified 5 numeric and 4 categorical columns +β€’ Filled missing numeric values with column means +β€’ Filled missing categorical values with column modes +β€’ Converted 1 date column to datetime format + Respond in the user's language for all summary and reasoning but keep the code in english""" + } + ], + "Data Modelling": [ + { + "template_name": "statistical_analytics_agent", + "display_name": "Statistical Analytics Agent", + "description": "Performs statistical analysis (e.g., regression, seasonal decomposition) using statsmodels, with proper handling of categorical data and missing values.", + "icon_url": "/icons/templates/statsmodels.svg", + "prompt_template": """You are a statistical analytics agent. Your task is to take a dataset and a user-defined goal and output Python code that performs the appropriate statistical analysis to achieve that goal. Follow these guidelines: +IMPORTANT: You may be provided with previous interaction history. The section marked "### Current Query:" contains the user's current request. Any text in "### Previous Interaction History:" is for context only and is NOT part of the current request. +Data Handling: +Always handle strings as categorical variables in a regression using statsmodels C(string_column). +Do not change the index of the DataFrame. +Convert X and y into float when fitting a model. +Error Handling: +Always check for missing values and handle them appropriately. +Ensure that categorical variables are correctly processed. +Provide clear error messages if the model fitting fails. +Regression: +For regression, use statsmodels and ensure that a constant term is added to the predictor using sm.add_constant(X). +Handle categorical variables using C(column_name) in the model formula. +Fit the model with model = sm.OLS(y.astype(float), X.astype(float)).fit(). +Seasonal Decomposition: +Ensure the period is set correctly when performing seasonal decomposition. +Verify the number of observations works for the decomposition. +Output: +Ensure the code is executable and as intended. +Also choose the correct type of model for the problem +Avoid adding data visualization code. +Use code like this to prevent failing: +import pandas as pd +import numpy as np +import statsmodels.api as sm +def statistical_model(X, y, goal, period=None): + try: + # Check for missing values and handle them + X = X.dropna() + y = y.loc[X.index].dropna() + # Ensure X and y are aligned + X = X.loc[y.index] + # Convert categorical variables + for col in X.select_dtypes(include=['object', 'category']).columns: + X[col] = X[col].astype('category') + # Add a constant term to the predictor + X = sm.add_constant(X) + # Fit the model + if goal == 'regression': + # Handle categorical variables in the model formula + formula = 'y ~ ' + ' + '.join([f'C({col})' if X[col].dtype.name == 'category' else col for col in X.columns]) + model = sm.OLS(y.astype(float), X.astype(float)).fit() + return model.summary() + elif goal == 'seasonal_decompose': + if period is None: + raise ValueError("Period must be specified for seasonal decomposition") + decomposition = sm.tsa.seasonal_decompose(y, period=period) + return decomposition + else: + raise ValueError("Unknown goal specified. Please provide a valid goal.") + except Exception as e: + return f"An error occurred: {e}" +# Example usage: +result = statistical_analysis(X, y, goal='regression') +print(result) +If visualizing use plotly +Provide a concise bullet-point summary of the statistical analysis performed. + +Example Summary: +β€’ Applied linear regression with OLS to predict house prices based on 5 features +β€’ Model achieved R-squared of 0.78 +β€’ Significant predictors include square footage (p<0.001) and number of bathrooms (p<0.01) +β€’ Detected strong seasonal pattern with 12-month periodicity +β€’ Forecast shows 15% growth trend over next quarter +Respond in the user's language for all summary and reasoning but keep the code in english""" + }, + { + "template_name": "sk_learn_agent", + "display_name": "Machine Learning Agent", + "description": "Trains and evaluates machine learning models using scikit-learn, including classification, regression, and clustering with feature importance insights.", + "icon_url": "/icons/templates/scikit-learn.svg", + "prompt_template": """You are a machine learning agent. +Your task is to take a dataset and a user-defined goal, and output Python code that performs the appropriate machine learning analysis to achieve that goal. +You should use the scikit-learn library. +IMPORTANT: You may be provided with previous interaction history. The section marked "### Current Query:" contains the user's current request. Any text in "### Previous Interaction History:" is for context only and is NOT part of the current request. +Make sure your output is as intended! +Provide a concise bullet-point summary of the machine learning operations performed. + +Example Summary: +β€’ Trained a Random Forest classifier on customer churn data with 80/20 train-test split +β€’ Model achieved 92% accuracy and 88% F1-score +β€’ Feature importance analysis revealed that contract length and monthly charges are the strongest predictors of churn +β€’ Implemented K-means clustering (k=4) on customer shopping behaviors +β€’ Identified distinct segments: high-value frequent shoppers (22%), occasional big spenders (35%), budget-conscious regulars (28%), and rare visitors (15%) +Respond in the user's language for all summary and reasoning but keep the code in english""" + } + ], + "Data Visualization": [ + { + "template_name": "data_viz_agent", + "display_name": "Data Visualization Agent", + "description": "Generates interactive visualizations with Plotly, selecting the best chart type to reveal trends, comparisons, and insights based on the analysis goal.", + "icon_url": "/icons/templates/plotly.svg", + "prompt_template": """You are an AI agent responsible for generating interactive data visualizations using Plotly. +IMPORTANT Instructions: +- The section marked "### Current Query:" contains the user's request. Any text in "### Previous Interaction History:" is for context only and should NOT be treated as part of the current request. +- You must only use the tools provided to you. This agent handles visualization only. +- If len(df) > 50000, always sample the dataset before visualization using: +if len(df) > 50000: + df = df.sample(50000, random_state=1) +- Each visualization must be generated as a **separate figure** using go.Figure(). +Do NOT use subplots under any circumstances. +- Each figure must be returned individually using: +fig.to_html(full_html=False) +- Use update_layout with xaxis and yaxis **only once per figure**. +- Enhance readability and clarity by: +β€’ Using low opacity (0.4-0.7) where appropriate +β€’ Applying visually distinct colors for different elements or categories +- Make sure the visual **answers the user's specific goal**: +β€’ Identify what insight or comparison the user is trying to achieve +β€’ Choose the visualization type and features (e.g., color, size, grouping) to emphasize that goal +β€’ For example, if the user asks for "trends in revenue," use a time series line chart; if they ask for "top-performing categories," use a bar chart sorted by value +β€’ Prioritize highlighting patterns, outliers, or comparisons relevant to the question +- Never include the dataset or styling index in the output. +- If there are no relevant columns for the requested visualization, respond with: +"No relevant columns found to generate this visualization." +- Use only one number format consistently: either 'K', 'M', or comma-separated values like 1,000/1,000,000. Do not mix formats. +- Only include trendlines in scatter plots if the user explicitly asks for them. +- Output only the code and a concise bullet-point summary of what the visualization reveals. +- Always end each visualization with: +fig.to_html(full_html=False) +Respond in the user's language for all summary and reasoning but keep the code in english +Example Summary: +β€’ Created an interactive scatter plot of sales vs. marketing spend with color-coded product categories +β€’ Included a trend line showing positive correlation (r=0.72) +β€’ Highlighted outliers where high marketing spend resulted in low sales +β€’ Generated a time series chart of monthly revenue from 2020-2023 +β€’ Added annotations for key business events +β€’ Visualization reveals 35% YoY growth with seasonal peaks in Q4""" + } + ] +} + +# Premium template agent definitions +PREMIUM_TEMPLATES = { + "Data Visualization": [ { "template_name": "matplotlib_agent", "display_name": "Matplotlib Visualization Agent", @@ -27,7 +215,7 @@ You are a matplotlib/seaborn visualization expert. Your task is to create high-quality static visualizations using matplotlib and seaborn libraries. IMPORTANT Instructions: -- You must only use matplotlib, seaborn, and numpy/polars for visualizations +- You must only use matplotlib, seaborn, and numpy/pandas for visualizations - Always use plt.style.use('seaborn-v0_8') or a clean style for better aesthetics - Include proper titles, axis labels, and legends - Use appropriate color palettes and consider accessibility @@ -133,56 +321,106 @@ ] } -def populate_templates(): - """Populate the database with agent templates.""" +def populate_agents_and_templates(include_defaults=True, include_premiums=True): + """Populate the database with default agents and premium templates.""" session = session_factory() + db_type = get_database_type() try: # Track statistics - created_count = 0 + default_created = 0 + premium_created = 0 skipped_count = 0 - for category, templates in AGENT_TEMPLATES.items(): - print(f"\n--- Processing {category} Templates ---") - - for template_data in templates: - template_name = template_data["template_name"] - - # Check if template already exists - existing = session.query(AgentTemplate).filter( - AgentTemplate.template_name == template_name - ).first() - - if existing: - print(f"⏭️ Skipping {template_name} (already exists)") - skipped_count += 1 - continue + print(f"πŸ” Detected {db_type.upper()} database") + print(f"πŸ“‹ Database URL: {DATABASE_URL}") + + # Populate default agents (free) + if include_defaults: + print(f"\nπŸ†“ --- Processing Default Agents (Free) ---") + for category, agents in DEFAULT_AGENTS.items(): + print(f"\nπŸ“ {category}:") - # Create new template - template = AgentTemplate( - template_name=template_name, - display_name=template_data["display_name"], - description=template_data["description"], - icon_url=template_data["icon_url"], - prompt_template=template_data["prompt_template"], - category=category, - is_premium_only=True, # All templates require premium - is_active=True, - created_at=datetime.now(UTC), - updated_at=datetime.now(UTC) - ) + for agent_data in agents: + template_name = agent_data["template_name"] + + # Check if agent already exists + existing = session.query(AgentTemplate).filter( + AgentTemplate.template_name == template_name + ).first() + + if existing: + print(f"⏭️ Skipping {template_name} (already exists)") + skipped_count += 1 + continue + + # Create new default agent + template = AgentTemplate( + template_name=template_name, + display_name=agent_data["display_name"], + description=agent_data["description"], + icon_url=agent_data["icon_url"], + prompt_template=agent_data["prompt_template"], + category=category, + is_premium_only=False, # Default agents are free + is_active=True, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC) + ) + + session.add(template) + print(f"βœ… Created default agent: {template_name}") + default_created += 1 + + # Populate premium templates (paid) + if include_premiums: + print(f"\nπŸ”’ --- Processing Premium Templates (Paid) ---") + for category, templates in PREMIUM_TEMPLATES.items(): + print(f"\nπŸ“ {category}:") - session.add(template) - print(f"βœ… Created template: {template_name}") - created_count += 1 + for template_data in templates: + template_name = template_data["template_name"] + + # Check if template already exists + existing = session.query(AgentTemplate).filter( + AgentTemplate.template_name == template_name + ).first() + + if existing: + print(f"⏭️ Skipping {template_name} (already exists)") + skipped_count += 1 + continue + + # Create new premium template + template = AgentTemplate( + template_name=template_name, + display_name=template_data["display_name"], + description=template_data["description"], + icon_url=template_data["icon_url"], + prompt_template=template_data["prompt_template"], + category=category, + is_premium_only=True, # Premium templates require subscription + is_active=True, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC) + ) + + session.add(template) + print(f"βœ… Created premium template: {template_name}") + premium_created += 1 # Commit all changes session.commit() - print(f"\n--- Summary ---") - print(f"Created: {created_count} templates") - print(f"Skipped: {skipped_count} templates") - print(f"Total templates in database: {created_count + skipped_count}") + print(f"\nπŸ“Š --- Summary ---") + print(f"πŸ†“ Default agents created: {default_created}") + print(f"πŸ”’ Premium templates created: {premium_created}") + print(f"⏭️ Skipped (already exist): {skipped_count}") + print(f"πŸ“ˆ Total new templates: {default_created + premium_created}") + + # Show total count in database + total_count = session.query(AgentTemplate).count() + print(f"πŸ—„οΈ Total templates in database: {total_count}") except Exception as e: session.rollback() @@ -191,6 +429,10 @@ def populate_templates(): finally: session.close() +def populate_templates(): + """Legacy function for backward compatibility - only premium templates.""" + populate_agents_and_templates(include_defaults=True, include_premiums=True) + def list_templates(): """List all existing templates.""" session = session_factory() @@ -236,18 +478,42 @@ def remove_all_templates(): finally: session.close() +def auto_populate_for_database(): + """Automatically populate based on database type.""" + db_type = get_database_type() + + if db_type == "sqlite": + print("πŸ” SQLite detected - populating both default agents and premium templates") + populate_agents_and_templates(include_defaults=True, include_premiums=True) + elif db_type == "postgresql": + print("πŸ” PostgreSQL detected - populating only premium templates") + populate_agents_and_templates(include_defaults=False, include_premiums=True) + else: + print(f"⚠️ Unknown database type: {db_type}") + print("Populating both default agents and premium templates") + populate_agents_and_templates(include_defaults=True, include_premiums=True) + if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Manage agent templates") - parser.add_argument("action", choices=["populate", "list", "remove-all"], + parser.add_argument("action", choices=["populate", "populate-all", "populate-defaults", "auto", "list", "remove-all"], help="Action to perform") args = parser.parse_args() if args.action == "populate": - print("πŸš€ Populating agent templates...") + print("πŸš€ Populating premium templates only...") populate_templates() + elif args.action == "populate-all": + print("πŸš€ Populating both default agents and premium templates...") + populate_agents_and_templates(include_defaults=True, include_premiums=True) + elif args.action == "populate-defaults": + print("πŸš€ Populating default agents only...") + populate_agents_and_templates(include_defaults=True, include_premiums=False) + elif args.action == "auto": + print("πŸš€ Auto-populating based on database type...") + auto_populate_for_database() elif args.action == "list": list_templates() elif args.action == "remove-all": diff --git a/auto-analyst-backend/src/agents/agents.py b/auto-analyst-backend/src/agents/agents.py index 16ef0a66..8de9d96f 100644 --- a/auto-analyst-backend/src/agents/agents.py +++ b/auto-analyst-backend/src/agents/agents.py @@ -9,7 +9,7 @@ logger = Logger("agents", see_time=True, console_log=False) # === CUSTOM AGENT FUNCTIONALITY === -def create_custom_agent_signature(agent_name, description, prompt_template): +def create_custom_agent_signature(agent_name, description, prompt_template, category=None): """ Dynamically creates a dspy.Signature class for custom agents. @@ -17,21 +17,33 @@ def create_custom_agent_signature(agent_name, description, prompt_template): agent_name: Name of the custom agent (e.g., 'pytorch_agent') description: Short description for agent selection prompt_template: Main prompt/instructions for agent behavior + category: Agent category from database (e.g., 'Visualization', 'Modelling', 'Data Manipulation') Returns: A dspy.Signature class with the custom prompt and standard input/output fields """ - # Standard input/output fields that match standard agents (like data_viz_agent) + # Check if this is a visualization agent to determine input fields + # First check category, then fallback to name-based detection + if category and category.lower() == 'visualization': + is_viz_agent = True + else: + is_viz_agent = 'viz' in agent_name.lower() or 'visual' in agent_name.lower() or 'plot' in agent_name.lower() or 'chart' in agent_name.lower() + + # Standard input/output fields that match the unified agent signatures class_attributes = { '__doc__': prompt_template, # The custom prompt becomes the docstring 'goal': dspy.InputField(desc="User-defined goal which includes information about data and task they want to perform"), 'dataset': dspy.InputField(desc="Provides information about the data in the data frame. Only use column names and dataframe_name as in this context"), - 'styling_index': dspy.InputField(desc='Provides instructions on how to style outputs and formatting'), + 'plan_instructions': dspy.InputField(desc="Agent-level instructions about what to create and receive (optional for individual use)", default=""), 'code': dspy.OutputField(desc="Generated Python code for the analysis"), 'summary': dspy.OutputField(desc="A concise bullet-point summary of what was done and key results") } + # Add styling_index for visualization agents + if is_viz_agent: + class_attributes['styling_index'] = dspy.InputField(desc='Provides instructions on how to style outputs and formatting') + # Create the dynamic signature class CustomAgentSignature = type(agent_name, (dspy.Signature,), class_attributes) return CustomAgentSignature @@ -39,7 +51,7 @@ def create_custom_agent_signature(agent_name, description, prompt_template): def load_user_enabled_templates_from_db(user_id, db_session): """ Load template agents that are enabled for a specific user from the database. - All templates are enabled by default unless explicitly disabled by user preference. + Default agents are enabled by default unless explicitly disabled by user preference. Args: user_id: ID of the user @@ -56,6 +68,14 @@ def load_user_enabled_templates_from_db(user_id, db_session): if not user_id: return agent_signatures + # Get list of default agent names that should be enabled by default + default_agent_names = [ + "preprocessing_agent", + "statistical_analytics_agent", + "sk_learn_agent", + "data_viz_agent" + ] + # Get all active templates all_templates = db_session.query(AgentTemplate).filter( AgentTemplate.is_active == True @@ -68,16 +88,20 @@ def load_user_enabled_templates_from_db(user_id, db_session): UserTemplatePreference.template_id == template.template_id ).first() - # Template is disabled by default unless explicitly enabled - # Only enabled if preference record exists and is_enabled=True - is_enabled = preference.is_enabled if preference else False + # Determine if template should be enabled by default + is_default_agent = template.template_name in default_agent_names + default_enabled = is_default_agent # Default agents enabled by default, others disabled + # Template is enabled by default for default agents, disabled for others + is_enabled = preference.is_enabled if preference else default_enabled + if is_enabled: # Create dynamic signature for each enabled template signature = create_custom_agent_signature( template.template_name, template.description, - template.prompt_template + template.prompt_template, + template.category # Pass the category from database ) agent_signatures[template.template_name] = signature @@ -90,6 +114,7 @@ def load_user_enabled_templates_from_db(user_id, db_session): def load_user_enabled_templates_for_planner_from_db(user_id, db_session): """ Load template agents that are enabled for planner use (max 10, prioritized by usage). + Default agents are enabled by default unless explicitly disabled by user preference. Args: user_id: ID of the user @@ -100,36 +125,63 @@ def load_user_enabled_templates_for_planner_from_db(user_id, db_session): """ try: from src.db.schemas.models import AgentTemplate, UserTemplatePreference + from datetime import datetime, UTC agent_signatures = {} if not user_id: return agent_signatures - # Get enabled templates ordered by usage (most used first) and limit to 10 - enabled_preferences = db_session.query(UserTemplatePreference).filter( - UserTemplatePreference.user_id == user_id, - UserTemplatePreference.is_enabled == True - ).order_by( - UserTemplatePreference.usage_count.desc(), - UserTemplatePreference.last_used_at.desc() - ).limit(10).all() + # Get list of default agent names that should be enabled by default + default_agent_names = [ + "preprocessing_agent", + "statistical_analytics_agent", + "sk_learn_agent", + "data_viz_agent" + ] - for preference in enabled_preferences: - # Get template details - template = db_session.query(AgentTemplate).filter( - AgentTemplate.template_id == preference.template_id, - AgentTemplate.is_active == True + # Get all active templates + all_templates = db_session.query(AgentTemplate).filter( + AgentTemplate.is_active == True + ).all() + + enabled_templates = [] + for template in all_templates: + # Check if user has a preference record for this template + preference = db_session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.template_id == template.template_id ).first() - if template: - # Create dynamic signature for each enabled template - signature = create_custom_agent_signature( - template.template_name, - template.description, - template.prompt_template - ) - agent_signatures[template.template_name] = signature + # Determine if template should be enabled by default + is_default_agent = template.template_name in default_agent_names + default_enabled = is_default_agent # Default agents enabled by default, others disabled + + # Template is enabled by default for default agents, disabled for others + is_enabled = preference.is_enabled if preference else default_enabled + + if is_enabled: + enabled_templates.append({ + 'template': template, + 'preference': preference, + 'usage_count': preference.usage_count if preference else 0, + 'last_used_at': preference.last_used_at if preference else None + }) + + # Sort by usage (most used first) and limit to 10 + enabled_templates.sort(key=lambda x: (x['usage_count'], x['last_used_at'] or datetime.min.replace(tzinfo=UTC)), reverse=True) + enabled_templates = enabled_templates[:10] + + for item in enabled_templates: + template = item['template'] + # Create dynamic signature for each enabled template + signature = create_custom_agent_signature( + template.template_name, + template.description, + template.prompt_template, + template.category # Pass the category from database + ) + agent_signatures[template.template_name] = signature logger.log_message(f"Loaded {len(agent_signatures)} templates for planner", level=logging.DEBUG) return agent_signatures @@ -247,11 +299,11 @@ def load_all_available_templates_from_db(db_session): signature = create_custom_agent_signature( template.template_name, template.description, - template.prompt_template + template.prompt_template, + template.category # Pass the category from database ) agent_signatures[template.template_name] = signature - logger.log_message(f"Loaded {len(agent_signatures)} templates", level=logging.INFO) return agent_signatures except Exception as e: @@ -262,43 +314,30 @@ def load_all_available_templates_from_db(db_session): # === END CUSTOM AGENT FUNCTIONALITY === -AGENTS_WITH_DESCRIPTION = { - "preprocessing_agent": "Cleans and prepares a DataFrame using Pandas and NumPyβ€”handles missing values, detects column types, and converts date strings to datetime.", - "statistical_analytics_agent": "Performs statistical analysis (e.g., regression, seasonal decomposition) using statsmodels, with proper handling of categorical data and missing values.", - "sk_learn_agent": "Trains and evaluates machine learning models using scikit-learn, including classification, regression, and clustering with feature importance insights.", - "data_viz_agent": "Generates interactive visualizations with Plotly, selecting the best chart type to reveal trends, comparisons, and insights based on the analysis goal." -} - -PLANNER_AGENTS_WITH_DESCRIPTION = { - "planner_preprocessing_agent": ( - "Cleans and prepares a DataFrame using Pandas and NumPy" - "handles missing values, detects column types, and converts date strings to datetime. " - "Outputs a cleaned DataFrame for the planner_statistical_analytics_agent." - ), - "planner_statistical_analytics_agent": ( - "Takes the cleaned DataFrame from preprocessing, performs statistical analysis " - "(e.g., regression, seasonal decomposition) using statsmodels with proper handling " - "of categorical data and remaining missing values. " - "Produces summary statistics and model diagnostics for the planner_sk_learn_agent." - ), - "planner_sk_learn_agent": ( - "Receives summary statistics and the cleaned data, trains and evaluates machine " - "learning models using scikit-learn (classification, regression, clustering), " - "and generates performance metrics and feature importance. " - "Passes the trained models and evaluation results to the planner_data_viz_agent." - ), - "planner_data_viz_agent": ( - "Consumes trained models and evaluation results to create interactive visualizations " - "with Plotlyβ€”selects the best chart type, applies styling, and annotates insights. " - "Delivers ready-to-share figures that communicate model performance and key findings." - ), -} - def get_agent_description(agent_name, is_planner=False): - if is_planner: - return PLANNER_AGENTS_WITH_DESCRIPTION[agent_name.lower()] if agent_name.lower() in PLANNER_AGENTS_WITH_DESCRIPTION else "No description available for this agent" - else: - return AGENTS_WITH_DESCRIPTION[agent_name.lower()] if agent_name.lower() in AGENTS_WITH_DESCRIPTION else "No description available for this agent" + """ + Get agent description from database instead of hardcoded dictionaries. + This function is kept for backward compatibility but will fetch from DB. + """ + try: + from src.db.init_db import session_factory + from src.db.schemas.models import AgentTemplate + + db_session = session_factory() + try: + template = db_session.query(AgentTemplate).filter( + AgentTemplate.template_name == agent_name, + AgentTemplate.is_active == True + ).first() + + if template: + return template.description + else: + return "No description available for this agent" + finally: + db_session.close() + except Exception as e: + return "No description available for this agent" # Agent to make a Chat history name from a query @@ -421,6 +460,11 @@ class custom_agent_instruction_generator(dspy.Signature): class advanced_query_planner(dspy.Signature): """ You are a advanced data analytics planner agent. Your task is to generate the most efficient planβ€”using the fewest necessary agents and variablesβ€”to achieve a user-defined goal. The plan must preserve data integrity, avoid unnecessary steps, and ensure clear data flow between agents. + +**CRITICAL**: Before planning, check if any agents are available in Agent_desc. If Agent_desc is empty or contains no active agents, respond with: +plan: no_agents_available +plan_instructions: {"message": "No agents are currently enabled for analysis. Please enable at least one agent (preprocessing, statistical analysis, machine learning, or visualization) in your template preferences to proceed with data analysis."} + **Inputs**: 1. Datasets (raw or preprocessed) 2. Agent descriptions (roles, variables they create/use, constraints) @@ -437,9 +481,9 @@ class advanced_query_planner(dspy.Signature): Example: 1 agent use goal: "Generate a bar plot showing sales by category after cleaning the raw data and calculating the average of the 'sales' column" Output: - plan: planner_data_viz_agent + plan: data_viz_agent { - "planner_data_viz_agent": { + "data_viz_agent": { "create": [ "cleaned_data: DataFrame - cleaned version of df (pd.Dataframe) after removing null values" ], @@ -451,9 +495,9 @@ class advanced_query_planner(dspy.Signature): } Example 3 Agent goal:"Clean the dataset, run a linear regression to model the relationship between marketing budget and sales, and visualize the regression line with confidence intervals." -plan: planner_preprocessing_agent -> planner_statistical_analytics_agent -> planner_data_viz_agent +plan: preprocessing_agent -> statistical_analytics_agent -> data_viz_agent { - "planner_preprocessing_agent": { + "preprocessing_agent": { "create": [ "cleaned_data: DataFrame - cleaned version of df with missing values handled and proper data types inferred" ], @@ -462,7 +506,7 @@ class advanced_query_planner(dspy.Signature): ], "instruction": "Clean df by handling missing values and converting column types (e.g., dates). Output cleaned_data for modeling." }, - "planner_statistical_analytics_agent": { + "statistical_analytics_agent": { "create": [ "regression_results: dict - model summary including coefficients, p-values, RΒ², and confidence intervals" ], @@ -471,7 +515,7 @@ class advanced_query_planner(dspy.Signature): ], "instruction": "Perform linear regression using cleaned_data to model sales as a function of marketing budget. Return regression_results including coefficients and confidence intervals." }, - "planner_data_viz_agent": { + "data_viz_agent": { "create": [ "regression_plot: PlotlyFigure - visual plot showing regression line with confidence intervals" ], @@ -496,20 +540,25 @@ class basic_query_planner(dspy.Signature): """ You are the basic query planner in the system, you pick one agent, to answer the user's goal. Use the Agent_desc that describes the names and actions of agents available. + + **CRITICAL**: Before planning, check if any agents are available in Agent_desc. If Agent_desc is empty or contains no active agents, respond with: + plan: no_agents_available + plan_instructions: {"message": "No agents are currently enabled for analysis. Please enable at least one agent (preprocessing, statistical analysis, machine learning, or visualization) in your template preferences to proceed with data analysis."} + Example: Visualize height and salary? - plan:planner_data_viz_agent + plan:data_viz_agent plan_instructions: { - "planner_data_viz_agent": { + "data_viz_agent": { "create": ["scatter_plot"], "use": ["original_data"], "instruction": "use the original_data to create scatter_plot of height & salary, using plotly" } } Example: Tell me the correlation between X and Y - plan:planner_preprocessing_agent + plan:preprocessing_agent plan_instructions:{ - "planner_data_viz_agent": { + "data_viz_agent": { "create": ["correlation"], "use": ["original_data"], "instruction": "use the original_data to measure correlation of X & Y, using pandas" @@ -535,6 +584,11 @@ class intermediate_query_planner(dspy.Signature): 3. User-defined Goal You take these three inputs to develop a comprehensive plan to achieve the user-defined goal from the data & Agents available. In case you think the user-defined goal is infeasible you can ask the user to redefine or add more description to the goal. + + **CRITICAL**: Before planning, check if any agents are available in Agent_desc. If Agent_desc is empty or contains no active agents, respond with: + plan: no_agents_available + plan_instructions: {"message": "No agents are currently enabled for analysis. Please enable at least one agent (preprocessing, statistical analysis, machine learning, or visualization) in your template preferences to proceed with data analysis."} + Give your output in this format: plan: Agent1->Agent2 plan_instructions = { @@ -583,53 +637,108 @@ def __init__(self): self.allocator = dspy.Predict("goal,planner_desc,dataset->exact_word_complexity,reasoning") - async def forward(self, goal,dataset,Agent_desc): - complexity = self.allocator(goal=goal, planner_desc= str(self.planner_desc), dataset=str(dataset)) - # print(complexity) - if complexity.exact_word_complexity.strip() != "unrelated": + async def forward(self, goal, dataset, Agent_desc): + # Check if we have any agents available + if not Agent_desc or Agent_desc == "[]" or len(str(Agent_desc).strip()) < 10: + logger.log_message("No agents available for planning", level=logging.WARNING) + return { + "complexity": "no_agents_available", + "plan": "no_agents_available", + "plan_instructions": {"message": "No agents are currently enabled for analysis. Please enable at least one agent (preprocessing, statistical analysis, machine learning, or visualization) in your template preferences to proceed with data analysis."} + } + + try: + complexity = self.allocator(goal=goal, planner_desc=str(self.planner_desc), dataset=str(dataset)) + # If complexity is unrelated, return basic_qa_agent + if complexity.exact_word_complexity.strip() == "unrelated": + return { + "complexity": complexity.exact_word_complexity.strip(), + "plan": "basic_qa_agent", + "plan_instructions": "{'basic_qa_agent':'Not a data related query, please ask a data related-query'}" + } + + # Try to get plan with determined complexity try: + logger.log_message(f"Attempting to plan with complexity: {complexity.exact_word_complexity.strip()}", level=logging.DEBUG) plan = await self.planners[complexity.exact_word_complexity.strip()](goal=goal, dataset=dataset, Agent_desc=Agent_desc) + logger.log_message(f"Plan generated successfully: {plan}", level=logging.DEBUG) + + # Check if the planner returned no_agents_available + if hasattr(plan, 'plan') and 'no_agents_available' in str(plan.plan): + logger.log_message("Planner returned no_agents_available", level=logging.WARNING) + output = { + "complexity": "no_agents_available", + "plan": "no_agents_available", + "plan_instructions": {"message": "No agents are currently enabled for analysis. Please enable at least one agent (preprocessing, statistical analysis, machine learning, or visualization) in your template preferences to proceed with data analysis."} + } + else: + output = { + "complexity": complexity.exact_word_complexity.strip(), + "plan": dict(plan) + } except Exception as e: + logger.log_message(f"Error with {complexity.exact_word_complexity.strip()} planner, falling back to intermediate: {str(e)}", level=logging.WARNING) + + # Fallback to intermediate planner plan = await self.planners["intermediate"](goal=goal, dataset=dataset, Agent_desc=Agent_desc) - - output = {"complexity":complexity.exact_word_complexity.strip() - ,"plan":dict(plan)} - else: - output = {"complexity":complexity.exact_word_complexity.strip() - ,"plan":dict(plan="basic_qa_agent", plan_instructions="""{'basic_qa_agent':'Not a data related query, please ask a data related-query'}""") - } - # print(output) + logger.log_message(f"Fallback plan generated: {plan}", level=logging.DEBUG) + + # Check if the fallback planner also returned no_agents_available + if hasattr(plan, 'plan') and 'no_agents_available' in str(plan.plan): + logger.log_message("Fallback planner also returned no_agents_available", level=logging.WARNING) + output = { + "complexity": "no_agents_available", + "plan": "no_agents_available", + "plan_instructions": {"message": "No agents are currently enabled for analysis. Please enable at least one agent (preprocessing, statistical analysis, machine learning, or visualization) in your template preferences to proceed with data analysis."} + } + else: + output = { + "complexity": "intermediate", + "plan": dict(plan) + } + + except Exception as e: + logger.log_message(f"Error in planner forward: {str(e)}", level=logging.ERROR) + # Return error response + return { + "complexity": "error", + "plan": "basic_qa_agent", + "plan_instructions": {"error": f"Planning error: {str(e)}"} + } + return output -class planner_preprocessing_agent(dspy.Signature): +class preprocessing_agent(dspy.Signature): """ -You are a preprocessing agent in a multi-agent data analytics system. +You are a preprocessing agent that can work both individually and in multi-agent data analytics systems. You are given: -* A dataset (already loaded as `df`). -* A user-defined analysis goal (e.g., predictive modeling, exploration, cleaning). -* Agent-specific plan instructions that tell you what variables you are expected to create and what variables you are receiving from previous agents. -* processed_df is just an arbitrary name, it can be anything the planner says to clean! +* A dataset (already loaded as `df`). +* A user-defined analysis goal (e.g., predictive modeling, exploration, cleaning). +* Optional plan instructions that tell you what variables you are expected to create and what variables you are receiving from previous agents. + ### Your Responsibilities: -* Follow the provided plan and create only the required variables listed in the 'create' section of the plan instructions. -* Do not create fake data or introduce variables not explicitly part of the instructions. -* Do not read data from CSV ; the dataset (`df`) is already loaded and ready for processing. -* Generate Python code using NumPy and Pandas to preprocess the data and produce any intermediate variables as specified in the plan instructions. +* If plan_instructions are provided, follow the provided plan and create only the required variables listed in the 'create' section. +* If no plan_instructions are provided, perform standard data preprocessing based on the goal. +* Do not create fake data or introduce variables not explicitly part of the instructions. +* Do not read data from CSV; the dataset (`df`) is already loaded and ready for processing. +* Generate Python code using NumPy and Pandas to preprocess the data and produce any intermediate variables as specified. + ### Best Practices for Preprocessing: -1. Create a copy of the original DataFrame : It will always be stored as df, it already exists use it! +1. Create a copy of the original DataFrame: It will always be stored as df, it already exists use it! ```python processed_df = df.copy() ``` -2. Separate column types : +2. Separate column types: ```python numeric_cols = processed_df.select_dtypes(include='number').columns categorical_cols = processed_df.select_dtypes(include='object').columns ``` -3. Handle missing values : +3. Handle missing values: ```python for col in numeric_cols: processed_df[col] = processed_df[col].fillna(processed_df[col].median()) @@ -637,7 +746,7 @@ class planner_preprocessing_agent(dspy.Signature): for col in categorical_cols: processed_df[col] = processed_df[col].fillna(processed_df[col].mode()[0] if not processed_df[col].mode().empty else 'Unknown') ``` -4. Convert string columns to datetime safely : +4. Convert string columns to datetime safely: ```python def safe_to_datetime(x): try: @@ -647,138 +756,141 @@ def safe_to_datetime(x): cleaned_df['date_column'] = cleaned_df['date_column'].apply(safe_to_datetime) ``` -> Replace `processed_df`,'cleaned_df' and `date_column` with whatever names the user or planner provides. -5. Do not alter the DataFrame index : - Avoid using `reset_index()`, `set_index()`, or reindexing unless explicitly instructed. -6. Log assumptions and corrections in comments to clarify any choices made during preprocessing. -7. Do not mutate global state : Avoid in-place modifications unless clearly necessary (e.g., using `.copy()`). -8. Handle data types properly : +5. Do not alter the DataFrame index unless explicitly instructed. +6. Log assumptions and corrections in comments to clarify any choices made during preprocessing. +7. Do not mutate global state: Avoid in-place modifications unless clearly necessary (e.g., using `.copy()`). +8. Handle data types properly: * Avoid coercing types blindly (e.g., don't compare timestamps to strings or floats). * Use `pd.to_datetime(..., errors='coerce')` for safe datetime parsing. -9. Preserve column structure : Only drop or rename columns if explicitly instructed. +9. Preserve column structure: Only drop or rename columns if explicitly instructed. + ### Output: -1. Code : Python code that performs the requested preprocessing steps as per the plan instructions. -2. Summary : A brief explanation of what preprocessing was done (e.g., columns handled, missing value treatment). +1. Code: Python code that performs the requested preprocessing steps. +2. Summary: A brief explanation of what preprocessing was done (e.g., columns handled, missing value treatment). + ### Principles to Follow: --Never alter the DataFrame index unless explicitly instructed. --Handle missing data explicitly, filling with default values when necessary. --Preserve column structure and avoid unnecessary modifications. --Ensure data types are appropriate (e.g., dates parsed correctly). --Log assumptions in the code. +- Never alter the DataFrame index unless explicitly instructed. +- Handle missing data explicitly, filling with default values when necessary. +- Preserve column structure and avoid unnecessary modifications. +- Ensure data types are appropriate (e.g., dates parsed correctly). +- Log assumptions in the code. Respond in the user's language for all summary and reasoning but keep the code in english """ dataset = dspy.InputField(desc="The dataset, preloaded as df") goal = dspy.InputField(desc="User-defined goal for the analysis") - plan_instructions = dspy.InputField(desc="Agent-level instructions about what to create and receive") + plan_instructions = dspy.InputField(desc="Agent-level instructions about what to create and receive (optional for individual use)", default="") code = dspy.OutputField(desc="Generated Python code for preprocessing") summary = dspy.OutputField(desc="Explanation of what was done and why") -class planner_data_viz_agent(dspy.Signature): +class data_viz_agent(dspy.Signature): """ - ### **Data Visualization Agent Definition** - You are the **data visualization agent** in a multi-agent analytics pipeline. Your primary responsibility is to **generate visualizations** based on the **user-defined goal** and the **plan instructions**. +You are a data visualization agent that can work both individually and in multi-agent analytics pipelines. +Your primary responsibility is to generate visualizations based on the user-defined goal. + You are provided with: * **goal**: A user-defined goal outlining the type of visualization the user wants (e.g., "plot sales over time with trendline"). - * **dataset**: The dataset (e.g., `df_cleaned`) which will be passed to you by other agents in the pipeline. **Do not assume or create any variables** β€” **the data is already present and valid** when you receive it. +* **dataset**: The dataset (e.g., `df_cleaned`) which will be passed to you by other agents in the pipeline. Do not assume or create any variables β€” the data is already present and valid when you receive it. * **styling_index**: Specific styling instructions (e.g., axis formatting, color schemes) for the visualization. - * **plan_instructions**: A dictionary containing: - * **'create'**: List of **visualization components** you must generate (e.g., 'scatter_plot', 'bar_chart'). - * **'use'**: List of **variables you must use** to generate the visualizations. This includes datasets and any other variables provided by the other agents. - * **'instructions'**: A list of additional instructions related to the creation of the visualizations, such as requests for trendlines or axis formats. - --- - ### **Responsibilities**: +* **plan_instructions**: Optional dictionary containing: + * **'create'**: List of visualization components you must generate (e.g., 'scatter_plot', 'bar_chart'). + * **'use'**: List of variables you must use to generate the visualizations. + * **'instructions'**: Additional instructions related to the creation of the visualizations. + +### Responsibilities: 1. **Strict Use of Provided Variables**: - * You must **never create fake data**. Only use the variables and datasets that are explicitly **provided** to you in the `plan_instructions['use']` section. All the required data **must already be available**. - * If any variable listed in `plan_instructions['use']` is missing or invalid, **you must return an error** and not proceed with any visualization. + * You must never create fake data. Only use the variables and datasets that are explicitly provided. + * If plan_instructions are provided and any variable listed in plan_instructions['use'] is missing, return an error. + * If no plan_instructions are provided, work with the available dataset directly. + 2. **Visualization Creation**: - * Based on the **'create'** section of the `plan_instructions`, generate the **required visualization** using **Plotly**. For example, if the goal is to plot a time series, you might generate a line chart. - * Respect the **user-defined goal** in determining which type of visualization to create. + * Based on the goal and optional 'create' section of plan_instructions, generate the required visualization using Plotly. + * Respect the user-defined goal in determining which type of visualization to create. + 3. **Performance Optimization**: - * If the dataset contains **more than 50,000 rows**, you **must sample** the data to **5,000 rows** to improve performance. Use this method: + * If the dataset contains more than 50,000 rows, you must sample the data to 5,000 rows to improve performance: ```python if len(df) > 50000: df = df.sample(5000, random_state=42) ``` + 4. **Layout and Styling**: - * Apply formatting and layout adjustments as defined by the **styling_index**. This may include: - * Axis labels and title formatting. - * Tick formats for axes. - * Color schemes or color maps for visual elements. - * You must ensure that all axes (x and y) have **consistent formats** (e.g., using `K`, `M`, or 1,000 format, but not mixing formats). + * Apply formatting and layout adjustments as defined by the styling_index. + * Ensure that all axes (x and y) have consistent formats (e.g., using `K`, `M`, or 1,000 format, but not mixing formats). + 5. **Trendlines**: - * Trendlines should **only be included** if explicitly requested in the **'instructions'** section of `plan_instructions`. + * Trendlines should only be included if explicitly requested in the goal or plan_instructions. + 6. **Displaying the Visualization**: * Use Plotly's `fig.show()` method to display the created chart. - * **Never** output raw datasets or the **goal** itself. Only the visualization code and the chart should be returned. + * Never output raw datasets or the goal itself. Only the visualization code and the chart should be returned. + 7. **Error Handling**: - * If the required dataset or variables are missing or invalid (i.e., not included in `plan_instructions['use']`), return an error message indicating which specific variable is missing or invalid. - * If the **goal** or **create** instructions are ambiguous or invalid, return an error stating the issue. + * If required dataset or variables are missing, return an error message indicating which specific variable is missing. + * If the goal or create instructions are ambiguous, return an error stating the issue. + 8. **No Data Modification**: - * **Never** modify the provided dataset or generate new data. If the data needs preprocessing or cleaning, assume it's already been done by other agents. - --- - ### **Strict Conditions**: - * You **never** create any data. - * You **only** use the data and variables passed to you. - * If any required data or variable is missing or invalid, **you must stop** and return a clear error message. - * Respond in the user's language for all summary and reasoning but keep the code in english - * it should be update_yaxes, update_xaxes, not axis - By following these conditions and responsibilities, your role is to ensure that the **visualizations** are generated as per the user goal, using the valid data and instructions given to you. + * Never modify the provided dataset or generate new data. If the data needs preprocessing, assume it's already been done by other agents. + +### Important Notes: +- Use update_yaxes, update_xaxes, not axis +- Each visualization must be generated as a separate figure using go.Figure() +- Do NOT use subplots under any circumstances +- Each figure must be returned individually using: fig.to_html(full_html=False) +- Use update_layout with xaxis and yaxis only once per figure +- Enhance readability with low opacity (0.4-0.7) where appropriate +- Apply visually distinct colors for different elements or categories +- Use only one number format consistently: either 'K', 'M', or comma-separated values +- Only include trendlines in scatter plots if the user explicitly asks for them +- Always end each visualization with: fig.to_html(full_html=False) + +Respond in the user's language for all summary and reasoning but keep the code in english """ goal = dspy.InputField(desc="User-defined chart goal (e.g. trendlines, scatter plots)") dataset = dspy.InputField(desc="Details of the dataframe (`df`) and its columns") styling_index = dspy.InputField(desc="Instructions for plot styling and layout formatting") - plan_instructions = dspy.InputField(desc="Variables to create and receive for visualization purposes") + plan_instructions = dspy.InputField(desc="Variables to create and receive for visualization purposes (optional for individual use)", default="") code = dspy.OutputField(desc="Plotly Python code for the visualization") summary = dspy.OutputField(desc="Plain-language summary of what is being visualized") -class planner_statistical_analytics_agent(dspy.Signature): +class statistical_analytics_agent(dspy.Signature): """ -**Agent Definition:** -You are a statistical analytics agent in a multi-agent data analytics pipeline. +You are a statistical analytics agent that can work both individually and in multi-agent data analytics pipelines. You are given: * A dataset (usually a cleaned or transformed version like `df_cleaned`). * A user-defined goal (e.g., regression, seasonal decomposition). -* Agent-specific **plan instructions** specifying: - * Which **variables** you are expected to **CREATE** (e.g., `regression_model`). - * Which **variables** you will **USE** (e.g., `df_cleaned`, `target_variable`). - * A set of **instructions** outlining additional processing or handling for these variables (e.g., handling missing values, adding constants, transforming features, etc.). -**Your Responsibilities:** +* Optional plan instructions specifying: + * Which variables you are expected to CREATE (e.g., `regression_model`). + * Which variables you will USE (e.g., `df_cleaned`, `target_variable`). + * A set of instructions outlining additional processing or handling for these variables. + +### Your Responsibilities: * Use the `statsmodels` library to implement the required statistical analysis. * Ensure that all strings are handled as categorical variables via `C(col)` in model formulas. * Always add a constant using `sm.add_constant()`. -* Do **not** modify the DataFrame's index. +* Do not modify the DataFrame's index. * Convert `X` and `y` to float before fitting the model. * Handle missing values before modeling. * Avoid any data visualization (that is handled by another agent). * Write output to the console using `print()`. -**If the goal is regression:** + +### If the goal is regression: * Use `statsmodels.OLS` with proper handling of categorical variables and adding a constant term. * Handle missing values appropriately. -**If the goal is seasonal decomposition:** + +### If the goal is seasonal decomposition: * Use `statsmodels.tsa.seasonal_decompose`. * Ensure the time series and period are correctly provided (i.e., `period` should not be `None`). -**You must not:** -* You must always create the variables in `plan_instructions['CREATE']`. -* **Never create the `df` variable**. Only work with the variables passed via the `plan_instructions`. -* Rely on hardcoded column names β€” use those passed via `plan_instructions`. -* Introduce or modify intermediate variables unless they are explicitly listed in `plan_instructions['CREATE']`. -**Instructions to Follow:** -1. **CREATE** only the variables specified in `plan_instructions['CREATE']`. Do not create any intermediate or new variables. -2. **USE** only the variables specified in `plan_instructions['USE']` to carry out the task. -3. Follow any **additional instructions** in `plan_instructions['INSTRUCTIONS']` (e.g., preprocessing steps, encoding, handling missing values). -4. **Do not reassign or modify** any variables passed via `plan_instructions`. These should be used as-is. -**Example Workflow:** -Given that the `plan_instructions` specifies variables to **CREATE** and **USE**, and includes instructions, your approach should look like this: -1. Use `df_cleaned` and the variables like `X` and `y` from `plan_instructions` for analysis. -2. Follow instructions for preprocessing (e.g., handle missing values or scale features). -3. If the goal is regression: - * Use `sm.OLS` for model fitting. - * Handle categorical variables via `C(col)` and add a constant term. -4. If the goal is seasonal decomposition: - * Ensure `period` is provided and use `sm.tsa.seasonal_decompose`. -5. Store the output variable as specified in `plan_instructions['CREATE']`. + +### Instructions to Follow: +1. If plan_instructions are provided: + * CREATE only the variables specified in plan_instructions['CREATE']. Do not create any intermediate or new variables. + * USE only the variables specified in plan_instructions['USE'] to carry out the task. + * Follow any additional instructions in plan_instructions['INSTRUCTIONS']. + * Do not reassign or modify any variables passed via plan_instructions. +2. If no plan_instructions are provided, perform standard statistical analysis based on the goal and available data. + ### Example Code Structure: ```python import statsmodels.api as sm @@ -808,109 +920,89 @@ def statistical_model(X, y, goal, period=None): except Exception as e: return f"An error occurred: {e}" ``` -**Summary:** -1. Always **USE** the variables passed in `plan_instructions['USE']` to carry out the task. -2. Only **CREATE** the variables specified in `plan_instructions['CREATE']`. Do not create any additional variables. -3. Follow any **additional instructions** in `plan_instructions['INSTRUCTIONS']` (e.g., handling missing values, adding constants). + +### Summary: +1. Always USE the variables passed in plan_instructions['USE'] to carry out the task (if provided). +2. Only CREATE the variables specified in plan_instructions['CREATE'] (if provided). +3. Follow any additional instructions in plan_instructions['INSTRUCTIONS'] (if provided). 4. Ensure reproducibility by setting the random state appropriately and handling categorical variables. 5. Focus on statistical analysis and avoid any unnecessary data manipulation. -**Output:** -* The **code** implementing the statistical analysis, including all required steps. -* A **summary** of what the statistical analysis does, how it's performed, and why it fits the goal. + +### Output: +* The code implementing the statistical analysis, including all required steps. +* A summary of what the statistical analysis does, how it's performed, and why it fits the goal. * Respond in the user's language for all summary and reasoning but keep the code in english """ dataset = dspy.InputField(desc="Preprocessed dataset, often named df_cleaned") goal = dspy.InputField(desc="The user's statistical analysis goal, e.g., regression or seasonal_decompose") - plan_instructions = dspy.InputField(desc="Instructions on variables to create and receive for statistical modeling") + plan_instructions = dspy.InputField(desc="Instructions on variables to create and receive for statistical modeling (optional for individual use)", default="") code = dspy.OutputField(desc="Python code for statistical modeling using statsmodels") summary = dspy.OutputField(desc="A concise bullet-point summary of the statistical analysis performed and key findings") - - -class planner_sk_learn_agent(dspy.Signature): +class sk_learn_agent(dspy.Signature): """ - **Agent Definition:** - You are a machine learning agent in a multi-agent data analytics pipeline. +You are a machine learning agent that can work both individually and in multi-agent data analytics pipelines. You are given: * A dataset (often cleaned and feature-engineered). * A user-defined goal (e.g., classification, regression, clustering). - * Agent-specific **plan instructions** specifying: - * Which **variables** you are expected to **CREATE** (e.g., `trained_model`, `predictions`). - * Which **variables** you will **USE** (e.g., `df_cleaned`, `target_variable`, `feature_columns`). - * A set of **instructions** outlining additional processing or handling for these variables (e.g., handling missing values, applying transformations, or other task-specific guidelines). - **Your Responsibilities:** +* Optional plan instructions specifying: + * Which variables you are expected to CREATE (e.g., `trained_model`, `predictions`). + * Which variables you will USE (e.g., `df_cleaned`, `target_variable`, `feature_columns`). + * A set of instructions outlining additional processing or handling for these variables. + +### Your Responsibilities: * Use the scikit-learn library to implement the appropriate ML pipeline. * Always split data into training and testing sets where applicable. * Use `print()` for all outputs. * Ensure your code is: - * **Reproducible**: Set `random_state=42` wherever applicable. - * **Modular**: Avoid deeply nested code. - * **Focused on model building**, not visualization (leave plotting to the `data_viz_agent`). + * Reproducible: Set `random_state=42` wherever applicable. + * Modular: Avoid deeply nested code. + * Focused on model building, not visualization (leave plotting to the `data_viz_agent`). * Your task may include: * Preprocessing inputs (e.g., encoding). * Model selection and training. * Evaluation (e.g., accuracy, RMSE, classification report). - **You must not:** + +### You must not: * Visualize anything (that's another agent's job). - * Rely on hardcoded column names β€” use those passed via `plan_instructions`. - * **Never create or modify any variables not explicitly mentioned in `plan_instructions['CREATE']`.** - * **Never create the `df` variable**. You will **only** work with the variables passed via the `plan_instructions`. - * Do not introduce intermediate variables unless they are listed in `plan_instructions['CREATE']`. - **Instructions to Follow:** - 1. **CREATE** only the variables specified in the `plan_instructions['CREATE']` list. Do not create any intermediate or new variables. - 2. **USE** only the variables specified in the `plan_instructions['USE']` list. You are **not allowed** to create or modify any variables not listed in the plan instructions. - 3. Follow any **processing instructions** in the `plan_instructions['INSTRUCTIONS']` list. This might include tasks like handling missing values, scaling features, or encoding categorical variables. Always perform these steps on the variables specified in the `plan_instructions`. - 4. Do **not reassign or modify** any variables passed via `plan_instructions`. These should be used as-is. - **Example Workflow:** - Given that the `plan_instructions` specifies variables to **CREATE** and **USE**, and includes instructions, your approach should look like this: - 1. Use `df_cleaned` and `feature_columns` from the `plan_instructions` to extract your features (`X`). - 2. Use `target_column` from `plan_instructions` to extract your target (`y`). +* Rely on hardcoded column names β€” use those passed via plan_instructions or infer from data. +* Never create or modify any variables not explicitly mentioned in plan_instructions['CREATE'] (if provided). +* Never create the `df` variable. You will only work with the variables passed via the plan_instructions. +* Do not introduce intermediate variables unless they are listed in plan_instructions['CREATE'] (if provided). + +### Instructions to Follow: +1. If plan_instructions are provided: + * CREATE only the variables specified in the plan_instructions['CREATE'] list. + * USE only the variables specified in the plan_instructions['USE'] list. + * Follow any processing instructions in the plan_instructions['INSTRUCTIONS'] list. + * Do not reassign or modify any variables passed via plan_instructions. +2. If no plan_instructions are provided, perform standard machine learning analysis based on the goal and available data. + +### Example Workflow: +Given that the plan_instructions specifies variables to CREATE and USE, and includes instructions, your approach should look like this: +1. Use `df_cleaned` and `feature_columns` from the plan_instructions to extract your features (`X`). +2. Use `target_column` from plan_instructions to extract your target (`y`). 3. If instructions are provided (e.g., scale or encode), follow them. 4. Split data into training and testing sets using `train_test_split`. 5. Train the model based on the received goal (classification, regression, etc.). - 6. Store the output variables as specified in `plan_instructions['CREATE']`. - ### Example Code Structure: - ```python - from sklearn.model_selection import train_test_split - from sklearn.linear_model import LogisticRegression - from sklearn.metrics import classification_report - from sklearn.preprocessing import StandardScaler - # Ensure that all variables follow plan instructions: - # Use received inputs: df_cleaned, feature_columns, target_column - X = df_cleaned[feature_columns] - y = df_cleaned[target_column] - # Apply any preprocessing instructions (e.g., scaling if instructed) - if 'scale' in plan_instructions['INSTRUCTIONS']: - scaler = StandardScaler() - X = scaler.fit_transform(X) - # Split the data into training and testing sets - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) - # Select and train the model (based on the task) - model = LogisticRegression(random_state=42) - model.fit(X_train, y_train) - # Generate predictions - predictions = model.predict(X_test) - # Create the variable specified in 'plan_instructions': 'metrics' - metrics = classification_report(y_test, predictions) - # Print the results - print(metrics) - # Ensure the 'metrics' variable is returned as requested in the plan - ``` - **Summary:** - 1. Always **USE** the variables passed in `plan_instructions['USE']` to build the pipeline. - 2. Only **CREATE** the variables specified in `plan_instructions['CREATE']`. Do not create any additional variables. - 3. Follow any **additional instructions** in `plan_instructions['INSTRUCTIONS']` (e.g., preprocessing steps). - 4. Ensure reproducibility by setting `random_state=42` wherever necessary. +6. Store the output variables as specified in plan_instructions['CREATE']. + +### Summary: +1. Always USE the variables passed in plan_instructions['USE'] to build the pipeline (if provided). +2. Only CREATE the variables specified in plan_instructions['CREATE'] (if provided). +3. Follow any additional instructions in plan_instructions['INSTRUCTIONS'] (if provided). +4. Ensure reproducibility by setting random_state=42 wherever necessary. 5. Focus on model building, evaluation, and saving the required outputsβ€”avoid any unnecessary variables. - **Output:** - * The **code** implementing the ML task, including all required steps. - * A **summary** of what the model does, how it is evaluated, and why it fits the goal. + +### Output: +* The code implementing the ML task, including all required steps. +* A summary of what the model does, how it is evaluated, and why it fits the goal. * Respond in the user's language for all summary and reasoning but keep the code in english """ dataset = dspy.InputField(desc="Input dataset, often cleaned and feature-selected (e.g., df_cleaned)") goal = dspy.InputField(desc="The user's machine learning goal (e.g., classification or regression)") - plan_instructions = dspy.InputField(desc="Instructions indicating what to create and what variables to receive") + plan_instructions = dspy.InputField(desc="Instructions indicating what to create and what variables to receive (optional for individual use)", default="") code = dspy.OutputField(desc="Scikit-learn based machine learning code") summary = dspy.OutputField(desc="Explanation of the ML approach and evaluation") @@ -924,140 +1016,7 @@ class goal_refiner_agent(dspy.Signature): goal = dspy.InputField(desc="The user defined goal ") refined_goal = dspy.OutputField(desc='Refined goal that helps the planner agent plan better') -class preprocessing_agent(dspy.Signature): - """You are a AI data-preprocessing agent. Generate clean and efficient Python code using NumPy and Pandas to perform introductory data preprocessing on a pre-loaded DataFrame df, based on the user's analysis goals. - Preprocessing Requirements: - 1. Identify Column Types - - Separate columns into numeric and categorical using: - categorical_columns = df.select_dtypes(include=[object, 'category']).columns.tolist() - numeric_columns = df.select_dtypes(include=[np.number]).columns.tolist() - 2. Handle Missing Values - - Numeric columns: Impute missing values using the mean of each column - - Categorical columns: Impute missing values using the mode of each column - 3. Convert Date Strings to Datetime - - For any column suspected to represent dates (in string format), convert it to datetime using: - def safe_to_datetime(date): - try: - return pd.to_datetime(date, errors='coerce', cache=False) - except (ValueError, TypeError): - return pd.NaT - df['datetime_column'] = df['datetime_column'].apply(safe_to_datetime) - - Replace 'datetime_column' with the actual column names containing date-like strings - Important Notes: - - Do NOT create a correlation matrix β€” correlation analysis is outside the scope of preprocessing - - Do NOT generate any plots or visualizations - Output Instructions: - 1. Include the full preprocessing Python code - 2. Provide a brief bullet-point summary of the steps performed. Example: - β€’ Identified 5 numeric and 4 categorical columns - β€’ Filled missing numeric values with column means - β€’ Filled missing categorical values with column modes - β€’ Converted 1 date column to datetime format - Respond in the user's language for all summary and reasoning but keep the code in english - """ - dataset = dspy.InputField(desc="Available datasets loaded in the system, use this df, column_names set df as copy of df") - goal = dspy.InputField(desc="The user defined goal could ") - code = dspy.OutputField(desc ="The code that does the data preprocessing and introductory analysis") - summary = dspy.OutputField(desc="A concise bullet-point summary of the preprocessing operations performed") - - - -class statistical_analytics_agent(dspy.Signature): - # Statistical Analysis Agent, builds statistical models using StatsModel Package - """ - You are a statistical analytics agent. Your task is to take a dataset and a user-defined goal and output Python code that performs the appropriate statistical analysis to achieve that goal. Follow these guidelines: - IMPORTANT: You may be provided with previous interaction history. The section marked "### Current Query:" contains the user's current request. Any text in "### Previous Interaction History:" is for context only and is NOT part of the current request. - Data Handling: - Always handle strings as categorical variables in a regression using statsmodels C(string_column). - Do not change the index of the DataFrame. - Convert X and y into float when fitting a model. - Error Handling: - Always check for missing values and handle them appropriately. - Ensure that categorical variables are correctly processed. - Provide clear error messages if the model fitting fails. - Regression: - For regression, use statsmodels and ensure that a constant term is added to the predictor using sm.add_constant(X). - Handle categorical variables using C(column_name) in the model formula. - Fit the model with model = sm.OLS(y.astype(float), X.astype(float)).fit(). - Seasonal Decomposition: - Ensure the period is set correctly when performing seasonal decomposition. - Verify the number of observations works for the decomposition. - Output: - Ensure the code is executable and as intended. - Also choose the correct type of model for the problem - Avoid adding data visualization code. - Use code like this to prevent failing: - import pandas as pd - import numpy as np - import statsmodels.api as sm - def statistical_model(X, y, goal, period=None): - try: - # Check for missing values and handle them - X = X.dropna() - y = y.loc[X.index].dropna() - # Ensure X and y are aligned - X = X.loc[y.index] - # Convert categorical variables - for col in X.select_dtypes(include=['object', 'category']).columns: - X[col] = X[col].astype('category') - # Add a constant term to the predictor - X = sm.add_constant(X) - # Fit the model - if goal == 'regression': - # Handle categorical variables in the model formula - formula = 'y ~ ' + ' + '.join([f'C({col})' if X[col].dtype.name == 'category' else col for col in X.columns]) - model = sm.OLS(y.astype(float), X.astype(float)).fit() - return model.summary() - elif goal == 'seasonal_decompose': - if period is None: - raise ValueError("Period must be specified for seasonal decomposition") - decomposition = sm.tsa.seasonal_decompose(y, period=period) - return decomposition - else: - raise ValueError("Unknown goal specified. Please provide a valid goal.") - except Exception as e: - return f"An error occurred: {e}" - # Example usage: - result = statistical_analysis(X, y, goal='regression') - print(result) - If visualizing use plotly - Provide a concise bullet-point summary of the statistical analysis performed. - - Example Summary: - β€’ Applied linear regression with OLS to predict house prices based on 5 features - β€’ Model achieved R-squared of 0.78 - β€’ Significant predictors include square footage (p<0.001) and number of bathrooms (p<0.01) - β€’ Detected strong seasonal pattern with 12-month periodicity - β€’ Forecast shows 15% growth trend over next quarter - Respond in the user's language for all summary and reasoning but keep the code in english - """ - dataset = dspy.InputField(desc="Available datasets loaded in the system, use this df,columns set df as copy of df") - goal = dspy.InputField(desc="The user defined goal for the analysis to be performed") - code = dspy.OutputField(desc ="The code that does the statistical analysis using statsmodel") - summary = dspy.OutputField(desc="A concise bullet-point summary of the statistical analysis performed and key findings") - -class sk_learn_agent(dspy.Signature): - # Machine Learning Agent, performs task using sci-kit learn - """You are a machine learning agent. - Your task is to take a dataset and a user-defined goal, and output Python code that performs the appropriate machine learning analysis to achieve that goal. - You should use the scikit-learn library. - IMPORTANT: You may be provided with previous interaction history. The section marked "### Current Query:" contains the user's current request. Any text in "### Previous Interaction History:" is for context only and is NOT part of the current request. - Make sure your output is as intended! - Provide a concise bullet-point summary of the machine learning operations performed. - - Example Summary: - β€’ Trained a Random Forest classifier on customer churn data with 80/20 train-test split - β€’ Model achieved 92% accuracy and 88% F1-score - β€’ Feature importance analysis revealed that contract length and monthly charges are the strongest predictors of churn - β€’ Implemented K-means clustering (k=4) on customer shopping behaviors - β€’ Identified distinct segments: high-value frequent shoppers (22%), occasional big spenders (35%), budget-conscious regulars (28%), and rare visitors (15%) - Respond in the user's language for all summary and reasoning but keep the code in english - """ - dataset = dspy.InputField(desc="Available datasets loaded in the system, use this df,columns. set df as copy of df") - goal = dspy.InputField(desc="The user defined goal ") - code = dspy.OutputField(desc ="The code that does the Exploratory data analysis") - summary = dspy.OutputField(desc="A concise bullet-point summary of the machine learning analysis performed and key results") @@ -1094,55 +1053,6 @@ class code_combiner_agent(dspy.Signature): -class data_viz_agent(dspy.Signature): - # Visualizes data using Plotly - """ - You are an AI agent responsible for generating interactive data visualizations using Plotly. - IMPORTANT Instructions: - - The section marked "### Current Query:" contains the user's request. Any text in "### Previous Interaction History:" is for context only and should NOT be treated as part of the current request. - - You must only use the tools provided to you. This agent handles visualization only. - - If len(df) > 50000, always sample the dataset before visualization using: - if len(df) > 50000: - df = df.sample(50000, random_state=1) - - Each visualization must be generated as a **separate figure** using go.Figure(). - Do NOT use subplots under any circumstances. - - Each figure must be returned individually using: - fig.to_html(full_html=False) - - Use update_layout with xaxis and yaxis **only once per figure**. - - Enhance readability and clarity by: - β€’ Using low opacity (0.4-0.7) where appropriate - β€’ Applying visually distinct colors for different elements or categories - - Make sure the visual **answers the user's specific goal**: - β€’ Identify what insight or comparison the user is trying to achieve - β€’ Choose the visualization type and features (e.g., color, size, grouping) to emphasize that goal - β€’ For example, if the user asks for "trends in revenue," use a time series line chart; if they ask for "top-performing categories," use a bar chart sorted by value - β€’ Prioritize highlighting patterns, outliers, or comparisons relevant to the question - - Never include the dataset or styling index in the output. - - If there are no relevant columns for the requested visualization, respond with: - "No relevant columns found to generate this visualization." - - Use only one number format consistently: either 'K', 'M', or comma-separated values like 1,000/1,000,000. Do not mix formats. - - Only include trendlines in scatter plots if the user explicitly asks for them. - - Output only the code and a concise bullet-point summary of what the visualization reveals. - - Always end each visualization with: - fig.to_html(full_html=False) - Respond in the user's language for all summary and reasoning but keep the code in english - Example Summary: - β€’ Created an interactive scatter plot of sales vs. marketing spend with color-coded product categories - β€’ Included a trend line showing positive correlation (r=0.72) - β€’ Highlighted outliers where high marketing spend resulted in low sales - β€’ Generated a time series chart of monthly revenue from 2020-2023 - β€’ Added annotations for key business events - β€’ Visualization reveals 35% YoY growth with seasonal peaks in Q4 - - """ - goal = dspy.InputField(desc="user defined goal which includes information about data and chart they want to plot") - dataset = dspy.InputField(desc=" Provides information about the data in the data frame. Only use column names and dataframe_name as in this context") - styling_index = dspy.InputField(desc='Provides instructions on how to style your Plotly plots') - code= dspy.OutputField(desc="Plotly code that visualizes what the user needs according to the query & dataframe_index & styling_context") - summary = dspy.OutputField(desc="A concise bullet-point summary of the visualization created and key insights revealed") - - - class code_fix(dspy.Signature): """ You are an expert AI developer and data analyst assistant, skilled at identifying and resolving issues in Python code related to data analytics. Another agent has attempted to generate Python code for a data analytics task but produced code that is broken or throws an error. @@ -1218,42 +1128,190 @@ def __init__(self, agents, retrievers, user_id=None, db_session=None): self.agent_inputs = {} self.agent_desc = [] - # Create modules from agent signatures - for i, a in enumerate(agents): - name = a.__pydantic_core_schema__['schema']['model_name'] - self.agents[name] = dspy.asyncify(dspy.ChainOfThoughtWithHint(a)) - self.agent_inputs[name] = {x.strip() for x in str(agents[i].__pydantic_core_schema__['cls']).split('->')[0].split('(')[1].split(',')} - self.agent_desc.append(get_agent_description(name)) - - # Load ALL available template agents for direct access (regardless of user preferences) - if db_session: + logger.log_message(f"[INIT] Initializing auto_analyst_ind with user_id={user_id}, agents={len(agents) if agents else 0}", level=logging.INFO) + + # Load core agents based on user preferences (not always loaded) + if not agents and user_id and db_session: + try: + # Get user preferences for core agents + from src.db.schemas.models import AgentTemplate, UserTemplatePreference + + core_agent_names = ['preprocessing_agent', 'statistical_analytics_agent', 'sk_learn_agent', 'data_viz_agent'] + + for agent_name in core_agent_names: + logger.log_message(f"[INIT] Processing core agent: {agent_name}", level=logging.DEBUG) + + # Check if user has enabled this core agent + template = db_session.query(AgentTemplate).filter( + AgentTemplate.template_name == agent_name, + AgentTemplate.is_active == True + ).first() + + if not template: + logger.log_message(f"[INIT] Core agent template '{agent_name}' not found in database", level=logging.WARNING) + continue + + # Get the agent signature class + if agent_name == 'preprocessing_agent': + agent_signature = preprocessing_agent + elif agent_name == 'statistical_analytics_agent': + agent_signature = statistical_analytics_agent + elif agent_name == 'sk_learn_agent': + agent_signature = sk_learn_agent + elif agent_name == 'data_viz_agent': + agent_signature = data_viz_agent + + # Add to agents dict + self.agents[agent_name] = dspy.asyncify(dspy.ChainOfThought(agent_signature)) + + # Set input fields based on signature + if agent_name == 'data_viz_agent': + self.agent_inputs[agent_name] = {'goal', 'dataset', 'styling_index', 'plan_instructions'} + else: + self.agent_inputs[agent_name] = {'goal', 'dataset', 'plan_instructions'} + + # Get description from database + self.agent_desc.append({agent_name: get_agent_description(agent_name)}) + logger.log_message(f"[INIT] Successfully loaded core agent: {agent_name} with inputs: {self.agent_inputs[agent_name]}", level=logging.INFO) + + except Exception as e: + logger.log_message(f"[INIT] Error loading core agents based on preferences: {str(e)}", level=logging.ERROR) + # Fallback to loading all core agents if preference system fails + self._load_default_agents_fallback() + elif not agents: + # If no user_id/db_session provided, load all core agents as fallback + logger.log_message(f"[INIT] No agents provided and no user_id/db_session, loading fallback agents", level=logging.INFO) + self._load_default_agents_fallback() + else: + # Load standard agents from provided list (legacy support) + logger.log_message(f"[INIT] Loading agents from provided list (legacy support)", level=logging.INFO) + for i, a in enumerate(agents): + name = a.__pydantic_core_schema__['schema']['model_name'] + self.agents[name] = dspy.asyncify(dspy.ChainOfThought(a)) + self.agent_inputs[name] = {x.strip() for x in str(agents[i].__pydantic_core_schema__['cls']).split('->')[0].split('(')[1].split(',')} + logger.log_message(f"[INIT] Added legacy agent: {name}, inputs: {self.agent_inputs[name]}", level=logging.DEBUG) + self.agent_desc.append({name: get_agent_description(name)}) + + # Load ALL available template agents if user_id and db_session are provided + # For individual agent execution (@agent_name), users should be able to access any available agent + if user_id and db_session: try: + # For individual use, load ALL available templates regardless of user preferences template_signatures = load_all_available_templates_from_db(db_session) + logger.log_message(f"[INIT] Loaded {len(template_signatures)} template signatures from database", level=logging.INFO) + for template_name, signature in template_signatures.items(): + # Skip if this is a core agent - we'll load it separately + if template_name in ['preprocessing_agent', 'statistical_analytics_agent', 'sk_learn_agent', 'data_viz_agent']: + logger.log_message(f"[INIT] Skipping template {template_name} as it's a core agent", level=logging.DEBUG) + continue + # Add template agent to agents dict - self.agents[template_name] = dspy.asyncify(dspy.ChainOfThoughtWithHint(signature)) + self.agents[template_name] = dspy.asyncify(dspy.ChainOfThought(signature)) - # Extract input fields from signature - templates use standard fields - self.agent_inputs[template_name] = {'goal', 'dataset', 'styling_index', 'hint'} + # Determine if this is a visualization agent based on database category + is_viz_agent = False + try: + from src.db.schemas.models import AgentTemplate + + # Find template record to check category + template_record = db_session.query(AgentTemplate).filter( + AgentTemplate.template_name == template_name + ).first() + + if template_record and template_record.category and template_record.category.lower() == 'visualization': + is_viz_agent = True + else: + # Fallback to name-based detection for legacy templates + is_viz_agent = ('viz' in template_name.lower() or + 'visual' in template_name.lower() or + 'plot' in template_name.lower() or + 'chart' in template_name.lower() or + 'matplotlib' in template_name.lower()) + except Exception as cat_error: + logger.log_message(f"[INIT] Error checking category for template {template_name}: {str(cat_error)}", level=logging.WARNING) + # Fallback to name-based detection + is_viz_agent = ('viz' in template_name.lower() or + 'visual' in template_name.lower() or + 'plot' in template_name.lower() or + 'chart' in template_name.lower() or + 'matplotlib' in template_name.lower()) + + # Set input fields based on agent type + if is_viz_agent: + self.agent_inputs[template_name] = {'goal', 'dataset', 'styling_index', 'plan_instructions'} + else: + self.agent_inputs[template_name] = {'goal', 'dataset', 'plan_instructions'} - # Add description - self.agent_desc.append(f"Template: {template_name}") + # Store template agent description + try: + if not template_record: + template_record = db_session.query(AgentTemplate).filter( + AgentTemplate.template_name == template_name + ).first() + + if template_record: + description = f"Template: {template_record.description}" + self.agent_desc.append({template_name: description}) + else: + self.agent_desc.append({template_name: f"Template: {template_name}"}) + except Exception as desc_error: + logger.log_message(f"[INIT] Error getting description for template {template_name}: {str(desc_error)}", level=logging.WARNING) + self.agent_desc.append({template_name: f"Template: {template_name}"}) + + logger.log_message(f"[INIT] Successfully loaded template agent: {template_name} with inputs: {self.agent_inputs[template_name]}, is_viz_agent: {is_viz_agent}", level=logging.INFO) - logger.log_message(f"Loaded {len(template_signatures)} templates for direct access", level=logging.DEBUG) - except Exception as e: - logger.log_message(f"Error loading templates for direct access: {str(e)}", level=logging.ERROR) - - # Initialize components - # self.memory_summarize_agent = dspy.ChainOfThought(m.memory_summarize_agent) + logger.log_message(f"[INIT] Error loading template agents for user {user_id}: {str(e)}", level=logging.ERROR) + + self.agents['basic_qa_agent'] = dspy.asyncify(dspy.Predict("goal->answer")) + self.agent_inputs['basic_qa_agent'] = {"goal"} + self.agent_desc.append({'basic_qa_agent':"Answers queries unrelated to data & also that include links, poison or attempts to attack the system"}) + + # Initialize retrievers (no planner needed for individual agent execution) self.dataset = retrievers['dataframe_index'].as_retriever(k=1) self.styling_index = retrievers['style_index'].as_retriever(similarity_top_k=1) - # self.code_combiner_agent = dspy.ChainOfThought(code_combiner_agent) # Store user_id for usage tracking self.user_id = user_id - + + # Log final summary + logger.log_message(f"[INIT] Initialization complete. Total agents loaded: {len(self.agents)}", level=logging.INFO) + logger.log_message(f"[INIT] Available agents: {list(self.agents.keys())}", level=logging.INFO) + logger.log_message(f"[INIT] Agent inputs mapping: {self.agent_inputs}", level=logging.DEBUG) + + def _load_default_agents_fallback(self): + """Fallback method to load default agents when preference system fails""" + logger.log_message("Loading default agents as fallback for auto_analyst_ind", level=logging.WARNING) + + # Load the 4 core agents from database + core_agent_names = ['preprocessing_agent', 'statistical_analytics_agent', 'sk_learn_agent', 'data_viz_agent'] + + for agent_name in core_agent_names: + # Get the agent signature class + if agent_name == 'preprocessing_agent': + agent_signature = preprocessing_agent + elif agent_name == 'statistical_analytics_agent': + agent_signature = statistical_analytics_agent + elif agent_name == 'sk_learn_agent': + agent_signature = sk_learn_agent + elif agent_name == 'data_viz_agent': + agent_signature = data_viz_agent + + # Add to agents dict + self.agents[agent_name] = dspy.asyncify(dspy.ChainOfThought(agent_signature)) + + # Set input fields based on signature + if agent_name == 'data_viz_agent': + self.agent_inputs[agent_name] = {'goal', 'dataset', 'styling_index', 'plan_instructions'} + else: + self.agent_inputs[agent_name] = {'goal', 'dataset', 'plan_instructions'} + + # Get description from database + self.agent_desc.append({agent_name: get_agent_description(agent_name)}) + logger.log_message(f"Added fallback agent: {agent_name}", level=logging.DEBUG) + async def _track_agent_usage(self, agent_name): """Track usage for template agents""" try: @@ -1323,56 +1381,118 @@ async def _track_agent_usage(self, agent_name): async def execute_agent(self, specified_agent, inputs): """Execute agent and generate memory summary in parallel""" try: + logger.log_message(f"[EXECUTE] Starting execution of agent: {specified_agent}", level=logging.INFO) + logger.log_message(f"[EXECUTE] Agent inputs: {inputs}", level=logging.DEBUG) + # Execute main agent agent_result = await self.agents[specified_agent.strip()](**inputs) # Track usage for custom agents and templates await self._track_agent_usage(specified_agent.strip()) + logger.log_message(f"[EXECUTE] Agent {specified_agent} execution completed successfully", level=logging.INFO) return specified_agent.strip(), dict(agent_result) except Exception as e: + logger.log_message(f"[EXECUTE] Error executing agent {specified_agent}: {str(e)}", level=logging.ERROR) + import traceback + logger.log_message(f"[EXECUTE] Full traceback: {traceback.format_exc()}", level=logging.ERROR) return specified_agent.strip(), {"error": str(e)} - async def forward(self, query, specified_agent): + async def forward(self, query, specified_agent): try: + logger.log_message(f"[FORWARD] Processing query with specified agent: {specified_agent}", level=logging.INFO) + logger.log_message(f"[FORWARD] Query: {query}", level=logging.DEBUG) + # If specified_agent contains multiple agents separated by commas # This is for handling multiple @agent mentions in one query if "," in specified_agent: agent_list = [agent.strip() for agent in specified_agent.split(",")] + logger.log_message(f"[FORWARD] Multiple agents detected: {agent_list}", level=logging.INFO) return await self.execute_multiple_agents(query, agent_list) # Process query with specified agent (single agent case) dict_ = {} dict_['dataset'] = self.dataset.retrieve(query)[0].text dict_['styling_index'] = self.styling_index.retrieve(query)[0].text + dict_['hint'] = [] dict_['goal'] = query dict_['Agent_desc'] = str(self.agent_desc) - - # Prepare inputs - inputs = {x:dict_[x] for x in self.agent_inputs[specified_agent.strip()]} - inputs['hint'] = str(dict_['hint']).replace('[','').replace(']','') - # Execute agent - result = await self.agents[specified_agent.strip()](**inputs) + logger.log_message(f"[FORWARD] Retrieved context - dataset length: {len(dict_['dataset'])}, styling_index length: {len(dict_['styling_index'])}", level=logging.DEBUG) + + if specified_agent.strip() not in self.agent_inputs: + logger.log_message(f"[FORWARD] ERROR: Agent '{specified_agent.strip()}' not found in agent_inputs", level=logging.ERROR) + logger.log_message(f"[FORWARD] Available agents: {list(self.agent_inputs.keys())}", level=logging.ERROR) + return {"response": f"Agent '{specified_agent.strip()}' not found in agent inputs"} + + # Create inputs that match exactly what the agent expects + inputs = {} + required_fields = self.agent_inputs[specified_agent.strip()] + logger.log_message(f"[FORWARD] Required fields for {specified_agent.strip()}: {required_fields}", level=logging.INFO) + + for field in required_fields: + if field == 'goal': + inputs['goal'] = query + elif field == 'dataset': + inputs['dataset'] = dict_['dataset'] + elif field == 'styling_index': + inputs['styling_index'] = dict_['styling_index'] + elif field == 'plan_instructions': + inputs['plan_instructions'] = "" # Empty for individual agent use + elif field == 'hint': + inputs['hint'] = "" # Empty string for hint + else: + # For any other fields, try to get from dict_ if available + if field in dict_: + inputs[field] = dict_[field] + else: + logger.log_message(f"[FORWARD] WARNING: Field '{field}' required by agent but not available in dict_", level=logging.WARNING) + inputs[field] = "" # Provide empty string as fallback + + logger.log_message(f"[FORWARD] Prepared inputs for {specified_agent.strip()}: {list(inputs.keys())}", level=logging.INFO) + + if specified_agent.strip() not in self.agents: + logger.log_message(f"[FORWARD] ERROR: Agent '{specified_agent.strip()}' not found in agents", level=logging.ERROR) + logger.log_message(f"[FORWARD] Available agents: {list(self.agents.keys())}", level=logging.ERROR) + return {"response": f"Agent '{specified_agent.strip()}' not found in agents"} + + logger.log_message(f"[FORWARD] About to execute agent {specified_agent.strip()}", level=logging.INFO) + result = await self.agents[specified_agent.strip()](**inputs) + # Track usage for template agents await self._track_agent_usage(specified_agent.strip()) - output_dict = {specified_agent.strip(): dict(result)} + try: + result_dict = dict(result) + logger.log_message(f"[FORWARD] Agent execution successful, result keys: {list(result_dict.keys())}", level=logging.INFO) + except Exception as dict_error: + logger.log_message(f"[FORWARD] Error converting agent result to dict: {str(dict_error)}", level=logging.ERROR) + return {"response": f"Error converting agent result to dict: {str(dict_error)}"} + + output_dict = {specified_agent.strip(): result_dict} - if "error" in output_dict: - return {"response": f"Error executing agent: {output_dict['error']}"} + # Check for errors in the agent's response (not in the outer dict) + if "error" in result_dict: + logger.log_message(f"[FORWARD] Agent returned error: {result_dict['error']}", level=logging.ERROR) + return {"response": f"Error executing agent: {result_dict['error']}"} + logger.log_message(f"[FORWARD] Successfully processed agent {specified_agent.strip()}", level=logging.INFO) return output_dict except Exception as e: + logger.log_message(f"[FORWARD] Exception in auto_analyst_ind.forward: {str(e)}", level=logging.ERROR) + import traceback + logger.log_message(f"[FORWARD] Full traceback: {traceback.format_exc()}", level=logging.ERROR) return {"response": f"This is the error from the system: {str(e)}"} async def execute_multiple_agents(self, query, agent_list): """Execute multiple agents sequentially on the same query""" try: + logger.log_message(f"[MULTI] Executing multiple agents: {agent_list}", level=logging.INFO) + # Initialize resources dict_ = {} dict_['dataset'] = self.dataset.retrieve(query)[0].text @@ -1386,29 +1506,63 @@ async def execute_multiple_agents(self, query, agent_list): # Execute each agent sequentially for agent_name in agent_list: + logger.log_message(f"[MULTI] Processing agent: {agent_name}", level=logging.INFO) + if agent_name not in self.agents: + logger.log_message(f"[MULTI] Agent '{agent_name}' not found", level=logging.ERROR) results[agent_name] = {"error": f"Agent '{agent_name}' not found"} continue - # Prepare inputs for this agent - inputs = {x:dict_[x] for x in self.agent_inputs[agent_name] if x in dict_} - inputs['hint'] = str(dict_['hint']).replace('[','').replace(']','') + # Create inputs that match exactly what the agent expects + inputs = {} + required_fields = self.agent_inputs[agent_name] - # Execute agent - agent_result = await self.agents[agent_name](**inputs) - agent_dict = dict(agent_result) - results[agent_name] = agent_dict + logger.log_message(f"[MULTI] Required fields for {agent_name}: {required_fields}", level=logging.DEBUG) - # Track usage for template agents - await self._track_agent_usage(agent_name) + for field in required_fields: + if field == 'goal': + inputs['goal'] = query + elif field == 'dataset': + inputs['dataset'] = dict_['dataset'] + elif field == 'styling_index': + inputs['styling_index'] = dict_['styling_index'] + elif field == 'plan_instructions': + inputs['plan_instructions'] = "" # Empty for individual agent use + elif field == 'hint': + inputs['hint'] = "" # Empty string for hint + else: + # For any other fields, try to get from dict_ if available + if field in dict_: + inputs[field] = dict_[field] + else: + logger.log_message(f"[MULTI] WARNING: Field '{field}' required by agent but not available in dict_", level=logging.WARNING) + + logger.log_message(f"[MULTI] Prepared inputs for {agent_name}: {list(inputs.keys())}", level=logging.DEBUG) - # Collect code for later combination - if 'code' in agent_dict: - code_list.append(agent_dict['code']) + # Execute agent + try: + agent_result = await self.agents[agent_name](**inputs) + agent_dict = dict(agent_result) + results[agent_name] = agent_dict + + # Track usage for template agents + await self._track_agent_usage(agent_name) + + # Collect code for later combination + if 'code' in agent_dict: + code_list.append(agent_dict['code']) + + logger.log_message(f"[MULTI] Successfully executed agent: {agent_name}", level=logging.INFO) + + except Exception as agent_error: + logger.log_message(f"[MULTI] Error executing agent {agent_name}: {str(agent_error)}", level=logging.ERROR) + results[agent_name] = {"error": str(agent_error)} + logger.log_message(f"[MULTI] Completed multiple agent execution. Results: {list(results.keys())}", level=logging.INFO) return results except Exception as e: + logger.log_message(f"[MULTI] Error executing multiple agents: {str(e)}", level=logging.ERROR) return {"response": f"Error executing multiple agents: {str(e)}"} @@ -1422,34 +1576,63 @@ def __init__(self, agents, retrievers, user_id=None, db_session=None): self.agent_inputs = {} self.agent_desc = [] - # Load standard agents - for i, a in enumerate(agents): - name = a.__pydantic_core_schema__['schema']['model_name'] - self.agents[name] = dspy.asyncify(dspy.ChainOfThought(a)) - self.agent_inputs[name] = {x.strip() for x in str(agents[i].__pydantic_core_schema__['cls']).split('->')[0].split('(')[1].split(',')} - self.agent_desc.append({name: get_agent_description(name)}) - # Load user-enabled template agents if user_id and db_session are provided + logger.log_message(f"Loading user-enabled template agents for user {user_id}", level=logging.INFO) if user_id and db_session: try: + # For planner use, load planner-enabled templates (max 10, prioritized by usage) template_signatures = load_user_enabled_templates_for_planner_from_db(user_id, db_session) + logger.log_message(f"Loaded {len(template_signatures)} templates for planner use", level=logging.INFO) for template_name, signature in template_signatures.items(): + # Skip if this is a core agent - we'll load it separately + if template_name in ['preprocessing_agent', 'statistical_analytics_agent', 'sk_learn_agent', 'data_viz_agent']: + continue + # Add template agent to agents dict self.agents[template_name] = dspy.asyncify(dspy.ChainOfThought(signature)) - # Extract input fields from signature - templates use standard fields like data_viz_agent - self.agent_inputs[template_name] = {'goal', 'dataset', 'styling_index'} - - # Store template agent description + # Determine if this is a visualization agent based on database category + is_viz_agent = False try: from src.db.schemas.models import AgentTemplate - # Find template record + # Find template record to check category template_record = db_session.query(AgentTemplate).filter( AgentTemplate.template_name == template_name ).first() + if template_record and template_record.category and template_record.category.lower() == 'visualization': + is_viz_agent = True + else: + # Fallback to name-based detection for legacy templates + is_viz_agent = ('viz' in template_name.lower() or + 'visual' in template_name.lower() or + 'plot' in template_name.lower() or + 'chart' in template_name.lower() or + 'matplotlib' in template_name.lower()) + except Exception as cat_error: + logger.log_message(f"Error checking category for template {template_name}: {str(cat_error)}", level=logging.WARNING) + # Fallback to name-based detection + is_viz_agent = ('viz' in template_name.lower() or + 'visual' in template_name.lower() or + 'plot' in template_name.lower() or + 'chart' in template_name.lower() or + 'matplotlib' in template_name.lower()) + + # Set input fields based on agent type + if is_viz_agent: + self.agent_inputs[template_name] = {'goal', 'dataset', 'styling_index', 'plan_instructions'} + else: + self.agent_inputs[template_name] = {'goal', 'dataset', 'plan_instructions'} + + # Store template agent description + try: + if not template_record: + template_record = db_session.query(AgentTemplate).filter( + AgentTemplate.template_name == template_name + ).first() + if template_record: description = f"Template: {template_record.description}" self.agent_desc.append({template_name: description}) @@ -1458,23 +1641,87 @@ def __init__(self, agents, retrievers, user_id=None, db_session=None): except Exception as desc_error: logger.log_message(f"Error getting description for template {template_name}: {str(desc_error)}", level=logging.WARNING) self.agent_desc.append({template_name: f"Template: {template_name}"}) - - logger.log_message(f"Loaded {len(template_signatures)} enabled templates for planner", level=logging.DEBUG) - + except Exception as e: logger.log_message(f"Error loading template agents for user {user_id}: {str(e)}", level=logging.ERROR) + # Load core agents based on user preferences (not always loaded) + if not agents and user_id and db_session: + try: + # Get user preferences for core agents + from src.db.schemas.models import AgentTemplate, UserTemplatePreference + + core_agent_names = ['preprocessing_agent', 'statistical_analytics_agent', 'sk_learn_agent', 'data_viz_agent'] + + for agent_name in core_agent_names: + # Check if user has enabled this core agent + template = db_session.query(AgentTemplate).filter( + AgentTemplate.template_name == agent_name, + AgentTemplate.is_active == True + ).first() + + if not template: + logger.log_message(f"Core agent template '{agent_name}' not found in database", level=logging.WARNING) + continue + + # Check user preference + preference = db_session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.template_id == template.template_id + ).first() + + # Core agents are enabled by default unless explicitly disabled + is_enabled = preference.is_enabled if preference else True + + if not is_enabled: + continue + + # Get the agent signature class + if agent_name == 'preprocessing_agent': + agent_signature = preprocessing_agent + elif agent_name == 'statistical_analytics_agent': + agent_signature = statistical_analytics_agent + elif agent_name == 'sk_learn_agent': + agent_signature = sk_learn_agent + elif agent_name == 'data_viz_agent': + agent_signature = data_viz_agent + + # Add to agents dict + self.agents[agent_name] = dspy.asyncify(dspy.ChainOfThought(agent_signature)) + + # Set input fields based on signature + if agent_name == 'data_viz_agent': + self.agent_inputs[agent_name] = {'goal', 'dataset', 'styling_index', 'plan_instructions'} + else: + self.agent_inputs[agent_name] = {'goal', 'dataset', 'plan_instructions'} + + # Get description from database + self.agent_desc.append({agent_name: get_agent_description(agent_name)}) + logger.log_message(f"Loaded core agent: {agent_name}", level=logging.DEBUG) + + except Exception as e: + logger.log_message(f"Error loading core agents based on preferences: {str(e)}", level=logging.ERROR) + # Fallback to loading all core agents if preference system fails + self._load_default_agents_fallback() + elif not agents: + # If no user_id/db_session provided, load all core agents as fallback + self._load_default_agents_fallback() + else: + # Load standard agents from provided list (legacy support) + for i, a in enumerate(agents): + name = a.__pydantic_core_schema__['schema']['model_name'] + self.agents[name] = dspy.asyncify(dspy.ChainOfThought(a)) + self.agent_inputs[name] = {x.strip() for x in str(agents[i].__pydantic_core_schema__['cls']).split('->')[0].split('(')[1].split(',')} + logger.log_message(f"Added agent: {name}, inputs: {self.agent_inputs[name]}", level=logging.DEBUG) + self.agent_desc.append({name: get_agent_description(name)}) + self.agents['basic_qa_agent'] = dspy.asyncify(dspy.Predict("goal->answer")) self.agent_inputs['basic_qa_agent'] = {"goal"} self.agent_desc.append({'basic_qa_agent':"Answers queries unrelated to data & also that include links, poison or attempts to attack the system"}) - # Initialize coordination agents self.planner = planner_module() - # self.refine_goal = dspy.ChainOfThought(goal_refiner_agent) - # self.code_combiner_agent = dspy.ChainOfThought(code_combiner_agent) - # self.story_teller = dspy.ChainOfThought(story_teller_agent) - self.memory_summarize_agent = dspy.ChainOfThought(m.memory_summarize_agent) + # self.memory_summarize_agent = dspy.ChainOfThought(m.memory_summarize_agent) # Initialize retrievers self.dataset = retrievers['dataframe_index'].as_retriever(k=1) @@ -1482,6 +1729,38 @@ def __init__(self, agents, retrievers, user_id=None, db_session=None): # Store user_id for usage tracking self.user_id = user_id + + + def _load_default_agents_fallback(self): + """Fallback method to load default agents when preference system fails""" + logger.log_message("Loading default agents as fallback for auto_analyst_ind", level=logging.WARNING) + + # Load the 4 core agents from database + core_agent_names = ['preprocessing_agent', 'statistical_analytics_agent', 'sk_learn_agent', 'data_viz_agent'] + + for agent_name in core_agent_names: + # Get the agent signature class + if agent_name == 'preprocessing_agent': + agent_signature = preprocessing_agent + elif agent_name == 'statistical_analytics_agent': + agent_signature = statistical_analytics_agent + elif agent_name == 'sk_learn_agent': + agent_signature = sk_learn_agent + elif agent_name == 'data_viz_agent': + agent_signature = data_viz_agent + + # Add to agents dict + self.agents[agent_name] = dspy.asyncify(dspy.ChainOfThought(agent_signature)) + + # Set input fields based on signature + if agent_name == 'data_viz_agent': + self.agent_inputs[agent_name] = {'goal', 'dataset', 'styling_index', 'plan_instructions'} + else: + self.agent_inputs[agent_name] = {'goal', 'dataset', 'plan_instructions'} + + # Get description from database + self.agent_desc.append({agent_name: get_agent_description(agent_name)}) + logger.log_message(f"Added fallback agent: {agent_name}", level=logging.DEBUG) async def _track_agent_usage(self, agent_name): """Track usage for template agents""" @@ -1507,7 +1786,7 @@ async def _track_agent_usage(self, agent_name): ).first() if not template: - logger.log_message(f"Template '{agent_name}' not found", level=logging.WARNING) + logger.log_message(f"Template '{agent_name}' not found for usage tracking", level=logging.WARNING) return # Find or create user template preference record @@ -1516,35 +1795,33 @@ async def _track_agent_usage(self, agent_name): UserTemplatePreference.template_id == template.template_id ).first() - if preference: - # Update existing preference - preference.usage_count += 1 - preference.last_used_at = datetime.now(UTC) - preference.updated_at = datetime.now(UTC) - else: - # Create new preference record when template is used directly (via @mention) - # Direct usage doesn't auto-enable for planner but tracks usage + if not preference: + # Create new preference record (disabled by default) preference = UserTemplatePreference( user_id=self.user_id, template_id=template.template_id, - is_enabled=False, # Default disabled for planner - usage_count=1, - last_used_at=datetime.now(UTC), + is_enabled=False, # Disabled by default + usage_count=0, + last_used_at=None, created_at=datetime.now(UTC), updated_at=datetime.now(UTC) ) session.add(preference) + # Update usage tracking + preference.usage_count += 1 + preference.last_used_at = datetime.now(UTC) + preference.updated_at = datetime.now(UTC) session.commit() logger.log_message( - f"Tracked usage for template '{agent_name}' for user {self.user_id} (count: {preference.usage_count})", + f"Tracked usage for template '{agent_name}' (count: {preference.usage_count})", level=logging.DEBUG ) except Exception as e: session.rollback() - logger.log_message(f"Error tracking template usage for {agent_name}: {str(e)}", level=logging.ERROR) + logger.log_message(f"Error tracking usage for template {agent_name}: {str(e)}", level=logging.ERROR) finally: session.close() @@ -1553,14 +1830,17 @@ async def _track_agent_usage(self, agent_name): async def execute_agent(self, agent_name, inputs): """Execute a single agent with given inputs""" + try: result = await self.agents[agent_name.strip()](**inputs) # Track usage for custom agents and templates await self._track_agent_usage(agent_name.strip()) + logger.log_message(f"Agent {agent_name} execution completed", level=logging.DEBUG) return agent_name.strip(), dict(result) except Exception as e: + logger.log_message(f"Error in execute_agent for {agent_name}: {str(e)}", level=logging.ERROR) return agent_name.strip(), {"error": str(e)} async def get_plan(self, query): @@ -1571,18 +1851,26 @@ async def get_plan(self, query): dict_['goal'] = query dict_['Agent_desc'] = str(self.agent_desc) - module_return = await self.planner(goal=dict_['goal'], dataset=dict_['dataset'], Agent_desc=dict_['Agent_desc']) - plan_dict = dict(module_return['plan']) - if 'complexity' in module_return: - complexity = module_return['complexity'] - else: - complexity = 'basic' - plan_dict['complexity'] = complexity + + try: + module_return = await self.planner(goal=dict_['goal'], dataset=dict_['dataset'], Agent_desc=dict_['Agent_desc']) + + plan_dict = dict(module_return['plan']) + if 'complexity' in module_return: + complexity = module_return['complexity'] + else: + complexity = 'basic' + plan_dict['complexity'] = complexity - return plan_dict + return plan_dict + + except Exception as e: + logger.log_message(f"Error in get_plan: {str(e)}", level=logging.ERROR) + raise async def execute_plan(self, query, plan): """Execute the plan and yield results as they complete""" + dict_ = {} dict_['dataset'] = self.dataset.retrieve(query)[0].text dict_['styling_index'] = self.styling_index.retrieve(query)[0].text @@ -1593,91 +1881,58 @@ async def execute_plan(self, query, plan): # Clean and split the plan string into agent names plan_text = plan.get("plan", "").replace("Plan", "").replace(":", "").strip() - if "basic_qa_agent" in plan_text: inputs = dict(goal=query) - agent_name, response = await self.execute_agent('basic_qa_agent', inputs) #! SHOULDN'T THIS BE **inputs ? + agent_name, response = await self.execute_agent('basic_qa_agent', inputs) yield agent_name, inputs, response return plan_list = [agent.strip() for agent in plan_text.split("->") if agent.strip()] - + logger.log_message(f"Plan list: {plan_list}", level=logging.INFO) # Parse the attached plan_instructions into a dict raw_instr = plan.get("plan_instructions", {}) if isinstance(raw_instr, str): try: plan_instructions = json.loads(raw_instr) - except Exception: + except Exception as e: + logger.log_message(f"Error parsing plan_instructions JSON: {str(e)}", level=logging.ERROR) plan_instructions = {} elif isinstance(raw_instr, dict): - plan_instructions = str(raw_instr) + plan_instructions = raw_instr else: plan_instructions = {} - # If no plan was produced, short-circuit - if not plan_list: - yield "plan_not_found", dict(plan), {"error": "No plan found"} - return - - # Create async tasks for each agent, similar to deep analysis approach - tasks = [] - task_info = [] + # Check if we have no valid agents to execute + if not plan_list or all(agent not in self.agents for agent in plan_list): + yield "plan_not_found", None, {"error": "No valid agents found in plan"} + return - for idx, agent_name in enumerate(plan_list): - key = agent_name.strip() - # gather input fields except plan_instructions - inputs = { - param: dict_[param] - for param in self.agent_inputs[key] - if param != "plan_instructions" - } - - # attach the specific instructions for this agent with prev/next format - if "plan_instructions" in self.agent_inputs[key]: - # Get current agent instructions - current_instructions = plan_instructions.get(key, {"create": [], "use": [], "instruction": ""}) - - # Format instructions with your_task first - formatted_instructions = {"your_task": current_instructions} + # Execute agents in sequence + for agent_name in plan_list: + if agent_name not in self.agents: + yield agent_name, {}, {"error": f"Agent '{agent_name}' not available"} + continue + + try: + # Prepare inputs for the agent + inputs = {x: dict_[x] for x in self.agent_inputs[agent_name] if x in dict_} - # Add previous agent instructions if available - if idx > 0: - prev_agent = plan_list[idx-1].strip() - prev_instructions = plan_instructions.get(prev_agent, {}).get("instruction", "") - formatted_instructions[f"Previous Agent {prev_agent}"] = prev_instructions + # Add plan instructions if available for this agent + if agent_name in plan_instructions: + inputs['plan_instructions'] = json.dumps(plan_instructions[agent_name]) + else: + inputs['plan_instructions'] = "" - # Add next agent instructions if available - if idx < len(plan_list) - 1: - next_agent = plan_list[idx+1].strip() - next_instructions = plan_instructions.get(next_agent, {}).get("instruction", "") - formatted_instructions[f"Next Agent {next_agent}"] = next_instructions + # logger.log_message(f"Agent inputs for {agent_name}: {inputs}", level=logging.INFO) - inputs["plan_instructions"] = str(formatted_instructions) - - - # Create async task directly from the asyncified agent - task = self.execute_agent(agent_name, inputs) - tasks.append(task) - task_info.append((agent_name, inputs)) - - # Execute all tasks concurrently and yield results as they complete - try: - # Execute all tasks concurrently - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Yield results with their corresponding task info - for i, result in enumerate(results): - agent_name, inputs = task_info[i] + # Execute the agent + agent_result_name, response = await self.execute_agent(agent_name, inputs) - if isinstance(result, Exception): - yield agent_name, inputs, {"error": str(result)} - else: - name, response = result - yield name, inputs, response + yield agent_result_name, inputs, response - except Exception as e: - logger.log_message(f"Error in task execution: {str(e)}", level=logging.ERROR) - yield "error", {}, {"error": str(e)} + except Exception as e: + logger.log_message(f"Error executing agent {agent_name}: {str(e)}", level=logging.ERROR) + yield agent_name, {}, {"error": f"Error executing {agent_name}: {str(e)}"} diff --git a/auto-analyst-backend/src/agents/deep_agents.py b/auto-analyst-backend/src/agents/deep_agents.py index 7c62e91e..d1f752b9 100644 --- a/auto-analyst-backend/src/agents/deep_agents.py +++ b/auto-analyst-backend/src/agents/deep_agents.py @@ -728,6 +728,8 @@ class deep_code_fix(dspy.Signature): class deep_analysis_module(dspy.Module): def __init__(self,agents, agents_desc): + logger.log_message(f"Initializing deep_analysis_module with {len(agents)} agents: {list(agents.keys())}", level=logging.INFO) + self.agents = agents # Make all dspy operations async using asyncify self.deep_questions = dspy.asyncify(dspy.Predict(deep_questions)) @@ -741,6 +743,8 @@ def __init__(self,agents, agents_desc): self.styling_instructions = chart_instructions self.agents_desc = agents_desc self.final_conclusion = dspy.asyncify(dspy.ChainOfThought(final_conclusion)) + + logger.log_message(f"Deep analysis module initialized successfully with agents: {list(self.agents.keys())}", level=logging.INFO) async def execute_deep_analysis_streaming(self, goal, dataset_info, session_df=None): """ @@ -795,7 +799,8 @@ async def execute_deep_analysis_streaming(self, goal, dataset_info, session_df=N if not all(key in self.agents for key in keys): raise ValueError(f"Invalid agent key(s) in plan instructions. Available agents: {list(self.agents.keys())}") - + logger.log_message(f"Plan instructions: {plan_instructions}", logging.INFO) + logger.log_message(f"Keys: {keys}", logging.INFO) except (ValueError, SyntaxError, json.JSONDecodeError) as e: try: deep_plan = await self.deep_plan_fixer(plan_instructions=deep_plan.plan_instructions) @@ -803,6 +808,8 @@ async def execute_deep_analysis_streaming(self, goal, dataset_info, session_df=N if not isinstance(plan_instructions, dict): plan_instructions = json.loads(deep_plan.fixed_plan) keys = [key for key in plan_instructions.keys()] + logger.log_message(f"Plan instructions fixed: {plan_instructions}", logging.INFO) + logger.log_message(f"Keys: {keys}", logging.INFO) except (ValueError, SyntaxError, json.JSONDecodeError) as e: logger.log_message(f"Error parsing plan instructions: {e}", logging.ERROR) raise e @@ -828,19 +835,18 @@ async def execute_deep_analysis_streaming(self, goal, dataset_info, session_df=N dspy.Example( goal=questions.deep_questions, dataset=dataset_info, - **({"plan_instructions": str(plan_instructions[key])} if "planner" in key else {}), - **({"styling_index": "Sample styling guidelines"} if "data_viz" in key else {}) + plan_instructions=str(plan_instructions[key]), + **({"styling_index": "Sample styling guidelines"} if "data_viz" in key or "viz" in key.lower() or "visual" in key.lower() or "plot" in key.lower() or "chart" in key.lower() else {}) ).with_inputs( "goal", - "dataset", - *(["plan_instructions"] if "planner" in key else []), - *(["styling_index"] if "data_viz" in key else []) + "dataset", + "plan_instructions", + *(["styling_index"] if "data_viz" in key or "viz" in key.lower() or "visual" in key.lower() or "plot" in key.lower() or "chart" in key.lower() else []) ) for key in keys ] - tasks = [self.agents[key](**q) for q, key in zip(queries, keys)] - + # Await all tasks to complete summaries = [] codes = [] diff --git a/auto-analyst-backend/src/db/init_db.py b/auto-analyst-backend/src/db/init_db.py index 4a8bf715..33ec869c 100644 --- a/auto-analyst-backend/src/db/init_db.py +++ b/auto-analyst-backend/src/db/init_db.py @@ -15,15 +15,17 @@ # Determine database type and set appropriate engine configurations if DATABASE_URL.startswith('postgresql'): # PostgreSQL-specific configuration - engine = create_engine( - DATABASE_URL, - pool_size=10, - max_overflow=20, - pool_pre_ping=True, # Check connection validity before use - pool_recycle=300 # Recycle connections after 5 minutes - ) - is_postgresql = True - logger.log_message("Using PostgreSQL database engine", logging.INFO) + ask = input("Are you sure?") + if ask.lower() == "yes": + engine = create_engine( + DATABASE_URL, + pool_size=10, + max_overflow=20, + pool_pre_ping=True, # Check connection validity before use + pool_recycle=300 # Recycle connections after 5 minutes + ) + is_postgresql = True + logger.log_message("Using PostgreSQL database engine", logging.INFO) else: # SQLite configuration engine = create_engine(DATABASE_URL) diff --git a/auto-analyst-backend/src/db/init_default_agents.py b/auto-analyst-backend/src/db/init_default_agents.py new file mode 100644 index 00000000..f6528949 --- /dev/null +++ b/auto-analyst-backend/src/db/init_default_agents.py @@ -0,0 +1,281 @@ +""" +Initialize default agents in the database. +This module should be run during application startup to ensure +default agents are available in the database. +""" + +import logging +from datetime import datetime, UTC +from src.utils.logger import Logger + +# Initialize logger +logger = Logger("init_default_agents", see_time=True, console_log=False) + +def load_default_agents_to_db(db_session, force_update=False): + """ + Load the default agents into the AgentTemplate table. + + Args: + db_session: Database session + force_update: If True, update existing agents. If False, skip existing ones. + + Returns: + Tuple (success: bool, message: str) + """ + try: + from src.db.schemas.models import AgentTemplate + + # Define default agents with their signatures and metadata + default_agents = { + "preprocessing_agent": { + "display_name": "Data Preprocessing Agent", + "description": "Cleans and prepares a DataFrame using Pandas and NumPyβ€”handles missing values, detects column types, and converts date strings to datetime.", + "prompt_template": """You are a AI data-preprocessing agent. Generate clean and efficient Python code using NumPy and Pandas to perform introductory data preprocessing on a pre-loaded DataFrame df, based on the user's analysis goals. +Preprocessing Requirements: +1. Identify Column Types +- Separate columns into numeric and categorical using: + categorical_columns = df.select_dtypes(include=[object, 'category']).columns.tolist() + numeric_columns = df.select_dtypes(include=[np.number]).columns.tolist() +2. Handle Missing Values +- Numeric columns: Impute missing values using the mean of each column +- Categorical columns: Impute missing values using the mode of each column +3. Convert Date Strings to Datetime +- For any column suspected to represent dates (in string format), convert it to datetime using: + def safe_to_datetime(date): + try: + return pd.to_datetime(date, errors='coerce', cache=False) + except (ValueError, TypeError): + return pd.NaT + df['datetime_column'] = df['datetime_column'].apply(safe_to_datetime) +- Replace 'datetime_column' with the actual column names containing date-like strings +Important Notes: +- Do NOT create a correlation matrix β€” correlation analysis is outside the scope of preprocessing +- Do NOT generate any plots or visualizations +Output Instructions: +1. Include the full preprocessing Python code +2. Provide a brief bullet-point summary of the steps performed. Example: +β€’ Identified 5 numeric and 4 categorical columns +β€’ Filled missing numeric values with column means +β€’ Filled missing categorical values with column modes +β€’ Converted 1 date column to datetime format + Respond in the user's language for all summary and reasoning but keep the code in english""", + "category": "Data Manipulation", + "icon_url": "https://cdn.jsdelivr.net/gh/devicons/devicon/icons/pandas/pandas-original.svg" + }, + "statistical_analytics_agent": { + "display_name": "Statistical Analytics Agent", + "description": "Performs statistical analysis (e.g., regression, seasonal decomposition) using statsmodels, with proper handling of categorical data and missing values.", + "prompt_template": """ +You are a statistical analytics agent. Your task is to take a dataset and a user-defined goal and output Python code that performs the appropriate statistical analysis to achieve that goal. Follow these guidelines: +IMPORTANT: You may be provided with previous interaction history. The section marked "### Current Query:" contains the user's current request. Any text in "### Previous Interaction History:" is for context only and is NOT part of the current request. +Data Handling: +Always handle strings as categorical variables in a regression using statsmodels C(string_column). +Do not change the index of the DataFrame. +Convert X and y into float when fitting a model. +Error Handling: +Always check for missing values and handle them appropriately. +Ensure that categorical variables are correctly processed. +Provide clear error messages if the model fitting fails. +Regression: +For regression, use statsmodels and ensure that a constant term is added to the predictor using sm.add_constant(X). +Handle categorical variables using C(column_name) in the model formula. +Fit the model with model = sm.OLS(y.astype(float), X.astype(float)).fit(). +Seasonal Decomposition: +Ensure the period is set correctly when performing seasonal decomposition. +Verify the number of observations works for the decomposition. +Output: +Ensure the code is executable and as intended. +Also choose the correct type of model for the problem +Avoid adding data visualization code. +Use code like this to prevent failing: +import pandas as pd +import numpy as np +import statsmodels.api as sm +def statistical_model(X, y, goal, period=None): + try: + # Check for missing values and handle them + X = X.dropna() + y = y.loc[X.index].dropna() + # Ensure X and y are aligned + X = X.loc[y.index] + # Convert categorical variables + for col in X.select_dtypes(include=['object', 'category']).columns: + X[col] = X[col].astype('category') + # Add a constant term to the predictor + X = sm.add_constant(X) + # Fit the model + if goal == 'regression': + # Handle categorical variables in the model formula + formula = 'y ~ ' + ' + '.join([f'C({col})' if X[col].dtype.name == 'category' else col for col in X.columns]) + model = sm.OLS(y.astype(float), X.astype(float)).fit() + return model.summary() + elif goal == 'seasonal_decompose': + if period is None: + raise ValueError("Period must be specified for seasonal decomposition") + decomposition = sm.tsa.seasonal_decompose(y, period=period) + return decomposition + else: + raise ValueError("Unknown goal specified. Please provide a valid goal.") + except Exception as e: + return f"An error occurred: {e}" +# Example usage: +result = statistical_analysis(X, y, goal='regression') +print(result) +If visualizing use plotly +Provide a concise bullet-point summary of the statistical analysis performed. + +Example Summary: +β€’ Applied linear regression with OLS to predict house prices based on 5 features +β€’ Model achieved R-squared of 0.78 +β€’ Significant predictors include square footage (p<0.001) and number of bathrooms (p<0.01) +β€’ Detected strong seasonal pattern with 12-month periodicity +β€’ Forecast shows 15% growth trend over next quarter +Respond in the user's language for all summary and reasoning but keep the code in english""", + "category": "Statistical Analysis", + "icon_url": "https://cdn.jsdelivr.net/gh/devicons/devicon/icons/statsmodels/statsmodels-original.svg" + }, + "sk_learn_agent": { + "display_name": "Machine Learning Agent", + "description": "Trains and evaluates machine learning models using scikit-learn, including classification, regression, and clustering with feature importance insights.", + "prompt_template": """You are a machine learning agent. +Your task is to take a dataset and a user-defined goal, and output Python code that performs the appropriate machine learning analysis to achieve that goal. +You should use the scikit-learn library. +IMPORTANT: You may be provided with previous interaction history. The section marked "### Current Query:" contains the user's current request. Any text in "### Previous Interaction History:" is for context only and is NOT part of the current request. +Make sure your output is as intended! +Provide a concise bullet-point summary of the machine learning operations performed. + +Example Summary: +β€’ Trained a Random Forest classifier on customer churn data with 80/20 train-test split +β€’ Model achieved 92% accuracy and 88% F1-score +β€’ Feature importance analysis revealed that contract length and monthly charges are the strongest predictors of churn +β€’ Implemented K-means clustering (k=4) on customer shopping behaviors +β€’ Identified distinct segments: high-value frequent shoppers (22%), occasional big spenders (35%), budget-conscious regulars (28%), and rare visitors (15%) +Respond in the user's language for all summary and reasoning but keep the code in english""", + "category": "Modelling", + "icon_url": "https://cdn.jsdelivr.net/gh/devicons/devicon/icons/scikit-learn/scikit-learn-original.svg" + }, + "data_viz_agent": { + "display_name": "Data Visualization Agent", + "description": "Generates interactive visualizations with Plotly, selecting the best chart type to reveal trends, comparisons, and insights based on the analysis goal.", + "prompt_template": """ +You are an AI agent responsible for generating interactive data visualizations using Plotly. +IMPORTANT Instructions: +- The section marked "### Current Query:" contains the user's request. Any text in "### Previous Interaction History:" is for context only and should NOT be treated as part of the current request. +- You must only use the tools provided to you. This agent handles visualization only. +- If len(df) > 50000, always sample the dataset before visualization using: +if len(df) > 50000: + df = df.sample(50000, random_state=1) +- Each visualization must be generated as a **separate figure** using go.Figure(). +Do NOT use subplots under any circumstances. +- Each figure must be returned individually using: +fig.to_html(full_html=False) +- Use update_layout with xaxis and yaxis **only once per figure**. +- Enhance readability and clarity by: +β€’ Using low opacity (0.4-0.7) where appropriate +β€’ Applying visually distinct colors for different elements or categories +- Make sure the visual **answers the user's specific goal**: +β€’ Identify what insight or comparison the user is trying to achieve +β€’ Choose the visualization type and features (e.g., color, size, grouping) to emphasize that goal +β€’ For example, if the user asks for "trends in revenue," use a time series line chart; if they ask for "top-performing categories," use a bar chart sorted by value +β€’ Prioritize highlighting patterns, outliers, or comparisons relevant to the question +- Never include the dataset or styling index in the output. +- If there are no relevant columns for the requested visualization, respond with: +"No relevant columns found to generate this visualization." +- Use only one number format consistently: either 'K', 'M', or comma-separated values like 1,000/1,000,000. Do not mix formats. +- Only include trendlines in scatter plots if the user explicitly asks for them. +- Output only the code and a concise bullet-point summary of what the visualization reveals. +- Always end each visualization with: +fig.to_html(full_html=False) +Respond in the user's language for all summary and reasoning but keep the code in english +Example Summary: +β€’ Created an interactive scatter plot of sales vs. marketing spend with color-coded product categories +β€’ Included a trend line showing positive correlation (r=0.72) +β€’ Highlighted outliers where high marketing spend resulted in low sales +β€’ Generated a time series chart of monthly revenue from 2020-2023 +β€’ Added annotations for key business events +β€’ Visualization reveals 35% YoY growth with seasonal peaks in Q4""", + "category": "Visualization", + "icon_url": "https://cdn.jsdelivr.net/gh/devicons/devicon/icons/plotly/plotly-original.svg" + } + } + + created_count = 0 + updated_count = 0 + + for template_name, agent_data in default_agents.items(): + # Check if agent already exists + existing_agent = db_session.query(AgentTemplate).filter( + AgentTemplate.template_name == template_name + ).first() + + if existing_agent: + if force_update: + # Update existing agent + existing_agent.display_name = agent_data["display_name"] + existing_agent.description = agent_data["description"] + existing_agent.prompt_template = agent_data["prompt_template"] + existing_agent.category = agent_data["category"] + existing_agent.icon_url = agent_data["icon_url"] + existing_agent.is_premium_only = False + existing_agent.is_active = True + existing_agent.updated_at = datetime.now(UTC) + updated_count += 1 + else: + logger.log_message(f"Agent '{template_name}' already exists, skipping", level=logging.INFO) + continue + else: + # Create new agent + new_agent = AgentTemplate( + template_name=template_name, + display_name=agent_data["display_name"], + description=agent_data["description"], + prompt_template=agent_data["prompt_template"], + category=agent_data["category"], + icon_url=agent_data["icon_url"], + is_premium_only=False, + is_active=True, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC) + ) + db_session.add(new_agent) + created_count += 1 + + db_session.commit() + + message = f"Successfully loaded default agents. Created: {created_count}, Updated: {updated_count}" + logger.log_message(message, level=logging.INFO) + return True, message + + except Exception as e: + db_session.rollback() + error_msg = f"Error loading default agents: {str(e)}" + logger.log_message(error_msg, level=logging.ERROR) + return False, error_msg + +def initialize_default_agents(force_update=False): + """ + Initialize default agents during application startup. + + Args: + force_update: If True, update existing agents. If False, skip existing ones. + + Returns: + bool: True if successful, False otherwise + """ + try: + from src.db.init_db import session_factory + + session = session_factory() + try: + success, message = load_default_agents_to_db(session, force_update=force_update) + logger.log_message(f"Default agents initialization: {message}", level=logging.INFO) + return success + finally: + session.close() + + except Exception as e: + logger.log_message(f"Failed to initialize default agents: {str(e)}", level=logging.ERROR) + return False + +if __name__ == "__main__": + initialize_default_agents(force_update=True) \ No newline at end of file diff --git a/auto-analyst-backend/src/managers/session_manager.py b/auto-analyst-backend/src/managers/session_manager.py index 013801c5..9ccbd292 100644 --- a/auto-analyst-backend/src/managers/session_manager.py +++ b/auto-analyst-backend/src/managers/session_manager.py @@ -27,19 +27,21 @@ class SessionManager: def __init__(self, styling_instructions: List[str], available_agents: Dict): """ - Initialize session manager with styling instructions and agents + Initialize SessionManager with styling instructions and available agents Args: - styling_instructions: List of styling instructions - available_agents: Dictionary of available agents + styling_instructions: List of styling instructions for visualization + available_agents: Dictionary of available agents (deprecated - agents now loaded from DB) """ + self.styling_instructions = styling_instructions self._sessions = {} self._default_df = None self._default_retrievers = None self._default_ai_system = None - self._dataset_description = None self._make_data = None - self._default_name = "Housing Dataset" # Default dataset name + # Initialize chat manager + self._dataset_description = "Housing Dataset" + self._default_name = "Housing.csv" self._dataset_description = """This dataset contains residential property information with details about pricing, physical characteristics, and amenities. The data can be used for real estate market analysis, property valuation, and understanding the relationship between house features and prices. @@ -92,8 +94,8 @@ def initialize_default_dataset(self): self._default_df = pd.read_csv("Housing.csv") self._make_data = make_data(self._default_df, self._dataset_description) self._default_retrievers = self.initialize_retrievers(self.styling_instructions, [str(self._make_data)]) - self._default_ai_system = auto_analyst(agents=list(self.available_agents.values()), - retrievers=self._default_retrievers) + # Create default AI system - agents will be loaded from database + self._default_ai_system = auto_analyst(agents=[], retrievers=self._default_retrievers) except Exception as e: logger.log_message(f"Error initializing default dataset: {str(e)}", level=logging.ERROR) raise e @@ -311,7 +313,7 @@ def create_ai_system_for_user(self, retrievers, user_id=None): try: # Create AI system with user context to load custom agents ai_system = auto_analyst( - agents=list(self.available_agents.values()), + agents=[], retrievers=retrievers, user_id=user_id, db_session=db_session @@ -322,12 +324,12 @@ def create_ai_system_for_user(self, retrievers, user_id=None): db_session.close() else: # Create standard AI system without custom agents - return auto_analyst(agents=list(self.available_agents.values()), retrievers=retrievers) + return auto_analyst(agents=[], retrievers=retrievers) except Exception as e: logger.log_message(f"Error creating AI system for user {user_id}: {str(e)}", level=logging.ERROR) # Fallback to standard AI system - return auto_analyst(agents=list(self.available_agents.values()), retrievers=retrievers) + return auto_analyst(agents=[], retrievers=retrievers) def set_session_user(self, session_id: str, user_id: int, chat_id: int = None): """ diff --git a/auto-analyst-backend/src/managers/user_manager.py b/auto-analyst-backend/src/managers/user_manager.py index af836299..23f10f39 100644 --- a/auto-analyst-backend/src/managers/user_manager.py +++ b/auto-analyst-backend/src/managers/user_manager.py @@ -1,12 +1,13 @@ import logging import os from typing import Optional +from datetime import datetime, UTC from fastapi import Depends, HTTPException, Request, status from fastapi.security import APIKeyHeader from src.db.init_db import get_session -from src.db.schemas.models import User as DBUser +from src.db.schemas.models import User as DBUser, AgentTemplate, UserTemplatePreference from src.schemas.user_schemas import User from src.utils.logger import Logger @@ -100,6 +101,9 @@ def create_user(username: str, email: str) -> User: session.commit() session.refresh(new_user) + # Enable default agents for the new user + _enable_default_agents_for_user(new_user.user_id, session) + return User( user_id=new_user.user_id, username=new_user.username, @@ -131,3 +135,48 @@ def get_user_by_email(email: str) -> Optional[User]: return None finally: session.close() + +def _enable_default_agents_for_user(user_id: int, session): + """Enable default agents for a new user""" + try: + # Get all default agents (the 4 built-in agents) + default_agent_names = [ + "preprocessing_agent", + "statistical_analytics_agent", + "sk_learn_agent", + "data_viz_agent" + ] + + # Find these agents in the database + default_agents = session.query(AgentTemplate).filter( + AgentTemplate.template_name.in_(default_agent_names), + AgentTemplate.is_active == True + ).all() + + # Enable each default agent for the user + for agent in default_agents: + # Check if preference already exists + existing_pref = session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.template_id == agent.template_id + ).first() + + if not existing_pref: + # Create new preference with enabled=True + new_pref = UserTemplatePreference( + user_id=user_id, + template_id=agent.template_id, + is_enabled=True, # Enable by default + usage_count=0, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC) + ) + session.add(new_pref) + + session.commit() + logger.log_message(f"Enabled {len(default_agents)} default agents for user {user_id}", level=logging.INFO) + + except Exception as e: + session.rollback() + logger.log_message(f"Error enabling default agents for user {user_id}: {str(e)}", level=logging.ERROR) + raise diff --git a/auto-analyst-backend/src/routes/templates_routes.py b/auto-analyst-backend/src/routes/templates_routes.py index e7c354cf..06394a17 100644 --- a/auto-analyst-backend/src/routes/templates_routes.py +++ b/auto-analyst-backend/src/routes/templates_routes.py @@ -41,12 +41,15 @@ class UserTemplatePreferenceResponse(BaseModel): template_category: Optional[str] icon_url: Optional[str] is_premium_only: bool + is_active: bool is_enabled: bool usage_count: int last_used_at: Optional[datetime] + created_at: Optional[datetime] + updated_at: Optional[datetime] -class ToggleTemplateRequest(BaseModel): - is_enabled: bool = Field(..., description="Whether to enable or disable the template") +class TogglePreferenceRequest(BaseModel): + is_enabled: bool def get_global_usage_counts(session, template_ids: List[int] = None) -> Dict[int, int]: """ @@ -140,6 +143,14 @@ async def get_user_template_preferences(user_id: int): AgentTemplate.is_active == True ).all() + # Get list of default agent names that should be enabled by default + default_agent_names = [ + "preprocessing_agent", + "statistical_analytics_agent", + "sk_learn_agent", + "data_viz_agent" + ] + result = [] for template in templates: # Get user preference for this template if it exists @@ -148,6 +159,10 @@ async def get_user_template_preferences(user_id: int): UserTemplatePreference.template_id == template.template_id ).first() + # Determine if template should be enabled by default + is_default_agent = template.template_name in default_agent_names + default_enabled = is_default_agent # Default agents enabled by default, others disabled + result.append(UserTemplatePreferenceResponse( template_id=template.template_id, template_name=template.template_name, @@ -156,9 +171,12 @@ async def get_user_template_preferences(user_id: int): template_category=template.category, icon_url=template.icon_url, is_premium_only=template.is_premium_only, - is_enabled=preference.is_enabled if preference else False, # Default to disabled + is_active=template.is_active, + is_enabled=preference.is_enabled if preference else default_enabled, # Default agents enabled by default usage_count=preference.usage_count if preference else 0, - last_used_at=preference.last_used_at if preference else None + last_used_at=preference.last_used_at if preference else None, + created_at=preference.created_at if preference else None, + updated_at=preference.updated_at if preference else None )) return result @@ -189,6 +207,14 @@ async def get_user_enabled_templates(user_id: int): AgentTemplate.is_active == True ).all() + # Get list of default agent names that should be enabled by default + default_agent_names = [ + "preprocessing_agent", + "statistical_analytics_agent", + "sk_learn_agent", + "data_viz_agent" + ] + result = [] for template in all_templates: # Check if user has a preference record for this template @@ -197,8 +223,12 @@ async def get_user_enabled_templates(user_id: int): UserTemplatePreference.template_id == template.template_id ).first() - # Template is disabled by default unless explicitly enabled - is_enabled = preference.is_enabled if preference else False + # Determine if template should be enabled by default + is_default_agent = template.template_name in default_agent_names + default_enabled = is_default_agent # Default agents enabled by default, others disabled + + # Template is enabled by default for default agents, disabled for others + is_enabled = preference.is_enabled if preference else default_enabled if is_enabled: result.append(UserTemplatePreferenceResponse( @@ -209,9 +239,12 @@ async def get_user_enabled_templates(user_id: int): template_category=template.category, icon_url=template.icon_url, is_premium_only=template.is_premium_only, + is_active=template.is_active, is_enabled=True, usage_count=preference.usage_count if preference else 0, - last_used_at=preference.last_used_at if preference else None + last_used_at=preference.last_used_at if preference else None, + created_at=preference.created_at if preference else None, + updated_at=preference.updated_at if preference else None )) return result @@ -237,36 +270,66 @@ async def get_user_enabled_templates_for_planner(user_id: int): if not user: raise HTTPException(status_code=404, detail="User not found") - # Get enabled templates ordered by usage (most used first) and limit to 10 - enabled_preferences = session.query(UserTemplatePreference).filter( - UserTemplatePreference.user_id == user_id, - UserTemplatePreference.is_enabled == True - ).order_by( - UserTemplatePreference.usage_count.desc(), - UserTemplatePreference.last_used_at.desc() - ).limit(10).all() + # Get list of default agent names that should be enabled by default + default_agent_names = [ + "preprocessing_agent", + "statistical_analytics_agent", + "sk_learn_agent", + "data_viz_agent" + ] - result = [] - for preference in enabled_preferences: - # Get template details - template = session.query(AgentTemplate).filter( - AgentTemplate.template_id == preference.template_id, - AgentTemplate.is_active == True + # Get all active templates + all_templates = session.query(AgentTemplate).filter( + AgentTemplate.is_active == True + ).all() + + enabled_templates = [] + for template in all_templates: + # Check if user has a preference record for this template + preference = session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.template_id == template.template_id ).first() - if template: - result.append(UserTemplatePreferenceResponse( - template_id=template.template_id, - template_name=template.template_name, - display_name=template.display_name, - description=template.description, - template_category=template.category, - icon_url=template.icon_url, - is_premium_only=template.is_premium_only, - is_enabled=True, - usage_count=preference.usage_count, - last_used_at=preference.last_used_at - )) + # Determine if template should be enabled by default + is_default_agent = template.template_name in default_agent_names + default_enabled = is_default_agent # Default agents enabled by default, others disabled + + # Template is enabled by default for default agents, disabled for others + is_enabled = preference.is_enabled if preference else default_enabled + + if is_enabled: + enabled_templates.append({ + 'template': template, + 'preference': preference, + 'usage_count': preference.usage_count if preference else 0, + 'last_used_at': preference.last_used_at if preference else None + }) + + # Sort by usage (most used first) and limit to 10 + enabled_templates.sort(key=lambda x: (x['usage_count'], x['last_used_at'] or datetime.min.replace(tzinfo=UTC)), reverse=True) + enabled_templates = enabled_templates[:10] + + result = [] + for item in enabled_templates: + template = item['template'] + preference = item['preference'] + + result.append(UserTemplatePreferenceResponse( + template_id=template.template_id, + template_name=template.template_name, + display_name=template.display_name, + description=template.description, + template_category=template.category, + icon_url=template.icon_url, + is_premium_only=template.is_premium_only, + is_active=template.is_active, + is_enabled=True, + usage_count=preference.usage_count if preference else 0, + last_used_at=preference.last_used_at if preference else None, + created_at=preference.created_at if preference else None, + updated_at=preference.updated_at if preference else None + )) logger.log_message(f"Retrieved {len(result)} enabled templates for planner for user {user_id}", level=logging.INFO) return result @@ -281,7 +344,7 @@ async def get_user_enabled_templates_for_planner(user_id: int): raise HTTPException(status_code=500, detail=f"Failed to retrieve planner templates: {str(e)}") @router.post("/user/{user_id}/template/{template_id}/toggle") -async def toggle_template_preference(user_id: int, template_id: int, request: ToggleTemplateRequest): +async def toggle_template_preference(user_id: int, template_id: int, request: TogglePreferenceRequest): """Toggle a user's template preference (enable/disable for planner use)""" try: session = session_factory() diff --git a/auto-analyst-frontend/app/account/page.tsx b/auto-analyst-frontend/app/account/page.tsx index d0128229..e91ce7f5 100644 --- a/auto-analyst-frontend/app/account/page.tsx +++ b/auto-analyst-frontend/app/account/page.tsx @@ -86,7 +86,6 @@ export default function AccountPage() { const fetchUserData = async () => { try { - // logger.log('Fetching user data from API') setIsRefreshing(true) // Add cache-busting parameter and force flag to ensure fresh data @@ -96,13 +95,11 @@ export default function AccountPage() { } const data: UserDataResponse = await response.json() - // logger.log('Received user data:', data) setProfile(data.profile) setSubscription(data.subscription) // Enhanced credits handling using centralized config - // logger.log('Credits data:', data.credits) if (data.credits) { // Use centralized config to get plan-specific defaults diff --git a/auto-analyst-frontend/components/chat/AgentSuggestions.tsx b/auto-analyst-frontend/components/chat/AgentSuggestions.tsx index 20dd1705..f1af1018 100644 --- a/auto-analyst-frontend/components/chat/AgentSuggestions.tsx +++ b/auto-analyst-frontend/components/chat/AgentSuggestions.tsx @@ -12,6 +12,7 @@ interface AgentSuggestion { description: string isCustom?: boolean isTemplate?: boolean + isPremium?: boolean } interface AgentSuggestionsProps { @@ -20,6 +21,7 @@ interface AgentSuggestionsProps { onSuggestionSelect: (agentName: string) => void isVisible: boolean userId?: number | null + onStateChange?: (hasSelection: boolean) => void } export default function AgentSuggestions({ @@ -27,7 +29,8 @@ export default function AgentSuggestions({ cursorPosition, onSuggestionSelect, isVisible, - userId + userId, + onStateChange }: AgentSuggestionsProps) { const [agents, setAgents] = useState([]) const [filteredAgents, setFilteredAgents] = useState([]) @@ -79,15 +82,15 @@ export default function AgentSuggestions({ const data = await response.json() const allAgents: AgentSuggestion[] = [] - // Add standard agents - if (data.standard_agents) { - data.standard_agents.forEach((agentName: string) => { - const standardAgent = standardAgents.find(agent => agent.name === agentName) - if (standardAgent) { - allAgents.push(standardAgent) - } - }) - } + // // Add standard agents + // if (data.standard_agents) { + // data.standard_agents.forEach((agentName: string) => { + // const standardAgent = standardAgents.find(agent => agent.name === agentName) + // if (standardAgent) { + // allAgents.push(standardAgent) + // } + // }) + // } // Add template agents (only for users with custom agents access) if (data.template_agents && data.template_agents.length > 0 && customAgentsAccess.hasAccess) { @@ -119,14 +122,15 @@ export default function AgentSuggestions({ if (response.ok) { const templateCategories = await response.json() const allTemplates: AgentSuggestion[] = [] - + console.log("templateCategories", templateCategories) // Flatten all templates from all categories templateCategories.forEach((category: any) => { if (category.templates) { const mappedTemplates = category.templates.map((template: any) => ({ name: template.agent_name, description: template.description, - isTemplate: true + // isTemplate: true, + isPremium: template.is_premium_only })) allTemplates.push(...mappedTemplates) } @@ -188,29 +192,35 @@ export default function AgentSuggestions({ ? message.slice(activeAtPos + 1, activeAtPos + 1 + spaceIndex) : textAfterAt - // Show suggestions if we're actively typing an agent name or just typed @ - if (!typedText.includes(' ')) { - // If no text after @, show all agents - if (typedText === '') { - setFilteredAgents(agents) - setSelectedIndex(agents.length > 0 ? 0 : -1) + // Show suggestions if we're actively typing an agent name or just typed @ + if (!typedText.includes(' ')) { + // If no text after @, show all agents + if (typedText === '') { + setFilteredAgents(agents) + setSelectedIndex(agents.length > 0 ? 0 : -1) + return + } + + // If there's text after @, filter agents that START WITH the typed text (autocomplete-style) + const filtered = agents.filter(agent => + agent.name.toLowerCase().startsWith(typedText.toLowerCase()) + ) + setFilteredAgents(filtered) + setSelectedIndex(filtered.length > 0 ? 0 : -1) return } - - // If there's text after @, filter agents based on that text - const filtered = agents.filter(agent => - agent.name.toLowerCase().includes(typedText.toLowerCase()) - ) - setFilteredAgents(filtered) - setSelectedIndex(filtered.length > 0 ? 0 : -1) - return - } } setFilteredAgents([]) setSelectedIndex(-1) }, [message, cursorPosition, agents, isVisible]) + // Report state changes to parent component + useEffect(() => { + const hasValidSelection = filteredAgents.length > 0 && selectedIndex >= 0 && selectedIndex < filteredAgents.length + onStateChange?.(hasValidSelection) + }, [filteredAgents, selectedIndex, onStateChange]) + // Handle keyboard navigation useEffect(() => { const handleKeyDown = (e: KeyboardEvent) => { @@ -219,25 +229,30 @@ export default function AgentSuggestions({ switch (e.key) { case 'ArrowDown': e.preventDefault() + e.stopPropagation() setSelectedIndex(prev => prev < filteredAgents.length - 1 ? prev + 1 : 0 ) break case 'ArrowUp': e.preventDefault() + e.stopPropagation() setSelectedIndex(prev => prev > 0 ? prev - 1 : filteredAgents.length - 1 ) break case 'Enter': - e.preventDefault() - e.stopPropagation() + // Only handle Enter if there's a valid selection if (selectedIndex >= 0 && selectedIndex < filteredAgents.length) { + e.preventDefault() + e.stopPropagation() onSuggestionSelect(filteredAgents[selectedIndex].name) } + // If no valid selection, let the event bubble up to ChatInput break case 'Escape': e.preventDefault() + e.stopPropagation() setFilteredAgents([]) setSelectedIndex(-1) break @@ -245,8 +260,8 @@ export default function AgentSuggestions({ } // Add event listener to document to capture keyboard events - document.addEventListener('keydown', handleKeyDown) - return () => document.removeEventListener('keydown', handleKeyDown) + document.addEventListener('keydown', handleKeyDown, true) // Use capture phase + return () => document.removeEventListener('keydown', handleKeyDown, true) }, [isVisible, filteredAgents, selectedIndex, onSuggestionSelect]) // Scroll selected item into view @@ -288,7 +303,7 @@ export default function AgentSuggestions({
{agent.name}
- {agent.isTemplate && ( + {agent.isPremium && ( Template diff --git a/auto-analyst-frontend/components/chat/ChatInput.tsx b/auto-analyst-frontend/components/chat/ChatInput.tsx index 5ff61148..0ef71b42 100644 --- a/auto-analyst-frontend/components/chat/ChatInput.tsx +++ b/auto-analyst-frontend/components/chat/ChatInput.tsx @@ -228,6 +228,9 @@ const ChatInput = forwardRef< const [showCommandSuggestions, setShowCommandSuggestions] = useState(false) const [commandQuery, setCommandQuery] = useState('') + // Agent suggestions state + const [agentSuggestionsHasSelection, setAgentSuggestionsHasSelection] = useState(false) + // Get subscription from store instead of manual construction const { subscription } = useUserSubscriptionStore() const deepAnalysisAccess = useFeatureAccess('DEEP_ANALYSIS', subscription) @@ -1951,13 +1954,16 @@ const ChatInput = forwardRef< onChange={handleInputChange} onKeyDown={(e) => { if (e.key === 'Enter' && !e.shiftKey) { - // Check if agent suggestions are visible and should handle the Enter key + // Check if agent suggestions are visible and have a selection const isAgentSuggestionsVisible = !showCommandSuggestions && message.includes('@'); - if (isAgentSuggestionsVisible) { - // Don't handle Enter here, let AgentSuggestions component handle it - // The AgentSuggestions component will preventDefault if it handles the event + + if (isAgentSuggestionsVisible && agentSuggestionsHasSelection) { + // AgentSuggestions will handle Enter key since it has a selection + // Don't preventDefault here - let AgentSuggestions handle it return; } + + // If no agent suggestions selection, handle normally e.preventDefault() handleSubmit(e) } @@ -1990,6 +1996,7 @@ const ChatInput = forwardRef< onSuggestionSelect={handleAgentSelect} isVisible={!showCommandSuggestions && message.includes('@')} userId={userId} + onStateChange={setAgentSuggestionsHasSelection} />
diff --git a/auto-analyst-frontend/components/custom-templates/TemplatesModal.tsx b/auto-analyst-frontend/components/custom-templates/TemplatesModal.tsx index 9770e0f6..89fab668 100644 --- a/auto-analyst-frontend/components/custom-templates/TemplatesModal.tsx +++ b/auto-analyst-frontend/components/custom-templates/TemplatesModal.tsx @@ -64,21 +64,39 @@ export default function TemplatesModal({ const loadTemplatesForFreeUsers = async () => { setLoading(true) try { + console.log('Loading templates for free users...', { API_URL }) + // Fetch all templates (no user-specific data needed) - const response = await fetch(`${API_URL}/templates/`) + const response = await fetch(`${API_URL}/templates/`).catch(err => { + console.error('Free user templates fetch error:', err) + throw new Error(`Templates endpoint failed: ${err.message}`) + }) - if (response.ok) { - const templatesData = await response.json() - setTemplates(templatesData) - setPreferences([]) // No preferences for free users + console.log('Free user templates response:', { status: response.status }) + + if (!response.ok) { + const errorText = await response.text() + console.error('Free user templates response error:', { status: response.status, errorText }) + throw new Error(`Failed to load templates: ${response.status} ${response.statusText} - ${errorText}`) } + const templatesData = await response.json().catch(err => { + console.error('Free user templates JSON parse error:', err) + throw new Error(`Failed to parse templates response: ${err.message}`) + }) + + console.log('Free user templates data parsed successfully:', { + templatesCount: templatesData.length + }) + + setTemplates(templatesData) + setPreferences([]) // No preferences for free users setChanges({}) } catch (error) { console.error('Error loading templates:', error) toast({ title: "Error", - description: "Failed to load agents", + description: error instanceof Error ? error.message : "Failed to load agents", variant: "destructive", }) } finally { @@ -90,59 +108,93 @@ export default function TemplatesModal({ const loadData = async () => { setLoading(true) try { + console.log('Loading templates data for modal...', { API_URL, userId }) + // Fetch global template data with global usage counts const [templatesResponse, preferencesResponse] = await Promise.all([ - fetch(`${API_URL}/templates/`), // Global templates with global usage counts - fetch(`${API_URL}/templates/user/${userId}`) // User preferences with per-user usage + fetch(`${API_URL}/templates/`).catch(err => { + console.error('Templates fetch error:', err) + throw new Error(`Templates endpoint failed: ${err.message}`) + }), // Global templates with global usage counts + fetch(`${API_URL}/templates/user/${userId}`).catch(err => { + console.error('Preferences fetch error:', err) + throw new Error(`Preferences endpoint failed: ${err.message}`) + }) // User preferences with per-user usage ]) - if (templatesResponse.ok) { - // Global templates with global usage counts - const globalTemplatesData = await templatesResponse.json() - - // Convert to TemplateAgent format with global usage counts - const templatesData = globalTemplatesData.map((item: any) => ({ - template_id: item.template_id, - template_name: item.template_name, - display_name: item.display_name, - description: item.description, - prompt_template: item.prompt_template, - template_category: item.template_category, - icon_url: item.icon_url, - is_premium_only: item.is_premium_only, - is_active: item.is_active, - usage_count: item.usage_count, // Global usage count from /templates/ endpoint - created_at: item.created_at - })) - setTemplates(templatesData) + console.log('Modal responses received:', { + templatesStatus: templatesResponse.status, + preferencesStatus: preferencesResponse.status + }) + + // Check templates response + if (!templatesResponse.ok) { + const errorText = await templatesResponse.text() + console.error('Templates response error:', { status: templatesResponse.status, errorText }) + throw new Error(`Failed to load templates: ${templatesResponse.status} ${templatesResponse.statusText} - ${errorText}`) } - if (preferencesResponse.ok) { - // User preferences (enabled/disabled status and per-user usage) - const userPreferencesData = await preferencesResponse.json() - - const preferencesData = userPreferencesData.map((item: any) => ({ - template_id: item.template_id, - template_name: item.template_name, - display_name: item.display_name, - description: item.description, - template_category: item.template_category, - icon_url: item.icon_url, - is_premium_only: item.is_premium_only, - is_enabled: item.is_enabled, - usage_count: item.usage_count, // Keep user-specific usage for preferences if needed - last_used_at: item.last_used_at - })) - setPreferences(preferencesData) + // Check preferences response + if (!preferencesResponse.ok) { + const errorText = await preferencesResponse.text() + console.error('Preferences response error:', { status: preferencesResponse.status, errorText }) + throw new Error(`Failed to load preferences: ${preferencesResponse.status} ${preferencesResponse.statusText} - ${errorText}`) } + // Parse templates response + const globalTemplatesData = await templatesResponse.json().catch(err => { + console.error('Templates JSON parse error:', err) + throw new Error(`Failed to parse templates response: ${err.message}`) + }) + + // Convert to TemplateAgent format with global usage counts + const templatesData = globalTemplatesData.map((item: any) => ({ + template_id: item.template_id, + template_name: item.template_name, + display_name: item.display_name, + description: item.description, + prompt_template: item.prompt_template, + template_category: item.template_category, + icon_url: item.icon_url, + is_premium_only: item.is_premium_only, + is_active: item.is_active, + usage_count: item.usage_count, // Global usage count from /templates/ endpoint + created_at: item.created_at + })) + setTemplates(templatesData) + + // Parse preferences response + const userPreferencesData = await preferencesResponse.json().catch(err => { + console.error('Preferences JSON parse error:', err) + throw new Error(`Failed to parse preferences response: ${err.message}`) + }) + + const preferencesData = userPreferencesData.map((item: any) => ({ + template_id: item.template_id, + template_name: item.template_name, + display_name: item.display_name, + description: item.description, + template_category: item.template_category, + icon_url: item.icon_url, + is_premium_only: item.is_premium_only, + is_enabled: item.is_enabled, + usage_count: item.usage_count, // Keep user-specific usage for preferences if needed + last_used_at: item.last_used_at + })) + setPreferences(preferencesData) + + console.log('Modal data parsed successfully:', { + templatesCount: templatesData.length, + preferencesCount: preferencesData.length + }) + // Reset changes when loading data setChanges({}) } catch (error) { console.error('Error loading data:', error) toast({ title: "Error", - description: "Failed to load agents", + description: error instanceof Error ? error.message : "Failed to load agents", variant: "destructive", }) } finally { @@ -165,6 +217,27 @@ export default function TemplatesModal({ return preferences.find(p => p.template_id === templateId) } + // Helper function to determine if a template should be enabled by default + const isDefaultEnabledTemplate = (templateName: string) => { + const defaultAgentNames = [ + "preprocessing_agent", + "statistical_analytics_agent", + "sk_learn_agent", + "data_viz_agent" + ] + return defaultAgentNames.includes(templateName) + } + + // Helper function to get the effective enabled state for a template + const getTemplateEnabledState = (template: TemplateAgent) => { + const preference = getTemplatePreference(template.template_id) + const defaultEnabled = isDefaultEnabledTemplate(template.template_name) + + return changes[template.template_id] !== undefined + ? changes[template.template_id] + : preference?.is_enabled ?? defaultEnabled + } + // Filter templates based on search, category, and status const filteredTemplates = useMemo(() => { let filtered = templates @@ -183,11 +256,7 @@ export default function TemplatesModal({ if (statusFilter !== 'all') { filtered = filtered.filter(template => { - const preference = getTemplatePreference(template.template_id) - const isEnabled = changes[template.template_id] !== undefined - ? changes[template.template_id] - : preference?.is_enabled || false - + const isEnabled = getTemplateEnabledState(template) return statusFilter === 'enabled' ? isEnabled : !isEnabled }) } @@ -294,18 +363,13 @@ export default function TemplatesModal({ // Get template data for rendering const getTemplateData = (template: TemplateAgent) => { const preference = getTemplatePreference(template.template_id) - const isEnabled = changes[template.template_id] !== undefined - ? changes[template.template_id] - : preference?.is_enabled || false + const isEnabled = getTemplateEnabledState(template) return { preference, isEnabled } } const enabledCount = hasAccess - ? preferences.filter(p => { - const hasChanges = changes[p.template_id] !== undefined - return hasChanges ? changes[p.template_id] : p.is_enabled - }).length + ? templates.filter(template => getTemplateEnabledState(template)).length : 0 return ( diff --git a/auto-analyst-frontend/components/custom-templates/useTemplates.ts b/auto-analyst-frontend/components/custom-templates/useTemplates.ts index c0ee739a..f958d116 100644 --- a/auto-analyst-frontend/components/custom-templates/useTemplates.ts +++ b/auto-analyst-frontend/components/custom-templates/useTemplates.ts @@ -34,24 +34,57 @@ export function useTemplates({ userId, enabled = true }: UseTemplatesProps): Use setError(null) try { + console.log('Loading templates data...', { API_URL, userId }) + const [templatesResponse, preferencesResponse] = await Promise.all([ - fetch(`${API_URL}/templates`), - fetch(`${API_URL}/templates/user/${userId}`) + fetch(`${API_URL}/templates/`).catch(err => { + console.error('Templates fetch error:', err) + throw new Error(`Templates endpoint failed: ${err.message}`) + }), + fetch(`${API_URL}/templates/user/${userId}`).catch(err => { + console.error('Preferences fetch error:', err) + throw new Error(`Preferences endpoint failed: ${err.message}`) + }) ]) - if (templatesResponse.ok) { - const templatesData = await templatesResponse.json() - setTemplates(templatesData) - } else { - throw new Error('Failed to load templates') + console.log('Responses received:', { + templatesStatus: templatesResponse.status, + preferencesStatus: preferencesResponse.status + }) + + // Check templates response + if (!templatesResponse.ok) { + const errorText = await templatesResponse.text() + console.error('Templates response error:', { status: templatesResponse.status, errorText }) + throw new Error(`Failed to load templates: ${templatesResponse.status} ${templatesResponse.statusText} - ${errorText}`) } - if (preferencesResponse.ok) { - const preferencesData = await preferencesResponse.json() - setPreferences(preferencesData) - } else { - throw new Error('Failed to load preferences') + // Check preferences response + if (!preferencesResponse.ok) { + const errorText = await preferencesResponse.text() + console.error('Preferences response error:', { status: preferencesResponse.status, errorText }) + throw new Error(`Failed to load preferences: ${preferencesResponse.status} ${preferencesResponse.statusText} - ${errorText}`) } + + // Parse responses + const templatesData = await templatesResponse.json().catch(err => { + console.error('Templates JSON parse error:', err) + throw new Error(`Failed to parse templates response: ${err.message}`) + }) + + const preferencesData = await preferencesResponse.json().catch(err => { + console.error('Preferences JSON parse error:', err) + throw new Error(`Failed to parse preferences response: ${err.message}`) + }) + + console.log('Data parsed successfully:', { + templatesCount: templatesData.length, + preferencesCount: preferencesData.length + }) + + setTemplates(templatesData) + setPreferences(preferencesData) + } catch (err) { const errorMessage = err instanceof Error ? err.message : 'Failed to load data' setError(errorMessage) diff --git a/auto-analyst-frontend/public/icons/templates/data_viz_agent.svg b/auto-analyst-frontend/public/icons/templates/data_viz_agent.svg new file mode 100644 index 00000000..e7d9cd53 --- /dev/null +++ b/auto-analyst-frontend/public/icons/templates/data_viz_agent.svg @@ -0,0 +1 @@ + diff --git a/auto-analyst-frontend/public/icons/templates/matplotlib_agent.png b/auto-analyst-frontend/public/icons/templates/matplotlib_agent.png new file mode 100644 index 00000000..2d515bce Binary files /dev/null and b/auto-analyst-frontend/public/icons/templates/matplotlib_agent.png differ diff --git a/auto-analyst-frontend/public/icons/templates/polars_agent.svg b/auto-analyst-frontend/public/icons/templates/polars_agent.svg new file mode 100644 index 00000000..5a99c887 --- /dev/null +++ b/auto-analyst-frontend/public/icons/templates/polars_agent.svg @@ -0,0 +1,83 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/auto-analyst-frontend/public/icons/templates/preprocessing_agent.svg b/auto-analyst-frontend/public/icons/templates/preprocessing_agent.svg new file mode 100644 index 00000000..26c18c46 --- /dev/null +++ b/auto-analyst-frontend/public/icons/templates/preprocessing_agent.svg @@ -0,0 +1 @@ + diff --git a/auto-analyst-frontend/public/icons/templates/sk_learn_agent.svg b/auto-analyst-frontend/public/icons/templates/sk_learn_agent.svg new file mode 100644 index 00000000..5a32f797 --- /dev/null +++ b/auto-analyst-frontend/public/icons/templates/sk_learn_agent.svg @@ -0,0 +1,111 @@ + + + +image/svg+xml + + + + + + + + + + + + + + +scikit + + + + + + + \ No newline at end of file diff --git a/test_default_agents.py b/test_default_agents.py new file mode 100644 index 00000000..0519ecba --- /dev/null +++ b/test_default_agents.py @@ -0,0 +1 @@ + \ No newline at end of file