From 92db04d5178420adff8a2f92aa633a7c2246e420 Mon Sep 17 00:00:00 2001 From: Ashad Qureshi Date: Thu, 12 Jun 2025 19:46:53 +0500 Subject: [PATCH 1/7] Custom Agents are not Templates only --- auto-analyst-backend/app.py | 141 +++- auto-analyst-backend/chat_database.db | 4 +- .../scripts/populate_agent_templates.py | 131 +--- auto-analyst-backend/src/agents/agents.py | 453 +++++++---- auto-analyst-backend/src/db/schemas/models.py | 52 +- .../src/managers/session_manager.py | 4 +- .../src/routes/custom_agents_routes.py | 697 ----------------- .../src/routes/templates_routes.py | 478 ++++++++++++ .../src/utils/model_registry.py | 2 +- .../components/chat/AgentSuggestions.tsx | 37 +- .../components/chat/ChatInput.tsx | 32 +- .../custom-agents/AgentDetailView.tsx | 417 ---------- .../custom-agents/AgentListView.tsx | 282 ------- .../custom-agents/CreateAgentForm.tsx | 531 ------------- .../custom-agents/CustomAgentsSidebar.tsx | 720 ------------------ .../components/custom-agents/README.md | 84 ++ .../custom-agents/TemplateDetailView.tsx | 229 ++---- .../custom-agents/TemplateListView.tsx | 125 ++- .../custom-agents/TemplateManagementView.tsx | 384 ++++++++++ ...omAgentsButton.tsx => TemplatesButton.tsx} | 52 +- .../custom-agents/TemplatesSidebar.tsx | 247 ++++++ .../components/custom-agents/index.ts | 10 +- .../components/custom-agents/types.ts | 52 +- .../components/templates/TemplateListView.tsx | 1 + auto-analyst-frontend/lib/model-registry.ts | 2 +- 25 files changed, 1821 insertions(+), 3346 deletions(-) delete mode 100644 auto-analyst-backend/src/routes/custom_agents_routes.py create mode 100644 auto-analyst-backend/src/routes/templates_routes.py delete mode 100644 auto-analyst-frontend/components/custom-agents/AgentDetailView.tsx delete mode 100644 auto-analyst-frontend/components/custom-agents/AgentListView.tsx delete mode 100644 auto-analyst-frontend/components/custom-agents/CreateAgentForm.tsx delete mode 100644 auto-analyst-frontend/components/custom-agents/CustomAgentsSidebar.tsx create mode 100644 auto-analyst-frontend/components/custom-agents/README.md create mode 100644 auto-analyst-frontend/components/custom-agents/TemplateManagementView.tsx rename auto-analyst-frontend/components/custom-agents/{CustomAgentsButton.tsx => TemplatesButton.tsx} (67%) create mode 100644 auto-analyst-frontend/components/custom-agents/TemplatesSidebar.tsx create mode 100644 auto-analyst-frontend/components/templates/TemplateListView.tsx diff --git a/auto-analyst-backend/app.py b/auto-analyst-backend/app.py index 929d1c2d..f0068a44 100644 --- a/auto-analyst-backend/app.py +++ b/auto-analyst-backend/app.py @@ -43,6 +43,7 @@ from src.routes.session_routes import router as session_router, get_session_id_dependency from src.routes.deep_analysis_routes import router as deep_analysis_router from src.routes.custom_agents_routes import router as custom_agents_router +from src.routes.templates_routes import router as templates_router from src.schemas.query_schemas import QueryRequest from src.utils.logger import Logger @@ -400,7 +401,7 @@ async def verify_origin_middleware(request: Request, call_next): RESPONSE_ERROR_INVALID_QUERY = "Please provide a valid query..." RESPONSE_ERROR_NO_DATASET = "No dataset is currently loaded. Please link a dataset before proceeding with your analysis." DEFAULT_TOKEN_RATIO = 1.5 -REQUEST_TIMEOUT_SECONDS = 60 # Timeout for LLM requests +REQUEST_TIMEOUT_SECONDS = 120 # Timeout for LLM requests MAX_RECENT_MESSAGES = 3 DB_BATCH_SIZE = 10 # For future batch DB operations @@ -430,17 +431,20 @@ async def chat_with_agent( enhanced_query = _prepare_query_with_context(request.query, session_state) logger.log_message(f"Enhanced query: {enhanced_query}", level=logging.INFO) - # Initialize agent - handle both standard and custom agents + # Initialize agent - handle standard, template, and custom agents if "," in agent_name: # Multiple agents case agent_list = [agent.strip() for agent in agent_name.split(",")] - # Check if any are custom agents - has_custom_agents = any(not _is_standard_agent(agent) for agent in agent_list) - logger.log_message(f"Has custom agents: {has_custom_agents}", level=logging.INFO) + # Categorize agents + standard_agents = [agent for agent in agent_list if _is_standard_agent(agent)] + template_agents = [agent for agent in agent_list if _is_template_agent(agent)] + custom_agents = [agent for agent in agent_list if not _is_standard_agent(agent) and not _is_template_agent(agent)] - if has_custom_agents: - # Use session AI system for mixed or custom agent execution + logger.log_message(f"Agent routing - Standard: {len(standard_agents)}, Template: {len(template_agents)}, Custom: {len(custom_agents)}", level=logging.INFO) + + if custom_agents: + # If any custom agents, use session AI system for all ai_system = session_state["ai_system"] session_lm = get_session_lm(session_state) with dspy.context(lm=session_lm): @@ -449,29 +453,61 @@ async def chat_with_agent( timeout=REQUEST_TIMEOUT_SECONDS ) else: - # All standard agents - use auto_analyst_ind - standard_agent_sigs = [AVAILABLE_AGENTS[agent] for agent in agent_list] + # All standard/template agents - use auto_analyst_ind + standard_agent_sigs = [AVAILABLE_AGENTS[agent] for agent in standard_agents] user_id = session_state.get("user_id") - agent = auto_analyst_ind(agents=standard_agent_sigs, retrievers=session_state["retrievers"], user_id=user_id) - session_lm = get_session_lm(session_state) - with dspy.context(lm=session_lm): - response = await asyncio.wait_for( - agent.forward(enhanced_query, ",".join(agent_list)), - timeout=REQUEST_TIMEOUT_SECONDS - ) + + # Create database session for template loading + from src.db.init_db import session_factory + db_session = session_factory() + try: + agent = auto_analyst_ind(agents=standard_agent_sigs, retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session) + session_lm = get_session_lm(session_state) + with dspy.context(lm=session_lm): + response = await asyncio.wait_for( + agent.forward(enhanced_query, ",".join(agent_list)), + timeout=REQUEST_TIMEOUT_SECONDS + ) + finally: + db_session.close() else: # Single agent case logger.log_message(f"Single agent case: {agent_name}", level=logging.INFO) if _is_standard_agent(agent_name): # Standard agent - use auto_analyst_ind user_id = session_state.get("user_id") - agent = auto_analyst_ind(agents=[AVAILABLE_AGENTS[agent_name]], retrievers=session_state["retrievers"], user_id=user_id) - session_lm = get_session_lm(session_state) - with dspy.context(lm=session_lm): - response = await asyncio.wait_for( - agent.forward(enhanced_query, agent_name), - timeout=REQUEST_TIMEOUT_SECONDS - ) + + # Create database session for template loading + from src.db.init_db import session_factory + db_session = session_factory() + try: + agent = auto_analyst_ind(agents=[AVAILABLE_AGENTS[agent_name]], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session) + session_lm = get_session_lm(session_state) + with dspy.context(lm=session_lm): + response = await asyncio.wait_for( + agent.forward(enhanced_query, agent_name), + timeout=REQUEST_TIMEOUT_SECONDS + ) + finally: + db_session.close() + elif _is_template_agent(agent_name): + # Template agent - use auto_analyst_ind with empty agents list (templates loaded in init) + logger.log_message(f"Template agent case: {agent_name}", level=logging.INFO) + user_id = session_state.get("user_id") + + # Create database session for template loading + from src.db.init_db import session_factory + db_session = session_factory() + try: + agent = auto_analyst_ind(agents=[], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session) + session_lm = get_session_lm(session_state) + with dspy.context(lm=session_lm): + response = await asyncio.wait_for( + agent.forward(enhanced_query, agent_name), + timeout=REQUEST_TIMEOUT_SECONDS + ) + finally: + db_session.close() else: # Custom agent - use session AI system logger.log_message(f"Custom agent case: {agent_name}", level=logging.INFO) @@ -606,11 +642,29 @@ def _validate_agent_name(agent_name: str, session_state: dict = None): ) def _is_agent_available(agent_name: str, session_state: dict = None) -> bool: - """Check if agent is available in either standard agents or user's custom agents""" + """Check if agent is available in either standard agents, template agents, or user's custom agents""" # Check standard agents if agent_name in AVAILABLE_AGENTS: return True + # Check template agents + try: + from src.db.init_db import session_factory + from src.db.schemas.models import AgentTemplate + + db_session = session_factory() + try: + template = db_session.query(AgentTemplate).filter( + AgentTemplate.template_name == agent_name, + AgentTemplate.is_active == True + ).first() + if template: + return True + finally: + db_session.close() + except Exception as e: + logger.log_message(f"Error checking template availability for {agent_name}: {str(e)}", level=logging.ERROR) + # Check custom agents if session has an AI system with custom agents if session_state and "ai_system" in session_state: ai_system = session_state["ai_system"] @@ -634,9 +688,28 @@ def _get_available_agents_list(session_state: dict = None) -> list: return available def _is_standard_agent(agent_name: str) -> bool: - """Check if agent is a standard agent (not custom)""" + """Check if agent is a standard agent (not custom or template)""" return agent_name in AVAILABLE_AGENTS +def _is_template_agent(agent_name: str) -> bool: + """Check if agent is a template agent""" + try: + from src.db.init_db import session_factory + from src.db.schemas.models import AgentTemplate + + db_session = session_factory() + try: + template = db_session.query(AgentTemplate).filter( + AgentTemplate.template_name == agent_name, + AgentTemplate.is_active == True + ).first() + return template is not None + finally: + db_session.close() + except Exception as e: + logger.log_message(f"Error checking if {agent_name} is template: {str(e)}", level=logging.ERROR) + return False + async def _execute_custom_agents(ai_system, agent_names: list, query: str): """Execute custom agents using the session's AI system""" try: @@ -644,8 +717,6 @@ async def _execute_custom_agents(ai_system, agent_names: list, query: str): if len(agent_names) == 1: # Single custom agent agent_name = agent_names[0] - logger.log_message(f"Executing custom agent: {agent_name}", level=logging.INFO) - # Prepare inputs for the custom agent (similar to standard agents like data_viz_agent) dict_ = {} dict_['dataset'] = ai_system.dataset.retrieve(query)[0].text @@ -657,14 +728,11 @@ async def _execute_custom_agents(ai_system, agent_names: list, query: str): if agent_name in ai_system.agent_inputs: inputs = {x: dict_[x] for x in ai_system.agent_inputs[agent_name] if x in dict_} - logger.log_message(f"Inputs for {agent_name}: {list(inputs.keys())}", level=logging.INFO) - # Execute the custom agent agent_name_result, result_dict = await ai_system.execute_agent(agent_name, inputs) - logger.log_message(f"Custom agent result: {agent_name_result}, has keys: {list(result_dict.keys()) if isinstance(result_dict, dict) else 'not dict'}", level=logging.INFO) return {agent_name_result: result_dict} else: - logger.log_message(f"Agent '{agent_name}' not found in ai_system.agent_inputs. Available: {list(ai_system.agent_inputs.keys())}", level=logging.ERROR) + logger.log_message(f"Agent '{agent_name}' not found in ai_system.agent_inputs", level=logging.ERROR) return {"error": f"Agent '{agent_name}' input configuration not found"} else: # Multiple agents - execute sequentially @@ -965,17 +1033,15 @@ async def list_agents(request: Request, session_id: str = Depends(get_session_id template_agents = [] try: from src.db.init_db import session_factory - from src.db.schemas.models import CustomAgent + from src.db.schemas.models import AgentTemplate db_session = session_factory() try: - templates = db_session.query(CustomAgent).filter( - CustomAgent.is_template == True, - CustomAgent.is_active == True, - CustomAgent.user_id == None # System templates + templates = db_session.query(AgentTemplate).filter( + AgentTemplate.is_active == True ).all() - template_agents = [template.agent_name for template in templates] + template_agents = [template.template_name for template in templates] logger.log_message(f"Found {len(template_agents)} template agents", level=logging.INFO) finally: @@ -1468,6 +1534,7 @@ async def download_html_report( app.include_router(feedback_router) app.include_router(deep_analysis_router) app.include_router(custom_agents_router) +app.include_router(templates_router) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/auto-analyst-backend/chat_database.db b/auto-analyst-backend/chat_database.db index 169301de..87155fd6 100644 --- a/auto-analyst-backend/chat_database.db +++ b/auto-analyst-backend/chat_database.db @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:eba37409fcfd29b12a840c1a16aae4616c9a3a25edb89dc1746d23740a172607 -size 73728 +oid sha256:b6e840e8c6dc93fad9bd771e0f6699d168ff4b999fa244a091d1b28c49b8971f +size 212992 diff --git a/auto-analyst-backend/scripts/populate_agent_templates.py b/auto-analyst-backend/scripts/populate_agent_templates.py index 332e9be6..ac65caac 100644 --- a/auto-analyst-backend/scripts/populate_agent_templates.py +++ b/auto-analyst-backend/scripts/populate_agent_templates.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -Script to populate custom agent templates. +Script to populate agent templates. These templates are available to all users but usable only by paid users. """ @@ -12,14 +12,14 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.db.init_db import session_factory -from src.db.schemas.models import CustomAgent +from src.db.schemas.models import AgentTemplate from sqlalchemy.exc import IntegrityError # Template agent definitions AGENT_TEMPLATES = { "Visualization": [ { - "agent_name": "matplotlib_agent", + "template_name": "matplotlib_agent", "display_name": "Matplotlib Visualization Agent", "description": "Creates static publication-quality plots using matplotlib and seaborn", "prompt_template": """ @@ -35,50 +35,12 @@ - Always end with plt.show() Focus on creating publication-ready static visualizations that are informative and aesthetically pleasing. -""" - }, - { - "agent_name": "seaborn_agent", - "display_name": "Seaborn Statistical Plots Agent", - "description": "Creates statistical visualizations and plots using seaborn library", - "prompt_template": """ -You are a seaborn statistical visualization expert. Create insightful statistical plots using seaborn. - -IMPORTANT Instructions: -- Specialize in seaborn's statistical plotting capabilities -- Use seaborn's built-in statistical functions (regplot, distplot, boxplot, violin, etc.) -- Apply appropriate statistical themes and color palettes -- Include confidence intervals and statistical annotations where relevant -- Sample large datasets: if len(df) > 50000: df = df.sample(50000, random_state=42) -- Use plt.figure(figsize=(10, 6)) for appropriate sizing -- Always include proper statistical context in titles and labels - -Focus on revealing statistical relationships and distributions in the data. -""" - }, - { - "agent_name": "plotly_advanced_agent", - "display_name": "Advanced Plotly Agent", - "description": "Creates sophisticated interactive visualizations with advanced Plotly features", - "prompt_template": """ -You are an advanced Plotly visualization expert. Create sophisticated interactive visualizations with advanced features. - -IMPORTANT Instructions: -- Use advanced Plotly features: subplots, animations, 3D plots, statistical charts -- Implement interactive features: hover data, clickable legends, zoom, pan -- Use plotly.graph_objects for fine control and plotly.express for rapid prototyping -- Add annotations, shapes, and custom styling -- Sample data if len(df) > 50000: df = df.sample(50000, random_state=42) -- Use fig.update_layout() for professional styling -- Return fig.to_html(full_html=False) for embedding - -Focus on creating publication-quality interactive visualizations with advanced features. """ } ], "Modelling": [ { - "agent_name": "xgboost_agent", + "template_name": "xgboost_agent", "display_name": "XGBoost Machine Learning Agent", "description": "Builds and optimizes XGBoost models for classification and regression tasks", "prompt_template": """ @@ -99,28 +61,7 @@ """ }, { - "agent_name": "neural_network_agent", - "display_name": "Neural Network Agent", - "description": "Builds and trains neural networks using TensorFlow/Keras", - "prompt_template": """ -You are a neural network expert using TensorFlow/Keras. Build and train neural networks for various tasks. - -IMPORTANT Instructions: -- Design appropriate network architectures for the task (classification, regression, etc.) -- Implement proper data preprocessing and normalization -- Use appropriate activation functions, optimizers, and loss functions -- Implement callbacks: EarlyStopping, ReduceLROnPlateau, ModelCheckpoint -- Plot training history (loss and metrics over epochs) -- Evaluate model performance with appropriate metrics -- Include model summary and architecture visualization -- Handle overfitting with dropout, regularization, or data augmentation -- Use train/validation/test splits properly - -Focus on building effective neural networks with proper training procedures and evaluation. -""" - }, - { - "agent_name": "time_series_agent", + "template_name": "time_series_agent", "display_name": "Time Series Forecasting Agent", "description": "Specialized in time series analysis and forecasting using ARIMA, Prophet, LSTM", "prompt_template": """ @@ -143,28 +84,7 @@ ], "Data Manipulation": [ { - "agent_name": "pandas_expert_agent", - "display_name": "Pandas Data Expert Agent", - "description": "Advanced pandas operations for complex data manipulation and analysis", - "prompt_template": """ -You are a pandas expert specializing in advanced data manipulation and analysis. - -IMPORTANT Instructions: -- Use advanced pandas operations: groupby, pivot, merge, concat, apply, transform -- Implement efficient data cleaning and preprocessing workflows -- Handle missing data with multiple strategies (imputation, dropping, flagging) -- Perform advanced aggregations and window functions -- Use vectorized operations for performance -- Handle large datasets efficiently with chunking if needed -- Create custom functions for complex transformations -- Use proper indexing and data types for optimization -- Include data quality checks and validation - -Focus on efficient and robust data manipulation that prepares data for analysis or modeling. -""" - }, - { - "agent_name": "data_cleaning_agent", + "template_name": "data_cleaning_agent", "display_name": "Data Cleaning Specialist Agent", "description": "Specialized in comprehensive data cleaning and quality assessment", "prompt_template": """ @@ -186,7 +106,7 @@ """ }, { - "agent_name": "feature_engineering_agent", + "template_name": "feature_engineering_agent", "display_name": "Feature Engineering Agent", "description": "Creates and transforms features for machine learning models", "prompt_template": """ @@ -223,37 +143,33 @@ def populate_templates(): print(f"\n--- Processing {category} Templates ---") for template_data in templates: - agent_name = template_data["agent_name"] + template_name = template_data["template_name"] # Check if template already exists - existing = session.query(CustomAgent).filter( - CustomAgent.agent_name == agent_name, - CustomAgent.is_template == True + existing = session.query(AgentTemplate).filter( + AgentTemplate.template_name == template_name ).first() if existing: - print(f"⏭️ Skipping {agent_name} (already exists)") + print(f"⏭️ Skipping {template_name} (already exists)") skipped_count += 1 continue # Create new template - template = CustomAgent( - user_id=None, # Templates don't belong to specific users - agent_name=agent_name, + template = AgentTemplate( + template_name=template_name, display_name=template_data["display_name"], description=template_data["description"], prompt_template=template_data["prompt_template"], - is_template=True, - template_category=category, + category=category, is_premium_only=True, # All templates require premium is_active=True, - usage_count=0, created_at=datetime.now(UTC), updated_at=datetime.now(UTC) ) session.add(template) - print(f"✅ Created template: {agent_name}") + print(f"✅ Created template: {template_name}") created_count += 1 # Commit all changes @@ -276,9 +192,7 @@ def list_templates(): session = session_factory() try: - templates = session.query(CustomAgent).filter( - CustomAgent.is_template == True - ).order_by(CustomAgent.template_category, CustomAgent.agent_name).all() + templates = session.query(AgentTemplate).order_by(AgentTemplate.category, AgentTemplate.template_name).all() if not templates: print("No templates found in database.") @@ -288,15 +202,14 @@ def list_templates(): current_category = None for template in templates: - if template.template_category != current_category: - current_category = template.template_category + if template.category != current_category: + current_category = template.category print(f"\n{current_category}:") status = "🔒 Premium" if template.is_premium_only else "🆓 Free" active = "✅ Active" if template.is_active else "❌ Inactive" - print(f" • {template.agent_name} ({template.display_name}) - {status} - {active}") + print(f" • {template.template_name} ({template.display_name}) - {status} - {active}") print(f" {template.description}") - print(f" Usage: {template.usage_count} times") except Exception as e: print(f"❌ Error listing templates: {str(e)}") @@ -308,9 +221,7 @@ def remove_all_templates(): session = session_factory() try: - deleted_count = session.query(CustomAgent).filter( - CustomAgent.is_template == True - ).delete() + deleted_count = session.query(AgentTemplate).delete() session.commit() print(f"🗑️ Removed {deleted_count} templates") @@ -324,7 +235,7 @@ def remove_all_templates(): if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description="Manage custom agent templates") + parser = argparse.ArgumentParser(description="Manage agent templates") parser.add_argument("action", choices=["populate", "list", "remove-all"], help="Action to perform") diff --git a/auto-analyst-backend/src/agents/agents.py b/auto-analyst-backend/src/agents/agents.py index b42fcefe..f65a4abe 100644 --- a/auto-analyst-backend/src/agents/agents.py +++ b/auto-analyst-backend/src/agents/agents.py @@ -36,117 +36,229 @@ def create_custom_agent_signature(agent_name, description, prompt_template): CustomAgentSignature = type(agent_name, (dspy.Signature,), class_attributes) return CustomAgentSignature -def load_custom_agents_from_db(user_id, db_session, include_templates=True): +def load_user_enabled_templates_from_db(user_id, db_session): """ - Load custom agents for a specific user from the database, optionally including templates. + Load template agents that are enabled for a specific user from the database. + All templates are enabled by default unless explicitly disabled by user preference. Args: user_id: ID of the user db_session: Database session - include_templates: Whether to include template agents that are available to all users Returns: - Dict of custom agent signatures keyed by agent name + Dict of template agent signatures keyed by template name """ try: - from src.db.schemas.models import CustomAgent + from src.db.schemas.models import AgentTemplate, UserTemplatePreference agent_signatures = {} - # Query active custom agents for the user - if user_id: - custom_agents = db_session.query(CustomAgent).filter( - CustomAgent.user_id == user_id, - CustomAgent.is_active == True, - CustomAgent.is_template == False # Only user-created agents - ).all() + if not user_id: + return agent_signatures + + # Get all active templates + all_templates = db_session.query(AgentTemplate).filter( + AgentTemplate.is_active == True + ).all() + + for template in all_templates: + # Check if user has explicitly disabled this template + preference = db_session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.template_id == template.template_id + ).first() + + # Template is disabled by default unless explicitly enabled + # Only enabled if preference record exists and is_enabled=True + is_enabled = preference.is_enabled if preference else False - for agent in custom_agents: - # Create dynamic signature for each custom agent + if is_enabled: + # Create dynamic signature for each enabled template signature = create_custom_agent_signature( - agent.agent_name, - agent.description, - agent.prompt_template + template.template_name, + template.description, + template.prompt_template ) - agent_signatures[agent.agent_name] = signature + agent_signatures[template.template_name] = signature + + return agent_signatures + + except Exception as e: + logger.log_message(f"Error loading user enabled templates for user {user_id}: {str(e)}", level=logging.ERROR) + return {} + +def load_user_enabled_templates_for_planner_from_db(user_id, db_session): + """ + Load template agents that are enabled for planner use (max 10, prioritized by usage). + + Args: + user_id: ID of the user + db_session: Database session + + Returns: + Dict of template agent signatures keyed by template name (max 10) + """ + try: + from src.db.schemas.models import AgentTemplate, UserTemplatePreference + + agent_signatures = {} + + if not user_id: + return agent_signatures - # Also include template agents if requested - if include_templates: - template_agents = db_session.query(CustomAgent).filter( - CustomAgent.is_template == True, - CustomAgent.is_active == True - ).all() + # Get enabled templates ordered by usage (most used first) and limit to 10 + enabled_preferences = db_session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.is_enabled == True + ).order_by( + UserTemplatePreference.usage_count.desc(), + UserTemplatePreference.last_used_at.desc() + ).limit(10).all() + + for preference in enabled_preferences: + # Get template details + template = db_session.query(AgentTemplate).filter( + AgentTemplate.template_id == preference.template_id, + AgentTemplate.is_active == True + ).first() - for template in template_agents: - # Create dynamic signature for each template agent + if template: + # Create dynamic signature for each enabled template signature = create_custom_agent_signature( - template.agent_name, + template.template_name, template.description, template.prompt_template ) - agent_signatures[template.agent_name] = signature + agent_signatures[template.template_name] = signature + logger.log_message(f"Loaded {len(agent_signatures)} templates for planner for user {user_id}", level=logging.INFO) return agent_signatures except Exception as e: - logger.log_message(f"Error loading custom agents for user {user_id}: {str(e)}", level=logging.ERROR) + logger.log_message(f"Error loading planner templates for user {user_id}: {str(e)}", level=logging.ERROR) return {} -def get_custom_agent_description(agent_name, custom_agents_descriptions): +def get_all_available_templates(db_session): """ - Get description for a custom agent. + Get all available agent templates from the database. Args: - agent_name: Name of the custom agent - custom_agents_descriptions: Dict of custom agent descriptions + db_session: Database session Returns: - Description string or default message + List of agent template records """ - return custom_agents_descriptions.get(agent_name.lower(), "Custom agent - no description available") + try: + from src.db.schemas.models import AgentTemplate + + templates = db_session.query(AgentTemplate).filter( + AgentTemplate.is_active == True + ).all() + + return templates + + except Exception as e: + logger.log_message(f"Error getting all available templates: {str(e)}", level=logging.ERROR) + return [] -def save_custom_agent_to_db(user_id, agent_name, display_name, description, prompt_template, db_session): +def toggle_user_template_preference(user_id, template_id, is_enabled, db_session): """ - Save a new custom agent to the database. + Toggle a user's template preference (enable/disable). Args: - user_id: ID of the user creating the agent - agent_name: Unique name for the agent (e.g., 'pytorch_agent') - display_name: User-friendly display name - description: Short description for agent selection - prompt_template: Main prompt/instructions for agent behavior + user_id: ID of the user + template_id: ID of the template + is_enabled: Whether to enable or disable the template db_session: Database session Returns: - Tuple (success: bool, message: str, agent_id: int or None) + Tuple (success: bool, message: str) """ try: - from src.db.schemas.models import CustomAgent - from sqlalchemy.exc import IntegrityError + from src.db.schemas.models import UserTemplatePreference, AgentTemplate + from datetime import datetime, UTC + + # Verify template exists and is active + template = db_session.query(AgentTemplate).filter( + AgentTemplate.template_id == template_id, + AgentTemplate.is_active == True + ).first() - # Create new custom agent - new_agent = CustomAgent( - user_id=user_id, - agent_name=agent_name.lower().strip(), - display_name=display_name, - description=description, - prompt_template=prompt_template, - is_active=True, - usage_count=0 - ) + if not template: + return False, "Template not found or inactive" + + # Check if preference record exists + preference = db_session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.template_id == template_id + ).first() + + if preference: + # Update existing preference + preference.is_enabled = is_enabled + preference.updated_at = datetime.now(UTC) + else: + # Create new preference record + preference = UserTemplatePreference( + user_id=user_id, + template_id=template_id, + is_enabled=is_enabled, + usage_count=0, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC) + ) + db_session.add(preference) - db_session.add(new_agent) db_session.commit() - return True, "Custom agent created successfully", new_agent.agent_id + action = "enabled" if is_enabled else "disabled" + return True, f"Template '{template.template_name}' {action} successfully" - except IntegrityError: - db_session.rollback() - return False, f"Agent name '{agent_name}' already exists, please choose a different name", None except Exception as e: db_session.rollback() - logger.log_message(f"Error saving custom agent: {str(e)}", level=logging.ERROR) - return False, f"Error creating custom agent: {str(e)}", None + logger.log_message(f"Error toggling template preference: {str(e)}", level=logging.ERROR) + return False, f"Error updating template preference: {str(e)}" + + + +def load_all_available_templates_from_db(db_session): + """ + Load ALL available template agents from the database for direct access. + This allows users to use any template via @template_name regardless of preferences. + + Args: + db_session: Database session + + Returns: + Dict of template agent signatures keyed by template name + """ + try: + from src.db.schemas.models import AgentTemplate + + agent_signatures = {} + + # Get all active templates + all_templates = db_session.query(AgentTemplate).filter( + AgentTemplate.is_active == True + ).all() + + for template in all_templates: + # Create dynamic signature for all active templates + signature = create_custom_agent_signature( + template.template_name, + template.description, + template.prompt_template + ) + agent_signatures[template.template_name] = signature + + logger.log_message(f"Loaded {len(agent_signatures)} templates", level=logging.INFO) + return agent_signatures + + except Exception as e: + logger.log_message(f"Error loading all available templates: {str(e)}", level=logging.ERROR) + return {} + + # === END CUSTOM AGENT FUNCTIONALITY === @@ -1100,7 +1212,7 @@ class code_edit(dspy.Signature): class auto_analyst_ind(dspy.Module): """Handles individual agent execution when explicitly specified in query""" - def __init__(self, agents, retrievers, user_id=None): + def __init__(self, agents, retrievers, user_id=None, db_session=None): # Initialize agent modules and retrievers self.agents = {} self.agent_inputs = {} @@ -1112,6 +1224,26 @@ def __init__(self, agents, retrievers, user_id=None): self.agents[name] = dspy.asyncify(dspy.ChainOfThoughtWithHint(a)) self.agent_inputs[name] = {x.strip() for x in str(agents[i].__pydantic_core_schema__['cls']).split('->')[0].split('(')[1].split(',')} self.agent_desc.append(get_agent_description(name)) + + # Load ALL available template agents for direct access (regardless of user preferences) + if db_session: + try: + template_signatures = load_all_available_templates_from_db(db_session) + + for template_name, signature in template_signatures.items(): + # Add template agent to agents dict + self.agents[template_name] = dspy.asyncify(dspy.ChainOfThoughtWithHint(signature)) + + # Extract input fields from signature - templates use standard fields + self.agent_inputs[template_name] = {'goal', 'dataset', 'styling_index', 'hint'} + + # Add description + self.agent_desc.append(f"Template: {template_name}") + + logger.log_message(f"Loaded {len(template_signatures)} templates for direct access", level=logging.INFO) + + except Exception as e: + logger.log_message(f"Error loading templates for direct access: {str(e)}", level=logging.ERROR) # Initialize components # self.memory_summarize_agent = dspy.ChainOfThought(m.memory_summarize_agent) @@ -1123,58 +1255,71 @@ def __init__(self, agents, retrievers, user_id=None): self.user_id = user_id async def _track_agent_usage(self, agent_name): - """Track usage for custom agents and templates""" + """Track usage for template agents""" try: # Skip tracking for standard agents if agent_name in ['preprocessing_agent', 'statistical_analytics_agent', 'sk_learn_agent', 'data_viz_agent', 'basic_qa_agent']: return - # Only track if we have user_id (custom/template agents) + # Only track if we have user_id (template agents) if not self.user_id: return from src.db.init_db import session_factory - from src.db.schemas.models import CustomAgent + from src.db.schemas.models import AgentTemplate, UserTemplatePreference from datetime import datetime, UTC # Create database session session = session_factory() try: - # First try to find as user's custom agent - agent = session.query(CustomAgent).filter( - CustomAgent.agent_name == agent_name, - CustomAgent.user_id == self.user_id, - CustomAgent.is_template == False + # Find the template + template = session.query(AgentTemplate).filter( + AgentTemplate.template_name == agent_name ).first() - # If not found, try to find as template - if not agent: - agent = session.query(CustomAgent).filter( - CustomAgent.agent_name == agent_name, - CustomAgent.is_template == True - ).first() + if not template: + logger.log_message(f"Template '{agent_name}' not found for usage tracking", level=logging.WARNING) + return - # Update usage count if agent found - if agent: - agent.usage_count += 1 - agent.updated_at = datetime.now(UTC) - session.commit() - - agent_type = "template" if agent.is_template else "custom" - logger.log_message( - f"Incremented usage count for {agent_type} agent '{agent_name}' (now: {agent.usage_count})", - level=logging.INFO + # Find or create user template preference record + preference = session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == self.user_id, + UserTemplatePreference.template_id == template.template_id + ).first() + + if not preference: + # Create new preference record (disabled by default) + preference = UserTemplatePreference( + user_id=self.user_id, + template_id=template.template_id, + is_enabled=False, # Disabled by default + usage_count=0, + last_used_at=None, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC) ) + session.add(preference) + + # Update usage tracking + preference.usage_count += 1 + preference.last_used_at = datetime.now(UTC) + preference.updated_at = datetime.now(UTC) + session.commit() + + logger.log_message( + f"Tracked usage for template '{agent_name}' (count: {preference.usage_count})", + level=logging.DEBUG + ) except Exception as e: session.rollback() - logger.log_message(f"Error tracking usage for agent {agent_name}: {str(e)}", level=logging.ERROR) + logger.log_message(f"Error tracking usage for template {agent_name}: {str(e)}", level=logging.ERROR) finally: session.close() except Exception as e: logger.log_message(f"Error in _track_agent_usage for {agent_name}: {str(e)}", level=logging.ERROR) - + async def execute_agent(self, specified_agent, inputs): """Execute agent and generate memory summary in parallel""" try: @@ -1211,6 +1356,10 @@ async def forward(self, query, specified_agent): # Execute agent result = await self.agents[specified_agent.strip()](**inputs) + + # Track usage for template agents + await self._track_agent_usage(specified_agent.strip()) + output_dict = {specified_agent.strip(): dict(result)} if "error" in output_dict: @@ -1250,6 +1399,9 @@ async def execute_multiple_agents(self, query, agent_list): agent_dict = dict(agent_result) results[agent_name] = agent_dict + # Track usage for template agents + await self._track_agent_usage(agent_name) + # Collect code for later combination if 'code' in agent_dict: code_list.append(agent_dict['code']) @@ -1269,7 +1421,6 @@ def __init__(self, agents, retrievers, user_id=None, db_session=None): self.agents = {} self.agent_inputs = {} self.agent_desc = [] - self.custom_agents_descriptions = {} # Load standard agents for i, a in enumerate(agents): @@ -1278,56 +1429,40 @@ def __init__(self, agents, retrievers, user_id=None, db_session=None): self.agent_inputs[name] = {x.strip() for x in str(agents[i].__pydantic_core_schema__['cls']).split('->')[0].split('(')[1].split(',')} self.agent_desc.append({name: get_agent_description(name)}) - # Load custom agents if user_id and db_session are provided + # Load user-enabled template agents if user_id and db_session are provided if user_id and db_session: try: - custom_agent_signatures = load_custom_agents_from_db(user_id, db_session, include_templates=True) + template_signatures = load_user_enabled_templates_for_planner_from_db(user_id, db_session) - for agent_name, signature in custom_agent_signatures.items(): - # Add custom agent to agents dict - self.agents[agent_name] = dspy.asyncify(dspy.ChainOfThought(signature)) + for template_name, signature in template_signatures.items(): + # Add template agent to agents dict + self.agents[template_name] = dspy.asyncify(dspy.ChainOfThought(signature)) - # Extract input fields from signature - custom agents use standard fields like data_viz_agent - self.agent_inputs[agent_name] = {'goal', 'dataset', 'styling_index'} + # Extract input fields from signature - templates use standard fields like data_viz_agent + self.agent_inputs[template_name] = {'goal', 'dataset', 'styling_index'} - # Store custom agent description + # Store template agent description try: - from src.db.schemas.models import CustomAgent + from src.db.schemas.models import AgentTemplate - # First try to find as user agent - agent_record = db_session.query(CustomAgent).filter( - CustomAgent.agent_name == agent_name, - CustomAgent.user_id == user_id, - CustomAgent.is_template == False + # Find template record + template_record = db_session.query(AgentTemplate).filter( + AgentTemplate.template_name == template_name ).first() - # If not found, try to find as template - if not agent_record: - agent_record = db_session.query(CustomAgent).filter( - CustomAgent.agent_name == agent_name, - CustomAgent.is_template == True - ).first() - - if agent_record: - description = agent_record.description - if agent_record.is_template: - # Add prefix for template agents - description = f"Template: {description}" - - self.custom_agents_descriptions[agent_name] = description - self.agent_desc.append({agent_name: description}) + if template_record: + description = f"Template: {template_record.description}" + self.agent_desc.append({template_name: description}) else: - self.custom_agents_descriptions[agent_name] = f"Custom agent: {agent_name}" - self.agent_desc.append({agent_name: f"Custom agent: {agent_name}"}) + self.agent_desc.append({template_name: f"Template: {template_name}"}) except Exception as desc_error: - logger.log_message(f"Error getting description for custom agent {agent_name}: {str(desc_error)}", level=logging.WARNING) - self.custom_agents_descriptions[agent_name] = f"Custom agent: {agent_name}" - self.agent_desc.append({agent_name: f"Custom agent: {agent_name}"}) + logger.log_message(f"Error getting description for template {template_name}: {str(desc_error)}", level=logging.WARNING) + self.agent_desc.append({template_name: f"Template: {template_name}"}) - logger.log_message(f"Loaded {len(custom_agent_signatures)} custom agents (including templates) for user {user_id}", level=logging.INFO) + logger.log_message(f"Loaded {len(template_signatures)} enabled template agents for user {user_id}", level=logging.INFO) except Exception as e: - logger.log_message(f"Error loading custom agents for user {user_id}: {str(e)}", level=logging.ERROR) + logger.log_message(f"Error loading template agents for user {user_id}: {str(e)}", level=logging.ERROR) self.agents['basic_qa_agent'] = dspy.asyncify(dspy.Predict("goal->answer")) self.agent_inputs['basic_qa_agent'] = {"goal"} @@ -1349,52 +1484,67 @@ def __init__(self, agents, retrievers, user_id=None, db_session=None): self.user_id = user_id async def _track_agent_usage(self, agent_name): - """Track usage for custom agents and templates""" + """Track usage for template agents""" try: # Skip tracking for standard agents if agent_name in ['preprocessing_agent', 'statistical_analytics_agent', 'sk_learn_agent', 'data_viz_agent', 'basic_qa_agent']: return - # Only track if we have user_id (custom/template agents) + # Only track if we have user_id (template agents) if not self.user_id: return from src.db.init_db import session_factory - from src.db.schemas.models import CustomAgent + from src.db.schemas.models import AgentTemplate, UserTemplatePreference from datetime import datetime, UTC # Create database session session = session_factory() try: - # First try to find as user's custom agent - agent = session.query(CustomAgent).filter( - CustomAgent.agent_name == agent_name, - CustomAgent.user_id == self.user_id, - CustomAgent.is_template == False + # Find the template + template = session.query(AgentTemplate).filter( + AgentTemplate.template_name == agent_name ).first() - # If not found, try to find as template - if not agent: - agent = session.query(CustomAgent).filter( - CustomAgent.agent_name == agent_name, - CustomAgent.is_template == True - ).first() + if not template: + logger.log_message(f"Template '{agent_name}' not found", level=logging.WARNING) + return - # Update usage count if agent found - if agent: - agent.usage_count += 1 - agent.updated_at = datetime.now(UTC) - session.commit() - - agent_type = "template" if agent.is_template else "custom" - logger.log_message( - f"Incremented usage count for {agent_type} agent '{agent_name}' (now: {agent.usage_count})", - level=logging.INFO + # Find or create user template preference record + preference = session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == self.user_id, + UserTemplatePreference.template_id == template.template_id + ).first() + + if preference: + # Update existing preference + preference.usage_count += 1 + preference.last_used_at = datetime.now(UTC) + preference.updated_at = datetime.now(UTC) + else: + # Create new preference record when template is used directly (via @mention) + # Direct usage doesn't auto-enable for planner but tracks usage + preference = UserTemplatePreference( + user_id=self.user_id, + template_id=template.template_id, + is_enabled=False, # Default disabled for planner + usage_count=1, + last_used_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC) ) + session.add(preference) + + session.commit() + + logger.log_message( + f"Tracked usage for template '{agent_name}' for user {self.user_id} (count: {preference.usage_count})", + level=logging.DEBUG + ) except Exception as e: session.rollback() - logger.log_message(f"Error tracking usage for agent {agent_name}: {str(e)}", level=logging.ERROR) + logger.log_message(f"Error tracking template usage for {agent_name}: {str(e)}", level=logging.ERROR) finally: session.close() @@ -1530,4 +1680,3 @@ async def execute_plan(self, query, plan): except Exception as e: logger.log_message(f"Error in task execution: {str(e)}", level=logging.ERROR) yield "error", {}, {"error": str(e)} - diff --git a/auto-analyst-backend/src/db/schemas/models.py b/auto-analyst-backend/src/db/schemas/models.py index de48e080..958d76a0 100644 --- a/auto-analyst-backend/src/db/schemas/models.py +++ b/auto-analyst-backend/src/db/schemas/models.py @@ -18,7 +18,7 @@ class User(Base): chats = relationship("Chat", back_populates="user", cascade="all, delete-orphan") usage_records = relationship("ModelUsage", back_populates="user") deep_analysis_reports = relationship("DeepAnalysisReport", back_populates="user", cascade="all, delete-orphan") - custom_agents = relationship("CustomAgent", back_populates="user", cascade="all, delete-orphan") + template_preferences = relationship("UserTemplatePreference", back_populates="user", cascade="all, delete-orphan") # Define the Chats table class Chat(Base): @@ -173,38 +173,58 @@ class DeepAnalysisReport(Base): # Relationships user = relationship("User", back_populates="deep_analysis_reports") -class CustomAgent(Base): - """Stores custom agents created by premium users and system templates.""" - __tablename__ = 'custom_agents' +class AgentTemplate(Base): + """Stores predefined agent templates that users can enable/disable.""" + __tablename__ = 'agent_templates' - agent_id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(Integer, ForeignKey('users.user_id', ondelete="CASCADE"), nullable=True) # Nullable for templates + template_id = Column(Integer, primary_key=True, autoincrement=True) - # Agent definition - agent_name = Column(String(100), nullable=False) # e.g., 'pytorch_agent', 'deep_learning_agent' + # Template definition + template_name = Column(String(100), nullable=False, unique=True) # e.g., 'pytorch_specialist', 'data_cleaning_expert' display_name = Column(String(200), nullable=True) # User-friendly display name - description = Column(Text, nullable=False) # Short description for agent selection + description = Column(Text, nullable=False) # Short description for template selection prompt_template = Column(Text, nullable=False) # Main prompt/instructions for agent behavior - # Template fields - is_template = Column(Boolean, default=False) # True for system templates, False for user agents - template_category = Column(String(50), nullable=True) # 'Visualization', 'Modelling', 'Data Manipulation' + # Template categorization + category = Column(String(50), nullable=True) # 'Visualization', 'Modelling', 'Data Manipulation' is_premium_only = Column(Boolean, default=False) # True if template requires premium subscription # Status and metadata is_active = Column(Boolean, default=True) - usage_count = Column(Integer, default=0) # Track how many times agent has been used # Timestamps created_at = Column(DateTime, default=lambda: datetime.now(UTC)) updated_at = Column(DateTime, default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) # Relationships - user = relationship("User", back_populates="custom_agents") + user_preferences = relationship("UserTemplatePreference", back_populates="template", cascade="all, delete-orphan") + +class UserTemplatePreference(Base): + """Tracks user preferences and usage for agent templates.""" + __tablename__ = 'user_template_preferences' + + preference_id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey('users.user_id', ondelete="CASCADE"), nullable=False) + template_id = Column(Integer, ForeignKey('agent_templates.template_id', ondelete="CASCADE"), nullable=False) + + # User preferences + is_enabled = Column(Boolean, default=True) # Whether user has this template enabled + + # Usage tracking + usage_count = Column(Integer, default=0) # Track how many times user has used this template + last_used_at = Column(DateTime, nullable=True) # Last time user used this template + + # Timestamps + created_at = Column(DateTime, default=lambda: datetime.now(UTC)) + updated_at = Column(DateTime, default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) + + # Relationships + user = relationship("User", back_populates="template_preferences") + template = relationship("AgentTemplate", back_populates="user_preferences") - # Constraints + # Constraints - user can only have one preference record per template __table_args__ = ( - UniqueConstraint('user_id', 'agent_name', name='unique_user_agent_name'), + UniqueConstraint('user_id', 'template_id', name='unique_user_template_preference'), ) \ No newline at end of file diff --git a/auto-analyst-backend/src/managers/session_manager.py b/auto-analyst-backend/src/managers/session_manager.py index bf5eb558..013801c5 100644 --- a/auto-analyst-backend/src/managers/session_manager.py +++ b/auto-analyst-backend/src/managers/session_manager.py @@ -316,7 +316,7 @@ def create_ai_system_for_user(self, retrievers, user_id=None): user_id=user_id, db_session=db_session ) - logger.log_message(f"Created AI system with custom agents for user {user_id}", level=logging.INFO) + logger.log_message(f"Created AI system for user {user_id}", level=logging.INFO) return ai_system finally: db_session.close() @@ -368,7 +368,7 @@ def set_session_user(self, session_id: str, user_id: int, chat_id: int = None): session_retrievers = self._sessions[session_id]["retrievers"] user_ai_system = self.create_ai_system_for_user(session_retrievers, user_id) self._sessions[session_id]["ai_system"] = user_ai_system - logger.log_message(f"Updated AI system for session {session_id} with user {user_id} custom agents", level=logging.INFO) + logger.log_message(f"Updated AI system for session {session_id} with user {user_id}", level=logging.INFO) except Exception as e: logger.log_message(f"Error updating AI system for user {user_id}: {str(e)}", level=logging.ERROR) # Continue with existing AI system if update fails diff --git a/auto-analyst-backend/src/routes/custom_agents_routes.py b/auto-analyst-backend/src/routes/custom_agents_routes.py deleted file mode 100644 index 54018e28..00000000 --- a/auto-analyst-backend/src/routes/custom_agents_routes.py +++ /dev/null @@ -1,697 +0,0 @@ -import logging -import os -from fastapi import APIRouter, Depends, HTTPException, Query, Body -from pydantic import BaseModel, Field -from typing import List, Optional, Dict, Any -from datetime import datetime, UTC -from sqlalchemy import desc -from sqlalchemy.exc import IntegrityError - -from src.db.init_db import session_factory -from src.db.schemas.models import CustomAgent, User -from src.utils.logger import Logger -import dspy -from src.agents.agents import custom_agent_instruction_generator - -# Initialize logger with console logging disabled -logger = Logger("custom_agents_routes", see_time=True, console_log=False) - -# Initialize router -router = APIRouter(prefix="/custom_agents", tags=["custom_agents"]) - -# Pydantic models for request/response -class CustomAgentCreate(BaseModel): - agent_name: str = Field(..., min_length=1, max_length=100, description="Unique agent name (e.g., 'pytorch_agent')") - display_name: Optional[str] = Field(None, max_length=200, description="User-friendly display name") - description: str = Field(..., min_length=10, max_length=1000, description="Short description for agent selection") - prompt_template: str = Field(..., min_length=50, description="Main prompt/instructions for agent behavior") - -class CustomAgentUpdate(BaseModel): - display_name: Optional[str] = Field(None, max_length=200) - description: Optional[str] = Field(None, min_length=10, max_length=1000) - prompt_template: Optional[str] = Field(None, min_length=50) - is_active: Optional[bool] = None - -class CustomAgentResponse(BaseModel): - agent_id: int - agent_name: str - display_name: Optional[str] - description: str - prompt_template: str - is_active: bool - usage_count: int - created_at: datetime - updated_at: datetime - -class CustomAgentListResponse(BaseModel): - agent_id: int - agent_name: str - display_name: Optional[str] - description: str - is_active: bool - usage_count: int - created_at: datetime - -class TemplateAgentResponse(BaseModel): - agent_id: int - agent_name: str - display_name: Optional[str] - description: str - prompt_template: str - template_category: str - is_premium_only: bool - is_active: bool - usage_count: int - created_at: datetime - -class TemplatesByCategory(BaseModel): - category: str - templates: List[TemplateAgentResponse] - -class AgentInstructionRequest(BaseModel): - category: str = Field(..., description="The category of the agent: 'Visualization', 'Modelling', or 'Data Manipulation'") - user_requirements: str = Field(..., min_length=10, max_length=2000, description="User's description of what they want the agent to do") - -class AgentInstructionResponse(BaseModel): - agent_instructions: str - category: str - user_requirements: str - -# Routes -@router.post("/", response_model=CustomAgentResponse) -async def create_custom_agent(agent: CustomAgentCreate, user_id: int = Query(...)): - """Create a new custom agent for a user""" - try: - session = session_factory() - - try: - # Validate user exists - user = session.query(User).filter(User.user_id == user_id).first() - if not user: - raise HTTPException(status_code=404, detail="User not found") - - # Create new custom agent - now = datetime.now(UTC) - new_agent = CustomAgent( - user_id=user_id, - agent_name=agent.agent_name.lower().strip(), - display_name=agent.display_name, - description=agent.description, - prompt_template=agent.prompt_template, - is_active=True, - usage_count=0, - created_at=now, - updated_at=now - ) - - session.add(new_agent) - session.commit() - session.refresh(new_agent) - - logger.log_message(f"Created custom agent '{agent.agent_name}' for user {user_id}", level=logging.INFO) - - return CustomAgentResponse( - agent_id=new_agent.agent_id, - agent_name=new_agent.agent_name, - display_name=new_agent.display_name, - description=new_agent.description, - prompt_template=new_agent.prompt_template, - is_active=new_agent.is_active, - usage_count=new_agent.usage_count, - created_at=new_agent.created_at, - updated_at=new_agent.updated_at - ) - - except IntegrityError: - session.rollback() - raise HTTPException(status_code=400, detail=f"Agent name '{agent.agent_name}' already exists, please choose a different name") - except Exception as e: - session.rollback() - logger.log_message(f"Error creating custom agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to create custom agent: {str(e)}") - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error creating custom agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to create custom agent: {str(e)}") - -@router.get("/", response_model=List[CustomAgentListResponse]) -async def get_custom_agents( - user_id: int = Query(...), - active_only: bool = Query(False, description="Return only active agents"), - limit: int = Query(50, ge=1, le=100), - offset: int = Query(0, ge=0) -): - """Get custom agents for a user""" - try: - session = session_factory() - - try: - query = session.query(CustomAgent).filter(CustomAgent.user_id == user_id) - - if active_only: - query = query.filter(CustomAgent.is_active == True) - - # Order by most recently created first - query = query.order_by(desc(CustomAgent.created_at)) - - agents = query.limit(limit).offset(offset).all() - - return [CustomAgentListResponse( - agent_id=agent.agent_id, - agent_name=agent.agent_name, - display_name=agent.display_name, - description=agent.description, - is_active=agent.is_active, - usage_count=agent.usage_count, - created_at=agent.created_at - ) for agent in agents] - - finally: - session.close() - - except Exception as e: - logger.log_message(f"Error retrieving custom agents: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to retrieve custom agents: {str(e)}") - -@router.get("/{agent_id}", response_model=CustomAgentResponse) -async def get_custom_agent(agent_id: int, user_id: int = Query(...)): - """Get a specific custom agent by ID""" - try: - session = session_factory() - - try: - agent = session.query(CustomAgent).filter( - CustomAgent.agent_id == agent_id, - CustomAgent.user_id == user_id - ).first() - - if not agent: - raise HTTPException(status_code=404, detail=f"Custom agent with ID {agent_id} not found") - - return CustomAgentResponse( - agent_id=agent.agent_id, - agent_name=agent.agent_name, - display_name=agent.display_name, - description=agent.description, - prompt_template=agent.prompt_template, - is_active=agent.is_active, - usage_count=agent.usage_count, - created_at=agent.created_at, - updated_at=agent.updated_at - ) - - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error retrieving custom agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to retrieve custom agent: {str(e)}") - -@router.put("/{agent_id}", response_model=CustomAgentResponse) -async def update_custom_agent(agent_id: int, agent_update: CustomAgentUpdate, user_id: int = Query(...)): - """Update a custom agent""" - try: - session = session_factory() - - try: - agent = session.query(CustomAgent).filter( - CustomAgent.agent_id == agent_id, - CustomAgent.user_id == user_id - ).first() - - if not agent: - raise HTTPException(status_code=404, detail=f"Custom agent with ID {agent_id} not found") - - # Update fields if provided - if agent_update.display_name is not None: - agent.display_name = agent_update.display_name - if agent_update.description is not None: - agent.description = agent_update.description - if agent_update.prompt_template is not None: - agent.prompt_template = agent_update.prompt_template - if agent_update.is_active is not None: - agent.is_active = agent_update.is_active - - agent.updated_at = datetime.now(UTC) - session.commit() - session.refresh(agent) - - logger.log_message(f"Updated custom agent {agent_id} for user {user_id}", level=logging.INFO) - - return CustomAgentResponse( - agent_id=agent.agent_id, - agent_name=agent.agent_name, - display_name=agent.display_name, - description=agent.description, - prompt_template=agent.prompt_template, - is_active=agent.is_active, - usage_count=agent.usage_count, - created_at=agent.created_at, - updated_at=agent.updated_at - ) - - except Exception as e: - session.rollback() - logger.log_message(f"Error updating custom agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to update custom agent: {str(e)}") - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error updating custom agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to update custom agent: {str(e)}") - -@router.delete("/{agent_id}") -async def delete_custom_agent(agent_id: int, user_id: int = Query(...)): - """Delete a custom agent""" - try: - session = session_factory() - - try: - agent = session.query(CustomAgent).filter( - CustomAgent.agent_id == agent_id, - CustomAgent.user_id == user_id - ).first() - - if not agent: - raise HTTPException(status_code=404, detail=f"Custom agent with ID {agent_id} not found") - - session.delete(agent) - session.commit() - - logger.log_message(f"Deleted custom agent {agent_id} for user {user_id}", level=logging.INFO) - - return {"message": f"Custom agent {agent_id} deleted successfully"} - - except Exception as e: - session.rollback() - logger.log_message(f"Error deleting custom agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to delete custom agent: {str(e)}") - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error deleting custom agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to delete custom agent: {str(e)}") - -@router.post("/{agent_id}/increment_usage") -async def increment_usage_count(agent_id: int, user_id: int = Query(...)): - """Increment usage count for a custom agent""" - try: - session = session_factory() - - try: - agent = session.query(CustomAgent).filter( - CustomAgent.agent_id == agent_id, - CustomAgent.user_id == user_id - ).first() - - if not agent: - raise HTTPException(status_code=404, detail=f"Custom agent with ID {agent_id} not found") - - agent.usage_count += 1 - agent.updated_at = datetime.now(UTC) - session.commit() - - logger.log_message(f"Incremented usage count for agent {agent_id} (now: {agent.usage_count})", level=logging.INFO) - - return {"message": "Usage count incremented", "usage_count": agent.usage_count} - - except Exception as e: - session.rollback() - logger.log_message(f"Error incrementing usage count: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to increment usage count: {str(e)}") - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error incrementing usage count: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to increment usage count: {str(e)}") - -@router.post("/{agent_id}/toggle_active") -async def toggle_agent_active_status(agent_id: int, user_id: int = Query(...)): - """Toggle active status for a custom agent with 10-agent limit""" - try: - session = session_factory() - - try: - agent = session.query(CustomAgent).filter( - CustomAgent.agent_id == agent_id, - CustomAgent.user_id == user_id, - CustomAgent.is_template == False - ).first() - - if not agent: - raise HTTPException(status_code=404, detail=f"Custom agent with ID {agent_id} not found") - - # If trying to activate an agent, check the 10-agent limit - if not agent.is_active: - # Count currently active custom agents for this user - active_count = session.query(CustomAgent).filter( - CustomAgent.user_id == user_id, - CustomAgent.is_active == True, - CustomAgent.is_template == False - ).count() - - if active_count >= 10: - raise HTTPException( - status_code=400, - detail="You can have at most 10 active custom agents. Please deactivate some agents first." - ) - - # Toggle the active status - agent.is_active = not agent.is_active - agent.updated_at = datetime.now(UTC) - session.commit() - - status_text = "activated" if agent.is_active else "deactivated" - logger.log_message(f"Agent {agent_id} {status_text} for user {user_id}", level=logging.INFO) - - return { - "message": f"Agent {status_text} successfully", - "is_active": agent.is_active, - "agent_id": agent.agent_id - } - - except Exception as e: - session.rollback() - logger.log_message(f"Error toggling agent status: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to toggle agent status: {str(e)}") - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error toggling agent status: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to toggle agent status: {str(e)}") - -@router.get("/active_count/{user_id}") -async def get_active_agents_count(user_id: int): - """Get count of active custom agents for a user""" - try: - session = session_factory() - - try: - active_count = session.query(CustomAgent).filter( - CustomAgent.user_id == user_id, - CustomAgent.is_active == True, - CustomAgent.is_template == False - ).count() - - return { - "active_count": active_count, - "max_allowed": 10, - "can_activate_more": active_count < 10 - } - - finally: - session.close() - - except Exception as e: - logger.log_message(f"Error getting active agents count: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to get active agents count: {str(e)}") - -@router.get("/validate_name/{agent_name}") -async def validate_agent_name(agent_name: str, user_id: int = Query(...)): - """Check if an agent name is available for a user""" - try: - session = session_factory() - - try: - existing_agent = session.query(CustomAgent).filter( - CustomAgent.user_id == user_id, - CustomAgent.agent_name == agent_name.lower().strip() - ).first() - - is_available = existing_agent is None - - return { - "agent_name": agent_name, - "is_available": is_available, - "message": "Agent name is available" if is_available else "Agent name already exists" - } - - finally: - session.close() - - except Exception as e: - logger.log_message(f"Error validating agent name: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to validate agent name: {str(e)}") - -# Template endpoints -@router.get("/templates/", response_model=List[TemplatesByCategory]) -async def get_template_agents(): - """Get all template agents organized by category""" - try: - session = session_factory() - - try: - templates = session.query(CustomAgent).filter( - CustomAgent.is_template == True, - CustomAgent.is_active == True - ).order_by(CustomAgent.template_category, CustomAgent.agent_name).all() - - # Group templates by category - categories = {} - for template in templates: - category = template.template_category or "Other" - if category not in categories: - categories[category] = [] - - categories[category].append(TemplateAgentResponse( - agent_id=template.agent_id, - agent_name=template.agent_name, - display_name=template.display_name, - description=template.description, - prompt_template=template.prompt_template, - template_category=template.template_category, - is_premium_only=template.is_premium_only, - is_active=template.is_active, - usage_count=template.usage_count, - created_at=template.created_at - )) - - # Convert to list of category objects - result = [TemplatesByCategory(category=category, templates=templates) - for category, templates in categories.items()] - - return result - - finally: - session.close() - - except Exception as e: - logger.log_message(f"Error retrieving template agents: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to retrieve template agents: {str(e)}") - -@router.get("/templates/{template_id}", response_model=TemplateAgentResponse) -async def get_template_agent(template_id: int): - """Get a specific template agent by ID""" - try: - session = session_factory() - - try: - template = session.query(CustomAgent).filter( - CustomAgent.agent_id == template_id, - CustomAgent.is_template == True - ).first() - - if not template: - raise HTTPException(status_code=404, detail=f"Template agent with ID {template_id} not found") - - return TemplateAgentResponse( - agent_id=template.agent_id, - agent_name=template.agent_name, - display_name=template.display_name, - description=template.description, - prompt_template=template.prompt_template, - template_category=template.template_category, - is_premium_only=template.is_premium_only, - is_active=template.is_active, - usage_count=template.usage_count, - created_at=template.created_at - ) - - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error retrieving template agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to retrieve template agent: {str(e)}") - -@router.post("/templates/{template_id}/copy", response_model=CustomAgentResponse) -async def copy_template_to_user(template_id: int, user_id: int = Query(...), agent_name: str = Query(...)): - """Copy a template agent to a user's custom agents""" - try: - session = session_factory() - - try: - # Validate user exists - user = session.query(User).filter(User.user_id == user_id).first() - if not user: - raise HTTPException(status_code=404, detail="User not found") - - # Get template agent - template = session.query(CustomAgent).filter( - CustomAgent.agent_id == template_id, - CustomAgent.is_template == True - ).first() - - if not template: - raise HTTPException(status_code=404, detail=f"Template agent with ID {template_id} not found") - - # Check if agent name already exists for user - existing_agent = session.query(CustomAgent).filter( - CustomAgent.user_id == user_id, - CustomAgent.agent_name == agent_name.lower().strip() - ).first() - - if existing_agent: - raise HTTPException(status_code=400, detail=f"Agent name '{agent_name}' already exists, please choose a different name") - - # Create new custom agent from template - now = datetime.now(UTC) - new_agent = CustomAgent( - user_id=user_id, - agent_name=agent_name.lower().strip(), - display_name=template.display_name, - description=template.description, - prompt_template=template.prompt_template, - is_template=False, # This is a user agent, not a template - template_category=None, # User agents don't have template categories - is_premium_only=False, # User agents are not premium-restricted - is_active=True, - usage_count=0, - created_at=now, - updated_at=now - ) - - session.add(new_agent) - session.commit() - session.refresh(new_agent) - - # Increment template usage count - template.usage_count += 1 - session.commit() - - logger.log_message(f"Copied template '{template.agent_name}' to user {user_id} as '{agent_name}'", level=logging.INFO) - - return CustomAgentResponse( - agent_id=new_agent.agent_id, - agent_name=new_agent.agent_name, - display_name=new_agent.display_name, - description=new_agent.description, - prompt_template=new_agent.prompt_template, - is_active=new_agent.is_active, - usage_count=new_agent.usage_count, - created_at=new_agent.created_at, - updated_at=new_agent.updated_at - ) - - except IntegrityError: - session.rollback() - raise HTTPException(status_code=400, detail=f"Agent name '{agent_name}' already exists, please choose a different name") - except Exception as e: - session.rollback() - logger.log_message(f"Error copying template to user: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to copy template: {str(e)}") - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error copying template to user: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to copy template: {str(e)}") - -@router.post("/templates/{template_id}/increment_usage") -async def increment_template_usage_count(template_id: int): - """Increment usage count for a template agent""" - try: - session = session_factory() - - try: - template = session.query(CustomAgent).filter( - CustomAgent.agent_id == template_id, - CustomAgent.is_template == True - ).first() - - if not template: - raise HTTPException(status_code=404, detail=f"Template agent with ID {template_id} not found") - - template.usage_count += 1 - template.updated_at = datetime.now(UTC) - session.commit() - - logger.log_message(f"Incremented template usage count for template {template_id} (now: {template.usage_count})", level=logging.INFO) - - return {"message": "Template usage count incremented", "usage_count": template.usage_count} - - except Exception as e: - session.rollback() - logger.log_message(f"Error incrementing template usage count: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to increment template usage count: {str(e)}") - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error incrementing template usage count: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to increment template usage count: {str(e)}") - -@router.post("/generate_instructions", response_model=AgentInstructionResponse) -async def generate_agent_instructions(request: AgentInstructionRequest): - """Generate agent instructions using the custom_agent_instruction_generator""" - try: - # Validate category - valid_categories = ["Visualization", "Modelling", "Data Manipulation"] - if request.category not in valid_categories: - raise HTTPException( - status_code=400, - detail=f"Invalid category. Must be one of: {', '.join(valid_categories)}" - ) - - default_lm = dspy.LM( - model="openai/gpt-4o-mini", - api_key=os.getenv("OPENAI_API_KEY"), - temperature=0.7, - max_tokens=7000 - ) - - instruction_generator = dspy.ChainOfThought(custom_agent_instruction_generator) - - with dspy.context(lm=default_lm): - result = instruction_generator( - category=request.category, - user_requirements=request.user_requirements - ) - - logger.log_message(f"Generated agent instructions for category: {request.category}", level=logging.INFO) - - return AgentInstructionResponse( - agent_instructions=result.agent_instructions, - category=request.category, - user_requirements=request.user_requirements - ) - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error generating agent instructions: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to generate agent instructions: {str(e)}") \ No newline at end of file diff --git a/auto-analyst-backend/src/routes/templates_routes.py b/auto-analyst-backend/src/routes/templates_routes.py new file mode 100644 index 00000000..1b5148e1 --- /dev/null +++ b/auto-analyst-backend/src/routes/templates_routes.py @@ -0,0 +1,478 @@ +import logging +import os +from fastapi import APIRouter, Depends, HTTPException, Query, Body +from pydantic import BaseModel, Field +from typing import List, Optional, Dict, Any +from datetime import datetime, UTC +from sqlalchemy import desc +from sqlalchemy.exc import IntegrityError + +from src.db.init_db import session_factory +from src.db.schemas.models import AgentTemplate, User, UserTemplatePreference +from src.utils.logger import Logger +from src.agents.agents import get_all_available_templates, toggle_user_template_preference + +# Initialize logger with console logging disabled +logger = Logger("templates_routes", see_time=True, console_log=False) + +# Initialize router +router = APIRouter(prefix="/templates", tags=["templates"]) + +# Pydantic models for request/response +class TemplateResponse(BaseModel): + template_id: int + template_name: str + display_name: Optional[str] + description: str + prompt_template: str + template_category: Optional[str] + is_premium_only: bool + is_active: bool + usage_count: int + created_at: datetime + updated_at: datetime + +class UserTemplatePreferenceResponse(BaseModel): + template_id: int + template_name: str + display_name: Optional[str] + description: str + template_category: Optional[str] + is_premium_only: bool + is_enabled: bool + usage_count: int + last_used_at: Optional[datetime] + +class ToggleTemplateRequest(BaseModel): + is_enabled: bool = Field(..., description="Whether to enable or disable the template") + +# Routes +@router.get("/", response_model=List[TemplateResponse]) +async def get_all_templates(): + """Get all available agent templates""" + try: + session = session_factory() + + try: + templates = get_all_available_templates(session) + + return [TemplateResponse( + template_id=template.template_id, + template_name=template.template_name, + display_name=template.display_name, + description=template.description, + prompt_template=template.prompt_template, + template_category=template.category, + is_premium_only=template.is_premium_only, + is_active=template.is_active, + usage_count=0, # Templates don't track global usage count anymore + created_at=template.created_at, + updated_at=template.updated_at + ) for template in templates] + + finally: + session.close() + + except Exception as e: + logger.log_message(f"Error retrieving templates: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve templates: {str(e)}") + +@router.get("/user/{user_id}", response_model=List[UserTemplatePreferenceResponse]) +async def get_user_template_preferences(user_id: int): + """Get all templates with user preferences (enabled/disabled status and usage)""" + try: + session = session_factory() + + try: + # Validate user exists + user = session.query(User).filter(User.user_id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Get all active templates + templates = session.query(AgentTemplate).filter( + AgentTemplate.is_active == True + ).all() + + result = [] + for template in templates: + # Get user preference for this template if it exists + preference = session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.template_id == template.template_id + ).first() + + result.append(UserTemplatePreferenceResponse( + template_id=template.template_id, + template_name=template.template_name, + display_name=template.display_name, + description=template.description, + template_category=template.category, + is_premium_only=template.is_premium_only, + is_enabled=preference.is_enabled if preference else False, # Default to disabled + usage_count=preference.usage_count if preference else 0, + last_used_at=preference.last_used_at if preference else None + )) + + return result + + finally: + session.close() + + except HTTPException: + raise + except Exception as e: + logger.log_message(f"Error retrieving user template preferences: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve user template preferences: {str(e)}") + +@router.get("/user/{user_id}/enabled", response_model=List[UserTemplatePreferenceResponse]) +async def get_user_enabled_templates(user_id: int): + """Get only templates that are enabled for the user (all templates enabled by default)""" + try: + session = session_factory() + + try: + # Validate user exists + user = session.query(User).filter(User.user_id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Get all active templates + all_templates = session.query(AgentTemplate).filter( + AgentTemplate.is_active == True + ).all() + + result = [] + for template in all_templates: + # Check if user has a preference record for this template + preference = session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.template_id == template.template_id + ).first() + + # Template is disabled by default unless explicitly enabled + is_enabled = preference.is_enabled if preference else False + + if is_enabled: + result.append(UserTemplatePreferenceResponse( + template_id=template.template_id, + template_name=template.template_name, + display_name=template.display_name, + description=template.description, + template_category=template.category, + is_premium_only=template.is_premium_only, + is_enabled=True, + usage_count=preference.usage_count if preference else 0, + last_used_at=preference.last_used_at if preference else None + )) + + return result + + finally: + session.close() + + except HTTPException: + raise + except Exception as e: + logger.log_message(f"Error retrieving user enabled templates: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve user enabled templates: {str(e)}") + +@router.get("/user/{user_id}/enabled/planner", response_model=List[UserTemplatePreferenceResponse]) +async def get_user_enabled_templates_for_planner(user_id: int): + """Get enabled templates for planner use (max 10 templates)""" + try: + session = session_factory() + + try: + # Validate user exists + user = session.query(User).filter(User.user_id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Get enabled templates ordered by usage (most used first) and limit to 10 + enabled_preferences = session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.is_enabled == True + ).order_by( + UserTemplatePreference.usage_count.desc(), + UserTemplatePreference.last_used_at.desc() + ).limit(10).all() + + result = [] + for preference in enabled_preferences: + # Get template details + template = session.query(AgentTemplate).filter( + AgentTemplate.template_id == preference.template_id, + AgentTemplate.is_active == True + ).first() + + if template: + result.append(UserTemplatePreferenceResponse( + template_id=template.template_id, + template_name=template.template_name, + display_name=template.display_name, + description=template.description, + template_category=template.category, + is_premium_only=template.is_premium_only, + is_enabled=True, + usage_count=preference.usage_count, + last_used_at=preference.last_used_at + )) + + logger.log_message(f"Retrieved {len(result)} enabled templates for planner for user {user_id}", level=logging.INFO) + return result + + finally: + session.close() + + except HTTPException: + raise + except Exception as e: + logger.log_message(f"Error retrieving planner templates for user {user_id}: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve planner templates: {str(e)}") + +@router.post("/user/{user_id}/template/{template_id}/toggle") +async def toggle_template_preference(user_id: int, template_id: int, request: ToggleTemplateRequest): + """Toggle a user's template preference (enable/disable for planner use)""" + try: + session = session_factory() + + try: + # Validate user exists + user = session.query(User).filter(User.user_id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + success, message = toggle_user_template_preference( + user_id, template_id, request.is_enabled, session + ) + + if not success: + raise HTTPException(status_code=400, detail=message) + + logger.log_message(f"Toggled template {template_id} for user {user_id}: {message}", level=logging.INFO) + + return {"message": message} + + finally: + session.close() + + except HTTPException: + raise + except Exception as e: + logger.log_message(f"Error toggling template preference: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to toggle template preference: {str(e)}") + +@router.post("/user/{user_id}/bulk-toggle") +async def bulk_toggle_template_preferences(user_id: int, request: dict): + """Bulk toggle multiple template preferences""" + try: + session = session_factory() + + try: + # Validate user exists + user = session.query(User).filter(User.user_id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + template_preferences = request.get("preferences", []) + if not template_preferences: + raise HTTPException(status_code=400, detail="No preferences provided") + + # Check current enabled count for limit enforcement + current_enabled_count = session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.is_enabled == True + ).count() + + # Count how many templates we're trying to enable + enabling_count = sum(1 for pref in template_preferences if pref.get("is_enabled", False)) + disabling_count = sum(1 for pref in template_preferences if not pref.get("is_enabled", False)) + + # Calculate what the new count would be + projected_enabled_count = current_enabled_count + enabling_count - disabling_count + + results = [] + for pref in template_preferences: + template_id = pref.get("template_id") + is_enabled = pref.get("is_enabled", True) + + if template_id is None: + results.append({"template_id": None, "success": False, "message": "Template ID required"}) + continue + + # Check 10-template limit for enabling + if is_enabled and projected_enabled_count > 10: + results.append({ + "template_id": template_id, + "success": False, + "message": "Cannot enable more than 10 templates for planner use", + "is_enabled": False + }) + continue + + success, message = toggle_user_template_preference( + user_id, template_id, is_enabled, session + ) + + results.append({ + "template_id": template_id, + "success": success, + "message": message, + "is_enabled": is_enabled + }) + + logger.log_message(f"Bulk toggled {len(template_preferences)} templates for user {user_id}", level=logging.INFO) + + return {"results": results} + + finally: + session.close() + + except HTTPException: + raise + except Exception as e: + logger.log_message(f"Error bulk toggling template preferences: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to bulk toggle template preferences: {str(e)}") + +@router.get("/template/{template_id}", response_model=TemplateResponse) +async def get_template(template_id: int): + """Get a specific template by ID""" + try: + session = session_factory() + + try: + template = session.query(AgentTemplate).filter( + AgentTemplate.template_id == template_id + ).first() + + if not template: + raise HTTPException(status_code=404, detail=f"Template with ID {template_id} not found") + + return TemplateResponse( + template_id=template.template_id, + template_name=template.template_name, + display_name=template.display_name, + description=template.description, + prompt_template=template.prompt_template, + template_category=template.category, + is_premium_only=template.is_premium_only, + is_active=template.is_active, + usage_count=0, # Templates don't track global usage count anymore + created_at=template.created_at, + updated_at=template.updated_at + ) + + finally: + session.close() + + except HTTPException: + raise + except Exception as e: + logger.log_message(f"Error retrieving template: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve template: {str(e)}") + +@router.get("/categories/list") +async def get_template_categories(): + """Get list of all template categories""" + try: + session = session_factory() + + try: + categories = session.query(AgentTemplate.category).filter( + AgentTemplate.is_active == True, + AgentTemplate.category.isnot(None) + ).distinct().all() + + category_list = [category[0] for category in categories if category[0]] + + return {"categories": category_list} + + finally: + session.close() + + except Exception as e: + logger.log_message(f"Error retrieving template categories: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve template categories: {str(e)}") + +@router.get("/categories") +async def get_templates_by_categories(): + """Get all templates grouped by category for frontend template browser""" + try: + session = session_factory() + + try: + # Get all active templates + templates = session.query(AgentTemplate).filter( + AgentTemplate.is_active == True + ).order_by(AgentTemplate.category, AgentTemplate.template_name).all() + + # Group templates by category + categories_dict = {} + for template in templates: + category = template.category or "Uncategorized" + if category not in categories_dict: + categories_dict[category] = [] + + categories_dict[category].append({ + "agent_id": template.template_id, # Use template_id as agent_id for compatibility + "agent_name": template.template_name, + "display_name": template.display_name or template.template_name, + "description": template.description, + "prompt_template": template.prompt_template, + "template_category": template.category, + "is_premium_only": template.is_premium_only, + "is_active": template.is_active, + "usage_count": 0, # Templates don't track global usage count anymore + "created_at": template.created_at.isoformat() if template.created_at else None + }) + + # Convert to list format expected by frontend + result = [] + for category, templates in categories_dict.items(): + result.append({ + "category": category, + "templates": templates + }) + + return result + + finally: + session.close() + + except Exception as e: + logger.log_message(f"Error retrieving templates by categories: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve templates by categories: {str(e)}") + +@router.get("/category/{category}") +async def get_templates_by_category(category: str): + """Get all templates in a specific category""" + try: + session = session_factory() + + try: + templates = session.query(AgentTemplate).filter( + AgentTemplate.is_active == True, + AgentTemplate.category == category + ).all() + + return [TemplateResponse( + template_id=template.template_id, + template_name=template.template_name, + display_name=template.display_name, + description=template.description, + prompt_template=template.prompt_template, + template_category=template.category, + is_premium_only=template.is_premium_only, + is_active=template.is_active, + usage_count=0, # Templates don't track global usage count anymore + created_at=template.created_at, + updated_at=template.updated_at + ) for template in templates] + + finally: + session.close() + + except Exception as e: + logger.log_message(f"Error retrieving templates by category: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve templates by category: {str(e)}") \ No newline at end of file diff --git a/auto-analyst-backend/src/utils/model_registry.py b/auto-analyst-backend/src/utils/model_registry.py index 6f1186a1..1d8ac1f3 100644 --- a/auto-analyst-backend/src/utils/model_registry.py +++ b/auto-analyst-backend/src/utils/model_registry.py @@ -23,7 +23,7 @@ "o1": {"input": 0.015, "output": 0.06}, "o1-pro": {"input": 0.015, "output": 0.6}, "o1-mini": {"input": 0.00011, "output": 0.00044}, - "o3": {"input": 0.001, "output": 0.04}, + "o3": {"input": 0.002, "output": 0.008}, "o3-mini": {"input": 0.00011, "output": 0.00044}, "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015}, }, diff --git a/auto-analyst-frontend/components/chat/AgentSuggestions.tsx b/auto-analyst-frontend/components/chat/AgentSuggestions.tsx index 76682e30..67ff7f7b 100644 --- a/auto-analyst-frontend/components/chat/AgentSuggestions.tsx +++ b/auto-analyst-frontend/components/chat/AgentSuggestions.tsx @@ -105,14 +105,7 @@ export default function AgentSuggestions({ allAgents.push(...templateAgents) } - // Add custom agents (only for users with custom agents access) - if (data.custom_agents && data.custom_agents.length > 0 && customAgentsAccess.hasAccess) { - console.log('fetchAllAgents - found custom agents, fetching details...'); - // Fetch custom agent details - const customAgents = await fetchCustomAgents() - console.log('fetchAllAgents - custom agent details:', customAgents); - allAgents.push(...customAgents) - } + // Custom agents are deprecated - using templates only console.log('fetchAllAgents - final allAgents:', allAgents); return allAgents @@ -128,38 +121,12 @@ export default function AgentSuggestions({ } } - // Fetch custom agents - const fetchCustomAgents = async (): Promise => { - const currentUserId = getUserId() - if (!currentUserId) { - return [] - } - try { - const customAgentsUrl = `${API_URL}/custom_agents/?user_id=${currentUserId}` - const response = await fetch(customAgentsUrl) - - if (response.ok) { - const customAgents = await response.json() - const mappedAgents = customAgents.map((agent: any) => ({ - name: agent.agent_name, - description: agent.description, - isCustom: true - })) - return mappedAgents - } else { - console.error('Failed to fetch custom agents:', response.status, await response.text()) - } - } catch (error) { - console.error('Error fetching custom agents:', error) - } - return [] - } // Fetch template agents const fetchTemplateAgents = async (): Promise => { try { - const templatesUrl = `${API_URL}/custom_agents/templates/` + const templatesUrl = `${API_URL}/templates/categories` console.log('fetchTemplateAgents - templatesUrl:', templatesUrl); const response = await fetch(templatesUrl) diff --git a/auto-analyst-frontend/components/chat/ChatInput.tsx b/auto-analyst-frontend/components/chat/ChatInput.tsx index ca1c0996..589bc259 100644 --- a/auto-analyst-frontend/components/chat/ChatInput.tsx +++ b/auto-analyst-frontend/components/chat/ChatInput.tsx @@ -36,7 +36,7 @@ import { // Deep Analysis imports import { DeepAnalysisSidebar, DeepAnalysisButton } from '../deep-analysis' // Custom Agents imports -import { CustomAgentsSidebar, CustomAgentsButton } from '../custom-agents' +import { TemplatesSidebar, TemplatesButton } from '../custom-agents' import CommandSuggestions from './CommandSuggestions' import AgentSuggestions from './AgentSuggestions' import { useUserSubscriptionStore } from '@/lib/store/userSubscriptionStore' @@ -225,15 +225,15 @@ const ChatInput = forwardRef< const [shouldForceExpanded, setShouldForceExpanded] = useState(false) // Custom Agents states - const [showCustomAgentsSidebar, setShowCustomAgentsSidebar] = useState(false) - const [shouldForceExpandedCustomAgents, setShouldForceExpandedCustomAgents] = useState(false) + const [showTemplatesSidebar, setShowTemplatesSidebar] = useState(false) + const [shouldForceExpandedTemplates, setShouldForceExpandedTemplates] = useState(false) const [showCommandSuggestions, setShowCommandSuggestions] = useState(false) const [commandQuery, setCommandQuery] = useState('') // Get subscription from store instead of manual construction const { subscription } = useUserSubscriptionStore() const deepAnalysisAccess = useFeatureAccess('DEEP_ANALYSIS', subscription) - + // Expose handlePreviewDefaultDataset to parent useImperativeHandle(ref, () => ({ handlePreviewDefaultDataset, @@ -939,11 +939,11 @@ const ChatInput = forwardRef< setTimeout(() => setShouldForceExpanded(false), 100) } else if (command.id === 'custom-agents') { // Show custom agents sidebar in expanded state - setShouldForceExpandedCustomAgents(true) - setShowCustomAgentsSidebar(true) + setShouldForceExpanded(true) + setShowTemplatesSidebar(true) setMessage('') // Reset force expanded after a brief moment - setTimeout(() => setShouldForceExpandedCustomAgents(false), 100) + setTimeout(() => setShouldForceExpanded(false), 100) } else { // For other commands, replace the "/" with the command setMessage(`${command.name} `) @@ -1945,12 +1945,12 @@ const ChatInput = forwardRef< )} - { - setShouldForceExpandedCustomAgents(true) - setShowCustomAgentsSidebar(true) + setShouldForceExpanded(true) + setShowTemplatesSidebar(true) // Reset force expanded after a brief moment - setTimeout(() => setShouldForceExpandedCustomAgents(false), 100) + setTimeout(() => setShouldForceExpanded(false), 100) }} userProfile={subscription} showLabel={true} @@ -2423,12 +2423,12 @@ const ChatInput = forwardRef< forceExpanded={shouldForceExpanded} /> - {/* Custom Agents Sidebar */} - setShowCustomAgentsSidebar(false)} + {/* Templates Sidebar */} + setShowTemplatesSidebar(false)} userId={userId} - forceExpanded={shouldForceExpandedCustomAgents} + forceExpanded={shouldForceExpanded} /> ) diff --git a/auto-analyst-frontend/components/custom-agents/AgentDetailView.tsx b/auto-analyst-frontend/components/custom-agents/AgentDetailView.tsx deleted file mode 100644 index e2b08890..00000000 --- a/auto-analyst-frontend/components/custom-agents/AgentDetailView.tsx +++ /dev/null @@ -1,417 +0,0 @@ -'use client'; - -import React, { useState, useEffect } from 'react'; -import { Button } from '@/components/ui/button'; -import { Input } from '@/components/ui/input'; -import { Textarea } from '@/components/ui/textarea'; -import { Label } from '@/components/ui/label'; -import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; -import { Badge } from '@/components/ui/badge'; -import { Switch } from '@/components/ui/switch'; -import { - ArrowLeft, - Edit, - Save, - X, - Calendar, - TrendingUp, - Settings2, - Trash2, - AlertCircle, - CheckCircle -} from 'lucide-react'; -import { CustomAgent } from './types'; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, - AlertDialogTrigger, -} from '@/components/ui/alert-dialog'; - -interface AgentDetailViewProps { - agent: CustomAgent | null; - onUpdateAgent: (agentId: number, updates: Partial) => Promise; - onDeleteAgent: (agentId: number) => Promise; - onToggleActive?: (agentId: number) => Promise; - onBack: () => void; -} - -export default function AgentDetailView({ - agent, - onUpdateAgent, - onDeleteAgent, - onToggleActive, - onBack -}: AgentDetailViewProps) { - const [isEditing, setIsEditing] = useState(false); - const [editedAgent, setEditedAgent] = useState>({}); - const [isUpdating, setIsUpdating] = useState(false); - const [isDeleting, setIsDeleting] = useState(false); - - useEffect(() => { - if (agent) { - setEditedAgent({ - display_name: agent.display_name, - description: agent.description, - prompt_template: agent.prompt_template, - is_active: agent.is_active - }); - } - }, [agent]); - - if (!agent) { - return ( -
- -

No Agent Selected

-

- Select an agent from the list to view and edit its details. -

-
- ); - } - - const handleEdit = () => { - setIsEditing(true); - }; - - const handleCancelEdit = () => { - setIsEditing(false); - setEditedAgent({ - display_name: agent.display_name, - description: agent.description, - prompt_template: agent.prompt_template, - is_active: agent.is_active - }); - }; - - const handleSave = async () => { - setIsUpdating(true); - try { - const success = await onUpdateAgent(agent.agent_id, editedAgent); - if (success) { - setIsEditing(false); - } - } finally { - setIsUpdating(false); - } - }; - - const handleDelete = async () => { - setIsDeleting(true); - try { - const success = await onDeleteAgent(agent.agent_id); - if (success) { - onBack(); - } - } finally { - setIsDeleting(false); - } - }; - - const handleInputChange = (field: keyof CustomAgent, value: string | boolean) => { - setEditedAgent(prev => ({ ...prev, [field]: value })); - }; - - const formatDate = (dateString: string) => { - return new Date(dateString).toLocaleDateString('en-US', { - month: 'long', - day: 'numeric', - year: 'numeric', - hour: '2-digit', - minute: '2-digit' - }); - }; - - const canSave = () => { - return ( - (editedAgent.description?.length || 0) >= 10 && - (editedAgent.prompt_template?.length || 0) >= 50 - ); - }; - - return ( -
- {/* Header */} -
-
- -
-

- {agent.display_name || agent.agent_name} -

-
- - @{agent.agent_name} - -
-
-
- {agent.is_active ? 'Active' : 'Inactive'} -
- {!isEditing && onToggleActive && ( - { - onToggleActive(agent.agent_id); - }} - className="data-[state=checked]:bg-[#FF7F7F] scale-75" - /> - )} -
-
-
-
- -
- {isEditing ? ( - <> - - - - ) : ( - <> - - - - - - - - - Delete Custom Agent - - Are you sure you want to delete "{agent.display_name || agent.agent_name}"? - This action cannot be undone and the agent will no longer be available for use. - - - - Cancel - - {isDeleting ? 'Deleting...' : 'Delete'} - - - - - - )} -
-
- - {/* Content */} -
- {/* Metadata Card */} - - - Agent Metadata - - -
-
- -
-

Created

-

{formatDate(agent.created_at)}

-
-
-
- -
-

Usage Count

-

{agent.usage_count} times

-
-
-
- - {agent.updated_at !== agent.created_at && ( -
- Last updated: {formatDate(agent.updated_at)} -
- )} -
-
- - {/* Agent Configuration */} - - - Configuration - - Customize your agent's behavior and availability - - - - {/* Active Status */} -
-
- -

- When active, this agent can be used in conversations -

-
- handleInputChange('is_active', value)} - disabled={!isEditing} - className="flex-shrink-0" - /> -
- - {/* Display Name */} -
- - {isEditing ? ( - handleInputChange('display_name', e.target.value)} - placeholder="User-friendly name for your agent" - className="mt-1 text-sm" - /> - ) : ( -
- {agent.display_name || No display name set} -
- )} -
- - {/* Description */} -
- - {isEditing ? ( - <> -