diff --git a/auto-analyst-backend/app.py b/auto-analyst-backend/app.py index 929d1c2d..6ecee146 100644 --- a/auto-analyst-backend/app.py +++ b/auto-analyst-backend/app.py @@ -42,7 +42,7 @@ from src.routes.feedback_routes import router as feedback_router from src.routes.session_routes import router as session_router, get_session_id_dependency from src.routes.deep_analysis_routes import router as deep_analysis_router -from src.routes.custom_agents_routes import router as custom_agents_router +from src.routes.templates_routes import router as templates_router from src.schemas.query_schemas import QueryRequest from src.utils.logger import Logger @@ -400,7 +400,7 @@ async def verify_origin_middleware(request: Request, call_next): RESPONSE_ERROR_INVALID_QUERY = "Please provide a valid query..." RESPONSE_ERROR_NO_DATASET = "No dataset is currently loaded. Please link a dataset before proceeding with your analysis." DEFAULT_TOKEN_RATIO = 1.5 -REQUEST_TIMEOUT_SECONDS = 60 # Timeout for LLM requests +REQUEST_TIMEOUT_SECONDS = 120 # Timeout for LLM requests MAX_RECENT_MESSAGES = 3 DB_BATCH_SIZE = 10 # For future batch DB operations @@ -428,19 +428,20 @@ async def chat_with_agent( # Get chat context and prepare query enhanced_query = _prepare_query_with_context(request.query, session_state) - logger.log_message(f"Enhanced query: {enhanced_query}", level=logging.INFO) - # Initialize agent - handle both standard and custom agents + # Initialize agent - handle standard, template, and custom agents if "," in agent_name: # Multiple agents case agent_list = [agent.strip() for agent in agent_name.split(",")] - # Check if any are custom agents - has_custom_agents = any(not _is_standard_agent(agent) for agent in agent_list) - logger.log_message(f"Has custom agents: {has_custom_agents}", level=logging.INFO) + # Categorize agents + standard_agents = [agent for agent in agent_list if _is_standard_agent(agent)] + template_agents = [agent for agent in agent_list if _is_template_agent(agent)] + custom_agents = [agent for agent in agent_list if not _is_standard_agent(agent) and not _is_template_agent(agent)] - if has_custom_agents: - # Use session AI system for mixed or custom agent execution + + if custom_agents: + # If any custom agents, use session AI system for all ai_system = session_state["ai_system"] session_lm = get_session_lm(session_state) with dspy.context(lm=session_lm): @@ -449,32 +450,61 @@ async def chat_with_agent( timeout=REQUEST_TIMEOUT_SECONDS ) else: - # All standard agents - use auto_analyst_ind - standard_agent_sigs = [AVAILABLE_AGENTS[agent] for agent in agent_list] + # All standard/template agents - use auto_analyst_ind + standard_agent_sigs = [AVAILABLE_AGENTS[agent] for agent in standard_agents] user_id = session_state.get("user_id") - agent = auto_analyst_ind(agents=standard_agent_sigs, retrievers=session_state["retrievers"], user_id=user_id) - session_lm = get_session_lm(session_state) - with dspy.context(lm=session_lm): - response = await asyncio.wait_for( - agent.forward(enhanced_query, ",".join(agent_list)), - timeout=REQUEST_TIMEOUT_SECONDS - ) + + # Create database session for template loading + from src.db.init_db import session_factory + db_session = session_factory() + try: + agent = auto_analyst_ind(agents=standard_agent_sigs, retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session) + session_lm = get_session_lm(session_state) + with dspy.context(lm=session_lm): + response = await asyncio.wait_for( + agent.forward(enhanced_query, ",".join(agent_list)), + timeout=REQUEST_TIMEOUT_SECONDS + ) + finally: + db_session.close() else: # Single agent case - logger.log_message(f"Single agent case: {agent_name}", level=logging.INFO) if _is_standard_agent(agent_name): # Standard agent - use auto_analyst_ind user_id = session_state.get("user_id") - agent = auto_analyst_ind(agents=[AVAILABLE_AGENTS[agent_name]], retrievers=session_state["retrievers"], user_id=user_id) - session_lm = get_session_lm(session_state) - with dspy.context(lm=session_lm): - response = await asyncio.wait_for( - agent.forward(enhanced_query, agent_name), - timeout=REQUEST_TIMEOUT_SECONDS - ) + + # Create database session for template loading + from src.db.init_db import session_factory + db_session = session_factory() + try: + agent = auto_analyst_ind(agents=[AVAILABLE_AGENTS[agent_name]], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session) + session_lm = get_session_lm(session_state) + with dspy.context(lm=session_lm): + response = await asyncio.wait_for( + agent.forward(enhanced_query, agent_name), + timeout=REQUEST_TIMEOUT_SECONDS + ) + finally: + db_session.close() + elif _is_template_agent(agent_name): + # Template agent - use auto_analyst_ind with empty agents list (templates loaded in init) + user_id = session_state.get("user_id") + + # Create database session for template loading + from src.db.init_db import session_factory + db_session = session_factory() + try: + agent = auto_analyst_ind(agents=[], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session) + session_lm = get_session_lm(session_state) + with dspy.context(lm=session_lm): + response = await asyncio.wait_for( + agent.forward(enhanced_query, agent_name), + timeout=REQUEST_TIMEOUT_SECONDS + ) + finally: + db_session.close() else: # Custom agent - use session AI system - logger.log_message(f"Custom agent case: {agent_name}", level=logging.INFO) ai_system = session_state["ai_system"] session_lm = get_session_lm(session_state) with dspy.context(lm=session_lm): @@ -484,7 +514,6 @@ async def chat_with_agent( ) formatted_response = format_response_to_markdown(response, agent_name, session_state["current_df"]) - logger.log_message(f"Formatted response: {formatted_response}", level=logging.INFO) if formatted_response == RESPONSE_ERROR_INVALID_QUERY: return { @@ -513,10 +542,8 @@ async def chat_with_agent( # Re-raise HTTP exceptions to preserve status codes raise except asyncio.TimeoutError: - logger.log_message(f"Agent execution timed out for {agent_name}", level=logging.WARNING) raise HTTPException(status_code=504, detail="Request timed out. Please try a simpler query.") except Exception as e: - logger.log_message(f"Unexpected error in chat_with_agent: {str(e)}", level=logging.ERROR) raise HTTPException(status_code=500, detail="An unexpected error occurred. Please try again later.") @@ -558,7 +585,6 @@ async def chat_with_all( # Re-raise HTTP exceptions to preserve status codes raise except Exception as e: - logger.log_message(f"Unexpected error in chat_with_all: {str(e)}", level=logging.ERROR) raise HTTPException(status_code=500, detail="An unexpected error occurred. Please try again later.") @@ -606,11 +632,29 @@ def _validate_agent_name(agent_name: str, session_state: dict = None): ) def _is_agent_available(agent_name: str, session_state: dict = None) -> bool: - """Check if agent is available in either standard agents or user's custom agents""" + """Check if agent is available in either standard agents, template agents, or user's custom agents""" # Check standard agents if agent_name in AVAILABLE_AGENTS: return True + # Check template agents + try: + from src.db.init_db import session_factory + from src.db.schemas.models import AgentTemplate + + db_session = session_factory() + try: + template = db_session.query(AgentTemplate).filter( + AgentTemplate.template_name == agent_name, + AgentTemplate.is_active == True + ).first() + if template: + return True + finally: + db_session.close() + except Exception as e: + logger.log_message(f"Error checking template availability for {agent_name}: {str(e)}", level=logging.ERROR) + # Check custom agents if session has an AI system with custom agents if session_state and "ai_system" in session_state: ai_system = session_state["ai_system"] @@ -634,9 +678,28 @@ def _get_available_agents_list(session_state: dict = None) -> list: return available def _is_standard_agent(agent_name: str) -> bool: - """Check if agent is a standard agent (not custom)""" + """Check if agent is a standard agent (not custom or template)""" return agent_name in AVAILABLE_AGENTS +def _is_template_agent(agent_name: str) -> bool: + """Check if agent is a template agent""" + try: + from src.db.init_db import session_factory + from src.db.schemas.models import AgentTemplate + + db_session = session_factory() + try: + template = db_session.query(AgentTemplate).filter( + AgentTemplate.template_name == agent_name, + AgentTemplate.is_active == True + ).first() + return template is not None + finally: + db_session.close() + except Exception as e: + logger.log_message(f"Error checking if {agent_name} is template: {str(e)}", level=logging.ERROR) + return False + async def _execute_custom_agents(ai_system, agent_names: list, query: str): """Execute custom agents using the session's AI system""" try: @@ -644,8 +707,6 @@ async def _execute_custom_agents(ai_system, agent_names: list, query: str): if len(agent_names) == 1: # Single custom agent agent_name = agent_names[0] - logger.log_message(f"Executing custom agent: {agent_name}", level=logging.INFO) - # Prepare inputs for the custom agent (similar to standard agents like data_viz_agent) dict_ = {} dict_['dataset'] = ai_system.dataset.retrieve(query)[0].text @@ -657,14 +718,11 @@ async def _execute_custom_agents(ai_system, agent_names: list, query: str): if agent_name in ai_system.agent_inputs: inputs = {x: dict_[x] for x in ai_system.agent_inputs[agent_name] if x in dict_} - logger.log_message(f"Inputs for {agent_name}: {list(inputs.keys())}", level=logging.INFO) - # Execute the custom agent agent_name_result, result_dict = await ai_system.execute_agent(agent_name, inputs) - logger.log_message(f"Custom agent result: {agent_name_result}, has keys: {list(result_dict.keys()) if isinstance(result_dict, dict) else 'not dict'}", level=logging.INFO) return {agent_name_result: result_dict} else: - logger.log_message(f"Agent '{agent_name}' not found in ai_system.agent_inputs. Available: {list(ai_system.agent_inputs.keys())}", level=logging.ERROR) + logger.log_message(f"Agent '{agent_name}' not found in ai_system.agent_inputs", level=logging.ERROR) return {"error": f"Agent '{agent_name}' input configuration not found"} else: # Multiple agents - execute sequentially @@ -952,7 +1010,6 @@ async def list_agents(request: Request, session_id: str = Depends(get_session_id app.state.set_session_user(session_id, user_id) # Refresh session state after user association session_state = app.state.get_session_state(session_id) - logger.log_message(f"Associated session {session_id} with user {user_id} for agent listing", level=logging.INFO) except (ValueError, TypeError): logger.log_message(f"Invalid user_id in agents endpoint: {user_id_param}", level=logging.WARNING) @@ -965,18 +1022,16 @@ async def list_agents(request: Request, session_id: str = Depends(get_session_id template_agents = [] try: from src.db.init_db import session_factory - from src.db.schemas.models import CustomAgent + from src.db.schemas.models import AgentTemplate db_session = session_factory() try: - templates = db_session.query(CustomAgent).filter( - CustomAgent.is_template == True, - CustomAgent.is_active == True, - CustomAgent.user_id == None # System templates + templates = db_session.query(AgentTemplate).filter( + AgentTemplate.is_active == True ).all() - template_agents = [template.agent_name for template in templates] - logger.log_message(f"Found {len(template_agents)} template agents", level=logging.INFO) + template_agents = [template.template_name for template in templates] + logger.log_message(f"Found {len(template_agents)} template agents", level=logging.DEBUG) finally: db_session.close() @@ -1195,7 +1250,6 @@ async def update_report_in_db(status, progress, step=None, content=None): if step == "completed": if content: report.html_report = content - logger.log_message(f"Storing HTML report in database, length: {len(content)}", level=logging.INFO) else: logger.log_message("No HTML content provided for completed step", level=logging.WARNING) @@ -1286,9 +1340,7 @@ async def update_report_in_db(status, progress, step=None, content=None): # Generate HTML report using the original final_result with Figure objects html_report = None try: - logger.log_message("Generating HTML report...", level=logging.INFO) html_report = generate_html_report(final_result) - logger.log_message(f"HTML report generated successfully, length: {len(html_report) if html_report else 0}", level=logging.INFO) except Exception as e: logger.log_message(f"Error generating HTML report: {str(e)}", level=logging.ERROR) # Continue even if HTML generation fails @@ -1432,10 +1484,8 @@ async def download_html_report( report.html_report = html_report report.updated_at = datetime.now(UTC) db_session.commit() - logger.log_message(f"Updated HTML report in database for UUID {report_uuid}", level=logging.INFO) except Exception as e: db_session.rollback() - logger.log_message(f"Error storing HTML report in database: {str(e)}", level=logging.ERROR) finally: db_session.close() except Exception as e: @@ -1467,7 +1517,7 @@ async def download_html_report( app.include_router(session_router) app.include_router(feedback_router) app.include_router(deep_analysis_router) -app.include_router(custom_agents_router) +app.include_router(templates_router) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/auto-analyst-backend/chat_database.db b/auto-analyst-backend/chat_database.db index 169301de..022a896b 100644 --- a/auto-analyst-backend/chat_database.db +++ b/auto-analyst-backend/chat_database.db @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:eba37409fcfd29b12a840c1a16aae4616c9a3a25edb89dc1746d23740a172607 -size 73728 +oid sha256:b74d10d11586da55532cda065ffa649d4cde09dbf9874c90ac275f1921d685da +size 110592 diff --git a/auto-analyst-backend/cleaned_property_data.csv b/auto-analyst-backend/cleaned_property_data.csv new file mode 100644 index 00000000..494c9709 --- /dev/null +++ b/auto-analyst-backend/cleaned_property_data.csv @@ -0,0 +1,3 @@ +price,area,bedrooms,bathrooms,stories,mainroad,guestroom,basement,hotwaterheating,airconditioning,parking,prefarea,furnishingstatus_semi-furnished,normalized_price,normalized_area +13300000,7420,4,2,3,1,1,0,0,1,2,0,True,0.7071067811865476,-0.7071067811865476 +12250000,8960,4,4,4,1,0,1,1,1,3,1,False,-0.7071067811865476,0.7071067811865476 diff --git a/auto-analyst-backend/property_price_vs_area.png b/auto-analyst-backend/property_price_vs_area.png new file mode 100644 index 00000000..c0188730 --- /dev/null +++ b/auto-analyst-backend/property_price_vs_area.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:14c495d58d3fddbe3e6623dbcd569a3b96c467a57b9ac98a3d4e22439abb153d +size 427665 diff --git a/auto-analyst-backend/requirements.txt b/auto-analyst-backend/requirements.txt index 2e774f49..0cd43eed 100644 --- a/auto-analyst-backend/requirements.txt +++ b/auto-analyst-backend/requirements.txt @@ -31,6 +31,7 @@ openpyxl==3.1.2 xlrd==2.0.1 openai==1.61.0 pandas==2.2.3 +polars==1.30.0 pillow==11.1.0 plotly==5.24.1 psycopg2==2.9.10 diff --git a/auto-analyst-backend/scripts/populate_agent_templates.py b/auto-analyst-backend/scripts/populate_agent_templates.py index 332e9be6..b0d9789d 100644 --- a/auto-analyst-backend/scripts/populate_agent_templates.py +++ b/auto-analyst-backend/scripts/populate_agent_templates.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -Script to populate custom agent templates. +Script to populate agent templates. These templates are available to all users but usable only by paid users. """ @@ -12,21 +12,22 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.db.init_db import session_factory -from src.db.schemas.models import CustomAgent +from src.db.schemas.models import AgentTemplate from sqlalchemy.exc import IntegrityError # Template agent definitions AGENT_TEMPLATES = { "Visualization": [ { - "agent_name": "matplotlib_agent", + "template_name": "matplotlib_agent", "display_name": "Matplotlib Visualization Agent", "description": "Creates static publication-quality plots using matplotlib and seaborn", + "icon_url": "https://cdn.jsdelivr.net/gh/devicons/devicon/icons/matplotlib/matplotlib-original.svg", "prompt_template": """ You are a matplotlib/seaborn visualization expert. Your task is to create high-quality static visualizations using matplotlib and seaborn libraries. IMPORTANT Instructions: -- You must only use matplotlib, seaborn, and numpy/pandas for visualizations +- You must only use matplotlib, seaborn, and numpy/polars for visualizations - Always use plt.style.use('seaborn-v0_8') or a clean style for better aesthetics - Include proper titles, axis labels, and legends - Use appropriate color palettes and consider accessibility @@ -38,135 +39,56 @@ """ }, { - "agent_name": "seaborn_agent", + "template_name": "seaborn_agent", "display_name": "Seaborn Statistical Plots Agent", - "description": "Creates statistical visualizations and plots using seaborn library", + "description": "Creates statistical visualizations and data exploration plots using seaborn", + "icon_url": "https://seaborn.pydata.org/_images/logo-mark-lightbg.svg", "prompt_template": """ -You are a seaborn statistical visualization expert. Create insightful statistical plots using seaborn. +You are a seaborn statistical visualization expert. Your task is to create statistical plots and exploratory data visualizations. IMPORTANT Instructions: -- Specialize in seaborn's statistical plotting capabilities -- Use seaborn's built-in statistical functions (regplot, distplot, boxplot, violin, etc.) -- Apply appropriate statistical themes and color palettes -- Include confidence intervals and statistical annotations where relevant -- Sample large datasets: if len(df) > 50000: df = df.sample(50000, random_state=42) -- Use plt.figure(figsize=(10, 6)) for appropriate sizing -- Always include proper statistical context in titles and labels +- Focus on seaborn for statistical plotting (distributions, relationships, categorical data) +- Use matplotlib as the backend for customization +- Create informative statistical plots: histograms, box plots, violin plots, pair plots, heatmaps +- Apply proper statistical annotations and significance testing where relevant +- Use seaborn's built-in themes and color palettes for professional appearance +- Include statistical summaries and insights in plot annotations +- Handle categorical and numerical data appropriately +- Always include proper legends, titles, and axis labels -Focus on revealing statistical relationships and distributions in the data. +Focus on revealing statistical patterns and relationships in data through visualization. """ }, - { - "agent_name": "plotly_advanced_agent", - "display_name": "Advanced Plotly Agent", - "description": "Creates sophisticated interactive visualizations with advanced Plotly features", - "prompt_template": """ -You are an advanced Plotly visualization expert. Create sophisticated interactive visualizations with advanced features. - -IMPORTANT Instructions: -- Use advanced Plotly features: subplots, animations, 3D plots, statistical charts -- Implement interactive features: hover data, clickable legends, zoom, pan -- Use plotly.graph_objects for fine control and plotly.express for rapid prototyping -- Add annotations, shapes, and custom styling -- Sample data if len(df) > 50000: df = df.sample(50000, random_state=42) -- Use fig.update_layout() for professional styling -- Return fig.to_html(full_html=False) for embedding - -Focus on creating publication-quality interactive visualizations with advanced features. -""" - } - ], - "Modelling": [ - { - "agent_name": "xgboost_agent", - "display_name": "XGBoost Machine Learning Agent", - "description": "Builds and optimizes XGBoost models for classification and regression tasks", - "prompt_template": """ -You are an XGBoost machine learning expert. Build, tune, and evaluate XGBoost models. - -IMPORTANT Instructions: -- Use XGBoost for both classification and regression tasks -- Implement proper train/validation/test splits -- Perform hyperparameter tuning using GridSearchCV or RandomizedSearchCV -- Handle categorical variables with proper encoding -- Include feature importance analysis and visualization -- Evaluate models with appropriate metrics (accuracy, precision, recall, F1, RMSE, MAE, etc.) -- Use cross-validation for robust model evaluation -- Plot training curves and validation curves -- Provide model interpretation and feature importance insights - -Focus on building production-ready XGBoost models with proper evaluation and interpretation. -""" - }, - { - "agent_name": "neural_network_agent", - "display_name": "Neural Network Agent", - "description": "Builds and trains neural networks using TensorFlow/Keras", - "prompt_template": """ -You are a neural network expert using TensorFlow/Keras. Build and train neural networks for various tasks. - -IMPORTANT Instructions: -- Design appropriate network architectures for the task (classification, regression, etc.) -- Implement proper data preprocessing and normalization -- Use appropriate activation functions, optimizers, and loss functions -- Implement callbacks: EarlyStopping, ReduceLROnPlateau, ModelCheckpoint -- Plot training history (loss and metrics over epochs) -- Evaluate model performance with appropriate metrics -- Include model summary and architecture visualization -- Handle overfitting with dropout, regularization, or data augmentation -- Use train/validation/test splits properly - -Focus on building effective neural networks with proper training procedures and evaluation. -""" - }, - { - "agent_name": "time_series_agent", - "display_name": "Time Series Forecasting Agent", - "description": "Specialized in time series analysis and forecasting using ARIMA, Prophet, LSTM", - "prompt_template": """ -You are a time series forecasting expert. Analyze temporal data and create forecasting models. - -IMPORTANT Instructions: -- Perform exploratory time series analysis (trend, seasonality, stationarity) -- Use appropriate models: ARIMA, SARIMA, Prophet, LSTM, or ensemble methods -- Test for stationarity using ADF test and apply differencing if needed -- Decompose time series into trend, seasonal, and residual components -- Create forecasts with confidence intervals -- Evaluate forecasts using MAE, RMSE, MAPE metrics -- Plot actual vs predicted values and residuals -- Handle missing values and outliers appropriately -- Consider multiple seasonalities and external factors - -Focus on accurate time series forecasting with proper validation and uncertainty quantification. -""" - } ], "Data Manipulation": [ { - "agent_name": "pandas_expert_agent", - "display_name": "Pandas Data Expert Agent", - "description": "Advanced pandas operations for complex data manipulation and analysis", + "template_name": "polars_agent", + "display_name": "Polars Data Processing Agent", + "description": "High-performance data manipulation and analysis using Polars", + "icon_url": "https://raw.githubusercontent.com/pola-rs/polars-static/master/logos/polars-logo-dark.svg", "prompt_template": """ -You are a pandas expert specializing in advanced data manipulation and analysis. +You are a Polars data processing expert. Perform high-performance data manipulation and analysis using Polars. IMPORTANT Instructions: -- Use advanced pandas operations: groupby, pivot, merge, concat, apply, transform -- Implement efficient data cleaning and preprocessing workflows -- Handle missing data with multiple strategies (imputation, dropping, flagging) -- Perform advanced aggregations and window functions -- Use vectorized operations for performance -- Handle large datasets efficiently with chunking if needed -- Create custom functions for complex transformations -- Use proper indexing and data types for optimization -- Include data quality checks and validation +- Use Polars for fast, memory-efficient data processing +- Leverage lazy evaluation with pl.scan_csv() and .lazy() for large datasets +- Implement efficient data transformations using Polars expressions +- Use Polars-specific methods for groupby, aggregations, and window functions +- Handle various data types and perform type conversions appropriately +- Optimize queries for performance using lazy evaluation and query optimization +- Implement complex data reshaping (pivots, melts, joins) +- Use Polars datetime functionality for time-based operations +- Convert to pandas only when necessary for visualization or other libraries +- Focus on performance and memory efficiency -Focus on efficient and robust data manipulation that prepares data for analysis or modeling. +Focus on leveraging Polars' speed and efficiency for data processing tasks. """ }, { - "agent_name": "data_cleaning_agent", + "template_name": "data_cleaning_agent", "display_name": "Data Cleaning Specialist Agent", "description": "Specialized in comprehensive data cleaning and quality assessment", + "icon_url": "https://cdn-icons-png.flaticon.com/512/2103/2103633.png", "prompt_template": """ You are a data cleaning specialist. Perform comprehensive data quality assessment and cleaning. @@ -186,9 +108,10 @@ """ }, { - "agent_name": "feature_engineering_agent", + "template_name": "feature_engineering_agent", "display_name": "Feature Engineering Agent", "description": "Creates and transforms features for machine learning models", + "icon_url": "https://cdn-icons-png.flaticon.com/512/2103/2103658.png", "prompt_template": """ You are a feature engineering expert. Create, transform, and select features for machine learning. @@ -223,37 +146,34 @@ def populate_templates(): print(f"\n--- Processing {category} Templates ---") for template_data in templates: - agent_name = template_data["agent_name"] + template_name = template_data["template_name"] # Check if template already exists - existing = session.query(CustomAgent).filter( - CustomAgent.agent_name == agent_name, - CustomAgent.is_template == True + existing = session.query(AgentTemplate).filter( + AgentTemplate.template_name == template_name ).first() if existing: - print(f"⏭️ Skipping {agent_name} (already exists)") + print(f"⏭️ Skipping {template_name} (already exists)") skipped_count += 1 continue # Create new template - template = CustomAgent( - user_id=None, # Templates don't belong to specific users - agent_name=agent_name, + template = AgentTemplate( + template_name=template_name, display_name=template_data["display_name"], description=template_data["description"], + icon_url=template_data["icon_url"], prompt_template=template_data["prompt_template"], - is_template=True, - template_category=category, + category=category, is_premium_only=True, # All templates require premium is_active=True, - usage_count=0, created_at=datetime.now(UTC), updated_at=datetime.now(UTC) ) session.add(template) - print(f"✅ Created template: {agent_name}") + print(f"✅ Created template: {template_name}") created_count += 1 # Commit all changes @@ -276,9 +196,7 @@ def list_templates(): session = session_factory() try: - templates = session.query(CustomAgent).filter( - CustomAgent.is_template == True - ).order_by(CustomAgent.template_category, CustomAgent.agent_name).all() + templates = session.query(AgentTemplate).order_by(AgentTemplate.category, AgentTemplate.template_name).all() if not templates: print("No templates found in database.") @@ -288,15 +206,14 @@ def list_templates(): current_category = None for template in templates: - if template.template_category != current_category: - current_category = template.template_category + if template.category != current_category: + current_category = template.category print(f"\n{current_category}:") status = "🔒 Premium" if template.is_premium_only else "🆓 Free" active = "✅ Active" if template.is_active else "❌ Inactive" - print(f" • {template.agent_name} ({template.display_name}) - {status} - {active}") + print(f" • {template.template_name} ({template.display_name}) - {status} - {active}") print(f" {template.description}") - print(f" Usage: {template.usage_count} times") except Exception as e: print(f"❌ Error listing templates: {str(e)}") @@ -308,9 +225,7 @@ def remove_all_templates(): session = session_factory() try: - deleted_count = session.query(CustomAgent).filter( - CustomAgent.is_template == True - ).delete() + deleted_count = session.query(AgentTemplate).delete() session.commit() print(f"🗑️ Removed {deleted_count} templates") @@ -324,7 +239,7 @@ def remove_all_templates(): if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description="Manage custom agent templates") + parser = argparse.ArgumentParser(description="Manage agent templates") parser.add_argument("action", choices=["populate", "list", "remove-all"], help="Action to perform") diff --git a/auto-analyst-backend/src/agents/agents.py b/auto-analyst-backend/src/agents/agents.py index b42fcefe..16ef0a66 100644 --- a/auto-analyst-backend/src/agents/agents.py +++ b/auto-analyst-backend/src/agents/agents.py @@ -36,117 +36,229 @@ def create_custom_agent_signature(agent_name, description, prompt_template): CustomAgentSignature = type(agent_name, (dspy.Signature,), class_attributes) return CustomAgentSignature -def load_custom_agents_from_db(user_id, db_session, include_templates=True): +def load_user_enabled_templates_from_db(user_id, db_session): """ - Load custom agents for a specific user from the database, optionally including templates. + Load template agents that are enabled for a specific user from the database. + All templates are enabled by default unless explicitly disabled by user preference. Args: user_id: ID of the user db_session: Database session - include_templates: Whether to include template agents that are available to all users Returns: - Dict of custom agent signatures keyed by agent name + Dict of template agent signatures keyed by template name """ try: - from src.db.schemas.models import CustomAgent + from src.db.schemas.models import AgentTemplate, UserTemplatePreference agent_signatures = {} - # Query active custom agents for the user - if user_id: - custom_agents = db_session.query(CustomAgent).filter( - CustomAgent.user_id == user_id, - CustomAgent.is_active == True, - CustomAgent.is_template == False # Only user-created agents - ).all() + if not user_id: + return agent_signatures + + # Get all active templates + all_templates = db_session.query(AgentTemplate).filter( + AgentTemplate.is_active == True + ).all() + + for template in all_templates: + # Check if user has explicitly disabled this template + preference = db_session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.template_id == template.template_id + ).first() + + # Template is disabled by default unless explicitly enabled + # Only enabled if preference record exists and is_enabled=True + is_enabled = preference.is_enabled if preference else False - for agent in custom_agents: - # Create dynamic signature for each custom agent + if is_enabled: + # Create dynamic signature for each enabled template signature = create_custom_agent_signature( - agent.agent_name, - agent.description, - agent.prompt_template + template.template_name, + template.description, + template.prompt_template ) - agent_signatures[agent.agent_name] = signature + agent_signatures[template.template_name] = signature + + return agent_signatures + + except Exception as e: + logger.log_message(f"Error loading user enabled templates for user {user_id}: {str(e)}", level=logging.ERROR) + return {} + +def load_user_enabled_templates_for_planner_from_db(user_id, db_session): + """ + Load template agents that are enabled for planner use (max 10, prioritized by usage). + + Args: + user_id: ID of the user + db_session: Database session + + Returns: + Dict of template agent signatures keyed by template name (max 10) + """ + try: + from src.db.schemas.models import AgentTemplate, UserTemplatePreference + + agent_signatures = {} + + if not user_id: + return agent_signatures - # Also include template agents if requested - if include_templates: - template_agents = db_session.query(CustomAgent).filter( - CustomAgent.is_template == True, - CustomAgent.is_active == True - ).all() + # Get enabled templates ordered by usage (most used first) and limit to 10 + enabled_preferences = db_session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.is_enabled == True + ).order_by( + UserTemplatePreference.usage_count.desc(), + UserTemplatePreference.last_used_at.desc() + ).limit(10).all() + + for preference in enabled_preferences: + # Get template details + template = db_session.query(AgentTemplate).filter( + AgentTemplate.template_id == preference.template_id, + AgentTemplate.is_active == True + ).first() - for template in template_agents: - # Create dynamic signature for each template agent + if template: + # Create dynamic signature for each enabled template signature = create_custom_agent_signature( - template.agent_name, + template.template_name, template.description, template.prompt_template ) - agent_signatures[template.agent_name] = signature + agent_signatures[template.template_name] = signature + logger.log_message(f"Loaded {len(agent_signatures)} templates for planner", level=logging.DEBUG) return agent_signatures except Exception as e: - logger.log_message(f"Error loading custom agents for user {user_id}: {str(e)}", level=logging.ERROR) + logger.log_message(f"Error loading planner templates for user {user_id}: {str(e)}", level=logging.ERROR) return {} -def get_custom_agent_description(agent_name, custom_agents_descriptions): +def get_all_available_templates(db_session): """ - Get description for a custom agent. + Get all available agent templates from the database. Args: - agent_name: Name of the custom agent - custom_agents_descriptions: Dict of custom agent descriptions + db_session: Database session Returns: - Description string or default message + List of agent template records """ - return custom_agents_descriptions.get(agent_name.lower(), "Custom agent - no description available") + try: + from src.db.schemas.models import AgentTemplate + + templates = db_session.query(AgentTemplate).filter( + AgentTemplate.is_active == True + ).all() + + return templates + + except Exception as e: + logger.log_message(f"Error getting all available templates: {str(e)}", level=logging.ERROR) + return [] -def save_custom_agent_to_db(user_id, agent_name, display_name, description, prompt_template, db_session): +def toggle_user_template_preference(user_id, template_id, is_enabled, db_session): """ - Save a new custom agent to the database. + Toggle a user's template preference (enable/disable). Args: - user_id: ID of the user creating the agent - agent_name: Unique name for the agent (e.g., 'pytorch_agent') - display_name: User-friendly display name - description: Short description for agent selection - prompt_template: Main prompt/instructions for agent behavior + user_id: ID of the user + template_id: ID of the template + is_enabled: Whether to enable or disable the template db_session: Database session Returns: - Tuple (success: bool, message: str, agent_id: int or None) + Tuple (success: bool, message: str) """ try: - from src.db.schemas.models import CustomAgent - from sqlalchemy.exc import IntegrityError + from src.db.schemas.models import UserTemplatePreference, AgentTemplate + from datetime import datetime, UTC + + # Verify template exists and is active + template = db_session.query(AgentTemplate).filter( + AgentTemplate.template_id == template_id, + AgentTemplate.is_active == True + ).first() - # Create new custom agent - new_agent = CustomAgent( - user_id=user_id, - agent_name=agent_name.lower().strip(), - display_name=display_name, - description=description, - prompt_template=prompt_template, - is_active=True, - usage_count=0 - ) + if not template: + return False, "Template not found or inactive" + + # Check if preference record exists + preference = db_session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.template_id == template_id + ).first() + + if preference: + # Update existing preference + preference.is_enabled = is_enabled + preference.updated_at = datetime.now(UTC) + else: + # Create new preference record + preference = UserTemplatePreference( + user_id=user_id, + template_id=template_id, + is_enabled=is_enabled, + usage_count=0, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC) + ) + db_session.add(preference) - db_session.add(new_agent) db_session.commit() - return True, "Custom agent created successfully", new_agent.agent_id + action = "enabled" if is_enabled else "disabled" + return True, f"Template '{template.template_name}' {action} successfully" - except IntegrityError: - db_session.rollback() - return False, f"Agent name '{agent_name}' already exists, please choose a different name", None except Exception as e: db_session.rollback() - logger.log_message(f"Error saving custom agent: {str(e)}", level=logging.ERROR) - return False, f"Error creating custom agent: {str(e)}", None + logger.log_message(f"Error toggling template preference: {str(e)}", level=logging.ERROR) + return False, f"Error updating template preference: {str(e)}" + + + +def load_all_available_templates_from_db(db_session): + """ + Load ALL available template agents from the database for direct access. + This allows users to use any template via @template_name regardless of preferences. + + Args: + db_session: Database session + + Returns: + Dict of template agent signatures keyed by template name + """ + try: + from src.db.schemas.models import AgentTemplate + + agent_signatures = {} + + # Get all active templates + all_templates = db_session.query(AgentTemplate).filter( + AgentTemplate.is_active == True + ).all() + + for template in all_templates: + # Create dynamic signature for all active templates + signature = create_custom_agent_signature( + template.template_name, + template.description, + template.prompt_template + ) + agent_signatures[template.template_name] = signature + + logger.log_message(f"Loaded {len(agent_signatures)} templates", level=logging.INFO) + return agent_signatures + + except Exception as e: + logger.log_message(f"Error loading all available templates: {str(e)}", level=logging.ERROR) + return {} + + # === END CUSTOM AGENT FUNCTIONALITY === @@ -1100,7 +1212,7 @@ class code_edit(dspy.Signature): class auto_analyst_ind(dspy.Module): """Handles individual agent execution when explicitly specified in query""" - def __init__(self, agents, retrievers, user_id=None): + def __init__(self, agents, retrievers, user_id=None, db_session=None): # Initialize agent modules and retrievers self.agents = {} self.agent_inputs = {} @@ -1112,6 +1224,26 @@ def __init__(self, agents, retrievers, user_id=None): self.agents[name] = dspy.asyncify(dspy.ChainOfThoughtWithHint(a)) self.agent_inputs[name] = {x.strip() for x in str(agents[i].__pydantic_core_schema__['cls']).split('->')[0].split('(')[1].split(',')} self.agent_desc.append(get_agent_description(name)) + + # Load ALL available template agents for direct access (regardless of user preferences) + if db_session: + try: + template_signatures = load_all_available_templates_from_db(db_session) + + for template_name, signature in template_signatures.items(): + # Add template agent to agents dict + self.agents[template_name] = dspy.asyncify(dspy.ChainOfThoughtWithHint(signature)) + + # Extract input fields from signature - templates use standard fields + self.agent_inputs[template_name] = {'goal', 'dataset', 'styling_index', 'hint'} + + # Add description + self.agent_desc.append(f"Template: {template_name}") + + logger.log_message(f"Loaded {len(template_signatures)} templates for direct access", level=logging.DEBUG) + + except Exception as e: + logger.log_message(f"Error loading templates for direct access: {str(e)}", level=logging.ERROR) # Initialize components # self.memory_summarize_agent = dspy.ChainOfThought(m.memory_summarize_agent) @@ -1123,58 +1255,71 @@ def __init__(self, agents, retrievers, user_id=None): self.user_id = user_id async def _track_agent_usage(self, agent_name): - """Track usage for custom agents and templates""" + """Track usage for template agents""" try: # Skip tracking for standard agents if agent_name in ['preprocessing_agent', 'statistical_analytics_agent', 'sk_learn_agent', 'data_viz_agent', 'basic_qa_agent']: return - # Only track if we have user_id (custom/template agents) + # Only track if we have user_id (template agents) if not self.user_id: return from src.db.init_db import session_factory - from src.db.schemas.models import CustomAgent + from src.db.schemas.models import AgentTemplate, UserTemplatePreference from datetime import datetime, UTC # Create database session session = session_factory() try: - # First try to find as user's custom agent - agent = session.query(CustomAgent).filter( - CustomAgent.agent_name == agent_name, - CustomAgent.user_id == self.user_id, - CustomAgent.is_template == False + # Find the template + template = session.query(AgentTemplate).filter( + AgentTemplate.template_name == agent_name ).first() - # If not found, try to find as template - if not agent: - agent = session.query(CustomAgent).filter( - CustomAgent.agent_name == agent_name, - CustomAgent.is_template == True - ).first() + if not template: + logger.log_message(f"Template '{agent_name}' not found for usage tracking", level=logging.WARNING) + return - # Update usage count if agent found - if agent: - agent.usage_count += 1 - agent.updated_at = datetime.now(UTC) - session.commit() - - agent_type = "template" if agent.is_template else "custom" - logger.log_message( - f"Incremented usage count for {agent_type} agent '{agent_name}' (now: {agent.usage_count})", - level=logging.INFO + # Find or create user template preference record + preference = session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == self.user_id, + UserTemplatePreference.template_id == template.template_id + ).first() + + if not preference: + # Create new preference record (disabled by default) + preference = UserTemplatePreference( + user_id=self.user_id, + template_id=template.template_id, + is_enabled=False, # Disabled by default + usage_count=0, + last_used_at=None, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC) ) + session.add(preference) + + # Update usage tracking + preference.usage_count += 1 + preference.last_used_at = datetime.now(UTC) + preference.updated_at = datetime.now(UTC) + session.commit() + + logger.log_message( + f"Tracked usage for template '{agent_name}' (count: {preference.usage_count})", + level=logging.DEBUG + ) except Exception as e: session.rollback() - logger.log_message(f"Error tracking usage for agent {agent_name}: {str(e)}", level=logging.ERROR) + logger.log_message(f"Error tracking usage for template {agent_name}: {str(e)}", level=logging.ERROR) finally: session.close() except Exception as e: logger.log_message(f"Error in _track_agent_usage for {agent_name}: {str(e)}", level=logging.ERROR) - + async def execute_agent(self, specified_agent, inputs): """Execute agent and generate memory summary in parallel""" try: @@ -1211,6 +1356,10 @@ async def forward(self, query, specified_agent): # Execute agent result = await self.agents[specified_agent.strip()](**inputs) + + # Track usage for template agents + await self._track_agent_usage(specified_agent.strip()) + output_dict = {specified_agent.strip(): dict(result)} if "error" in output_dict: @@ -1250,6 +1399,9 @@ async def execute_multiple_agents(self, query, agent_list): agent_dict = dict(agent_result) results[agent_name] = agent_dict + # Track usage for template agents + await self._track_agent_usage(agent_name) + # Collect code for later combination if 'code' in agent_dict: code_list.append(agent_dict['code']) @@ -1269,7 +1421,6 @@ def __init__(self, agents, retrievers, user_id=None, db_session=None): self.agents = {} self.agent_inputs = {} self.agent_desc = [] - self.custom_agents_descriptions = {} # Load standard agents for i, a in enumerate(agents): @@ -1278,56 +1429,40 @@ def __init__(self, agents, retrievers, user_id=None, db_session=None): self.agent_inputs[name] = {x.strip() for x in str(agents[i].__pydantic_core_schema__['cls']).split('->')[0].split('(')[1].split(',')} self.agent_desc.append({name: get_agent_description(name)}) - # Load custom agents if user_id and db_session are provided + # Load user-enabled template agents if user_id and db_session are provided if user_id and db_session: try: - custom_agent_signatures = load_custom_agents_from_db(user_id, db_session, include_templates=True) + template_signatures = load_user_enabled_templates_for_planner_from_db(user_id, db_session) - for agent_name, signature in custom_agent_signatures.items(): - # Add custom agent to agents dict - self.agents[agent_name] = dspy.asyncify(dspy.ChainOfThought(signature)) + for template_name, signature in template_signatures.items(): + # Add template agent to agents dict + self.agents[template_name] = dspy.asyncify(dspy.ChainOfThought(signature)) - # Extract input fields from signature - custom agents use standard fields like data_viz_agent - self.agent_inputs[agent_name] = {'goal', 'dataset', 'styling_index'} + # Extract input fields from signature - templates use standard fields like data_viz_agent + self.agent_inputs[template_name] = {'goal', 'dataset', 'styling_index'} - # Store custom agent description + # Store template agent description try: - from src.db.schemas.models import CustomAgent + from src.db.schemas.models import AgentTemplate - # First try to find as user agent - agent_record = db_session.query(CustomAgent).filter( - CustomAgent.agent_name == agent_name, - CustomAgent.user_id == user_id, - CustomAgent.is_template == False + # Find template record + template_record = db_session.query(AgentTemplate).filter( + AgentTemplate.template_name == template_name ).first() - # If not found, try to find as template - if not agent_record: - agent_record = db_session.query(CustomAgent).filter( - CustomAgent.agent_name == agent_name, - CustomAgent.is_template == True - ).first() - - if agent_record: - description = agent_record.description - if agent_record.is_template: - # Add prefix for template agents - description = f"Template: {description}" - - self.custom_agents_descriptions[agent_name] = description - self.agent_desc.append({agent_name: description}) + if template_record: + description = f"Template: {template_record.description}" + self.agent_desc.append({template_name: description}) else: - self.custom_agents_descriptions[agent_name] = f"Custom agent: {agent_name}" - self.agent_desc.append({agent_name: f"Custom agent: {agent_name}"}) + self.agent_desc.append({template_name: f"Template: {template_name}"}) except Exception as desc_error: - logger.log_message(f"Error getting description for custom agent {agent_name}: {str(desc_error)}", level=logging.WARNING) - self.custom_agents_descriptions[agent_name] = f"Custom agent: {agent_name}" - self.agent_desc.append({agent_name: f"Custom agent: {agent_name}"}) + logger.log_message(f"Error getting description for template {template_name}: {str(desc_error)}", level=logging.WARNING) + self.agent_desc.append({template_name: f"Template: {template_name}"}) - logger.log_message(f"Loaded {len(custom_agent_signatures)} custom agents (including templates) for user {user_id}", level=logging.INFO) + logger.log_message(f"Loaded {len(template_signatures)} enabled templates for planner", level=logging.DEBUG) except Exception as e: - logger.log_message(f"Error loading custom agents for user {user_id}: {str(e)}", level=logging.ERROR) + logger.log_message(f"Error loading template agents for user {user_id}: {str(e)}", level=logging.ERROR) self.agents['basic_qa_agent'] = dspy.asyncify(dspy.Predict("goal->answer")) self.agent_inputs['basic_qa_agent'] = {"goal"} @@ -1349,52 +1484,67 @@ def __init__(self, agents, retrievers, user_id=None, db_session=None): self.user_id = user_id async def _track_agent_usage(self, agent_name): - """Track usage for custom agents and templates""" + """Track usage for template agents""" try: # Skip tracking for standard agents if agent_name in ['preprocessing_agent', 'statistical_analytics_agent', 'sk_learn_agent', 'data_viz_agent', 'basic_qa_agent']: return - # Only track if we have user_id (custom/template agents) + # Only track if we have user_id (template agents) if not self.user_id: return from src.db.init_db import session_factory - from src.db.schemas.models import CustomAgent + from src.db.schemas.models import AgentTemplate, UserTemplatePreference from datetime import datetime, UTC # Create database session session = session_factory() try: - # First try to find as user's custom agent - agent = session.query(CustomAgent).filter( - CustomAgent.agent_name == agent_name, - CustomAgent.user_id == self.user_id, - CustomAgent.is_template == False + # Find the template + template = session.query(AgentTemplate).filter( + AgentTemplate.template_name == agent_name ).first() - # If not found, try to find as template - if not agent: - agent = session.query(CustomAgent).filter( - CustomAgent.agent_name == agent_name, - CustomAgent.is_template == True - ).first() + if not template: + logger.log_message(f"Template '{agent_name}' not found", level=logging.WARNING) + return - # Update usage count if agent found - if agent: - agent.usage_count += 1 - agent.updated_at = datetime.now(UTC) - session.commit() - - agent_type = "template" if agent.is_template else "custom" - logger.log_message( - f"Incremented usage count for {agent_type} agent '{agent_name}' (now: {agent.usage_count})", - level=logging.INFO + # Find or create user template preference record + preference = session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == self.user_id, + UserTemplatePreference.template_id == template.template_id + ).first() + + if preference: + # Update existing preference + preference.usage_count += 1 + preference.last_used_at = datetime.now(UTC) + preference.updated_at = datetime.now(UTC) + else: + # Create new preference record when template is used directly (via @mention) + # Direct usage doesn't auto-enable for planner but tracks usage + preference = UserTemplatePreference( + user_id=self.user_id, + template_id=template.template_id, + is_enabled=False, # Default disabled for planner + usage_count=1, + last_used_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC) ) + session.add(preference) + + session.commit() + + logger.log_message( + f"Tracked usage for template '{agent_name}' for user {self.user_id} (count: {preference.usage_count})", + level=logging.DEBUG + ) except Exception as e: session.rollback() - logger.log_message(f"Error tracking usage for agent {agent_name}: {str(e)}", level=logging.ERROR) + logger.log_message(f"Error tracking template usage for {agent_name}: {str(e)}", level=logging.ERROR) finally: session.close() diff --git a/auto-analyst-backend/src/db/schemas/models.py b/auto-analyst-backend/src/db/schemas/models.py index de48e080..5fbfa64d 100644 --- a/auto-analyst-backend/src/db/schemas/models.py +++ b/auto-analyst-backend/src/db/schemas/models.py @@ -18,7 +18,7 @@ class User(Base): chats = relationship("Chat", back_populates="user", cascade="all, delete-orphan") usage_records = relationship("ModelUsage", back_populates="user") deep_analysis_reports = relationship("DeepAnalysisReport", back_populates="user", cascade="all, delete-orphan") - custom_agents = relationship("CustomAgent", back_populates="user", cascade="all, delete-orphan") + template_preferences = relationship("UserTemplatePreference", back_populates="user", cascade="all, delete-orphan") # Define the Chats table class Chat(Base): @@ -173,38 +173,61 @@ class DeepAnalysisReport(Base): # Relationships user = relationship("User", back_populates="deep_analysis_reports") -class CustomAgent(Base): - """Stores custom agents created by premium users and system templates.""" - __tablename__ = 'custom_agents' +class AgentTemplate(Base): + """Stores predefined agent templates that users can enable/disable.""" + __tablename__ = 'agent_templates' - agent_id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(Integer, ForeignKey('users.user_id', ondelete="CASCADE"), nullable=True) # Nullable for templates + template_id = Column(Integer, primary_key=True, autoincrement=True) - # Agent definition - agent_name = Column(String(100), nullable=False) # e.g., 'pytorch_agent', 'deep_learning_agent' + # Template definition + template_name = Column(String(100), nullable=False, unique=True) # e.g., 'pytorch_specialist', 'data_cleaning_expert' display_name = Column(String(200), nullable=True) # User-friendly display name - description = Column(Text, nullable=False) # Short description for agent selection + description = Column(Text, nullable=False) # Short description for template selection prompt_template = Column(Text, nullable=False) # Main prompt/instructions for agent behavior - # Template fields - is_template = Column(Boolean, default=False) # True for system templates, False for user agents - template_category = Column(String(50), nullable=True) # 'Visualization', 'Modelling', 'Data Manipulation' + # Template appearance + icon_url = Column(String(500), nullable=True) # URL to template icon (CDN, data URL, or relative path) + + # Template categorization + category = Column(String(50), nullable=True) # 'Visualization', 'Modelling', 'Data Manipulation' is_premium_only = Column(Boolean, default=False) # True if template requires premium subscription # Status and metadata is_active = Column(Boolean, default=True) - usage_count = Column(Integer, default=0) # Track how many times agent has been used # Timestamps created_at = Column(DateTime, default=lambda: datetime.now(UTC)) updated_at = Column(DateTime, default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) # Relationships - user = relationship("User", back_populates="custom_agents") + user_preferences = relationship("UserTemplatePreference", back_populates="template", cascade="all, delete-orphan") + +class UserTemplatePreference(Base): + """Tracks user preferences and usage for agent templates.""" + __tablename__ = 'user_template_preferences' + + preference_id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey('users.user_id', ondelete="CASCADE"), nullable=False) + template_id = Column(Integer, ForeignKey('agent_templates.template_id', ondelete="CASCADE"), nullable=False) + + # User preferences + is_enabled = Column(Boolean, default=True) # Whether user has this template enabled + + # Usage tracking + usage_count = Column(Integer, default=0) # Track how many times user has used this template + last_used_at = Column(DateTime, nullable=True) # Last time user used this template + + # Timestamps + created_at = Column(DateTime, default=lambda: datetime.now(UTC)) + updated_at = Column(DateTime, default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) + + # Relationships + user = relationship("User", back_populates="template_preferences") + template = relationship("AgentTemplate", back_populates="user_preferences") - # Constraints + # Constraints - user can only have one preference record per template __table_args__ = ( - UniqueConstraint('user_id', 'agent_name', name='unique_user_agent_name'), + UniqueConstraint('user_id', 'template_id', name='unique_user_template_preference'), ) \ No newline at end of file diff --git a/auto-analyst-backend/src/managers/session_manager.py b/auto-analyst-backend/src/managers/session_manager.py index bf5eb558..013801c5 100644 --- a/auto-analyst-backend/src/managers/session_manager.py +++ b/auto-analyst-backend/src/managers/session_manager.py @@ -316,7 +316,7 @@ def create_ai_system_for_user(self, retrievers, user_id=None): user_id=user_id, db_session=db_session ) - logger.log_message(f"Created AI system with custom agents for user {user_id}", level=logging.INFO) + logger.log_message(f"Created AI system for user {user_id}", level=logging.INFO) return ai_system finally: db_session.close() @@ -368,7 +368,7 @@ def set_session_user(self, session_id: str, user_id: int, chat_id: int = None): session_retrievers = self._sessions[session_id]["retrievers"] user_ai_system = self.create_ai_system_for_user(session_retrievers, user_id) self._sessions[session_id]["ai_system"] = user_ai_system - logger.log_message(f"Updated AI system for session {session_id} with user {user_id} custom agents", level=logging.INFO) + logger.log_message(f"Updated AI system for session {session_id} with user {user_id}", level=logging.INFO) except Exception as e: logger.log_message(f"Error updating AI system for user {user_id}: {str(e)}", level=logging.ERROR) # Continue with existing AI system if update fails diff --git a/auto-analyst-backend/src/routes/custom_agents_routes.py b/auto-analyst-backend/src/routes/custom_agents_routes.py deleted file mode 100644 index 54018e28..00000000 --- a/auto-analyst-backend/src/routes/custom_agents_routes.py +++ /dev/null @@ -1,697 +0,0 @@ -import logging -import os -from fastapi import APIRouter, Depends, HTTPException, Query, Body -from pydantic import BaseModel, Field -from typing import List, Optional, Dict, Any -from datetime import datetime, UTC -from sqlalchemy import desc -from sqlalchemy.exc import IntegrityError - -from src.db.init_db import session_factory -from src.db.schemas.models import CustomAgent, User -from src.utils.logger import Logger -import dspy -from src.agents.agents import custom_agent_instruction_generator - -# Initialize logger with console logging disabled -logger = Logger("custom_agents_routes", see_time=True, console_log=False) - -# Initialize router -router = APIRouter(prefix="/custom_agents", tags=["custom_agents"]) - -# Pydantic models for request/response -class CustomAgentCreate(BaseModel): - agent_name: str = Field(..., min_length=1, max_length=100, description="Unique agent name (e.g., 'pytorch_agent')") - display_name: Optional[str] = Field(None, max_length=200, description="User-friendly display name") - description: str = Field(..., min_length=10, max_length=1000, description="Short description for agent selection") - prompt_template: str = Field(..., min_length=50, description="Main prompt/instructions for agent behavior") - -class CustomAgentUpdate(BaseModel): - display_name: Optional[str] = Field(None, max_length=200) - description: Optional[str] = Field(None, min_length=10, max_length=1000) - prompt_template: Optional[str] = Field(None, min_length=50) - is_active: Optional[bool] = None - -class CustomAgentResponse(BaseModel): - agent_id: int - agent_name: str - display_name: Optional[str] - description: str - prompt_template: str - is_active: bool - usage_count: int - created_at: datetime - updated_at: datetime - -class CustomAgentListResponse(BaseModel): - agent_id: int - agent_name: str - display_name: Optional[str] - description: str - is_active: bool - usage_count: int - created_at: datetime - -class TemplateAgentResponse(BaseModel): - agent_id: int - agent_name: str - display_name: Optional[str] - description: str - prompt_template: str - template_category: str - is_premium_only: bool - is_active: bool - usage_count: int - created_at: datetime - -class TemplatesByCategory(BaseModel): - category: str - templates: List[TemplateAgentResponse] - -class AgentInstructionRequest(BaseModel): - category: str = Field(..., description="The category of the agent: 'Visualization', 'Modelling', or 'Data Manipulation'") - user_requirements: str = Field(..., min_length=10, max_length=2000, description="User's description of what they want the agent to do") - -class AgentInstructionResponse(BaseModel): - agent_instructions: str - category: str - user_requirements: str - -# Routes -@router.post("/", response_model=CustomAgentResponse) -async def create_custom_agent(agent: CustomAgentCreate, user_id: int = Query(...)): - """Create a new custom agent for a user""" - try: - session = session_factory() - - try: - # Validate user exists - user = session.query(User).filter(User.user_id == user_id).first() - if not user: - raise HTTPException(status_code=404, detail="User not found") - - # Create new custom agent - now = datetime.now(UTC) - new_agent = CustomAgent( - user_id=user_id, - agent_name=agent.agent_name.lower().strip(), - display_name=agent.display_name, - description=agent.description, - prompt_template=agent.prompt_template, - is_active=True, - usage_count=0, - created_at=now, - updated_at=now - ) - - session.add(new_agent) - session.commit() - session.refresh(new_agent) - - logger.log_message(f"Created custom agent '{agent.agent_name}' for user {user_id}", level=logging.INFO) - - return CustomAgentResponse( - agent_id=new_agent.agent_id, - agent_name=new_agent.agent_name, - display_name=new_agent.display_name, - description=new_agent.description, - prompt_template=new_agent.prompt_template, - is_active=new_agent.is_active, - usage_count=new_agent.usage_count, - created_at=new_agent.created_at, - updated_at=new_agent.updated_at - ) - - except IntegrityError: - session.rollback() - raise HTTPException(status_code=400, detail=f"Agent name '{agent.agent_name}' already exists, please choose a different name") - except Exception as e: - session.rollback() - logger.log_message(f"Error creating custom agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to create custom agent: {str(e)}") - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error creating custom agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to create custom agent: {str(e)}") - -@router.get("/", response_model=List[CustomAgentListResponse]) -async def get_custom_agents( - user_id: int = Query(...), - active_only: bool = Query(False, description="Return only active agents"), - limit: int = Query(50, ge=1, le=100), - offset: int = Query(0, ge=0) -): - """Get custom agents for a user""" - try: - session = session_factory() - - try: - query = session.query(CustomAgent).filter(CustomAgent.user_id == user_id) - - if active_only: - query = query.filter(CustomAgent.is_active == True) - - # Order by most recently created first - query = query.order_by(desc(CustomAgent.created_at)) - - agents = query.limit(limit).offset(offset).all() - - return [CustomAgentListResponse( - agent_id=agent.agent_id, - agent_name=agent.agent_name, - display_name=agent.display_name, - description=agent.description, - is_active=agent.is_active, - usage_count=agent.usage_count, - created_at=agent.created_at - ) for agent in agents] - - finally: - session.close() - - except Exception as e: - logger.log_message(f"Error retrieving custom agents: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to retrieve custom agents: {str(e)}") - -@router.get("/{agent_id}", response_model=CustomAgentResponse) -async def get_custom_agent(agent_id: int, user_id: int = Query(...)): - """Get a specific custom agent by ID""" - try: - session = session_factory() - - try: - agent = session.query(CustomAgent).filter( - CustomAgent.agent_id == agent_id, - CustomAgent.user_id == user_id - ).first() - - if not agent: - raise HTTPException(status_code=404, detail=f"Custom agent with ID {agent_id} not found") - - return CustomAgentResponse( - agent_id=agent.agent_id, - agent_name=agent.agent_name, - display_name=agent.display_name, - description=agent.description, - prompt_template=agent.prompt_template, - is_active=agent.is_active, - usage_count=agent.usage_count, - created_at=agent.created_at, - updated_at=agent.updated_at - ) - - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error retrieving custom agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to retrieve custom agent: {str(e)}") - -@router.put("/{agent_id}", response_model=CustomAgentResponse) -async def update_custom_agent(agent_id: int, agent_update: CustomAgentUpdate, user_id: int = Query(...)): - """Update a custom agent""" - try: - session = session_factory() - - try: - agent = session.query(CustomAgent).filter( - CustomAgent.agent_id == agent_id, - CustomAgent.user_id == user_id - ).first() - - if not agent: - raise HTTPException(status_code=404, detail=f"Custom agent with ID {agent_id} not found") - - # Update fields if provided - if agent_update.display_name is not None: - agent.display_name = agent_update.display_name - if agent_update.description is not None: - agent.description = agent_update.description - if agent_update.prompt_template is not None: - agent.prompt_template = agent_update.prompt_template - if agent_update.is_active is not None: - agent.is_active = agent_update.is_active - - agent.updated_at = datetime.now(UTC) - session.commit() - session.refresh(agent) - - logger.log_message(f"Updated custom agent {agent_id} for user {user_id}", level=logging.INFO) - - return CustomAgentResponse( - agent_id=agent.agent_id, - agent_name=agent.agent_name, - display_name=agent.display_name, - description=agent.description, - prompt_template=agent.prompt_template, - is_active=agent.is_active, - usage_count=agent.usage_count, - created_at=agent.created_at, - updated_at=agent.updated_at - ) - - except Exception as e: - session.rollback() - logger.log_message(f"Error updating custom agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to update custom agent: {str(e)}") - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error updating custom agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to update custom agent: {str(e)}") - -@router.delete("/{agent_id}") -async def delete_custom_agent(agent_id: int, user_id: int = Query(...)): - """Delete a custom agent""" - try: - session = session_factory() - - try: - agent = session.query(CustomAgent).filter( - CustomAgent.agent_id == agent_id, - CustomAgent.user_id == user_id - ).first() - - if not agent: - raise HTTPException(status_code=404, detail=f"Custom agent with ID {agent_id} not found") - - session.delete(agent) - session.commit() - - logger.log_message(f"Deleted custom agent {agent_id} for user {user_id}", level=logging.INFO) - - return {"message": f"Custom agent {agent_id} deleted successfully"} - - except Exception as e: - session.rollback() - logger.log_message(f"Error deleting custom agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to delete custom agent: {str(e)}") - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error deleting custom agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to delete custom agent: {str(e)}") - -@router.post("/{agent_id}/increment_usage") -async def increment_usage_count(agent_id: int, user_id: int = Query(...)): - """Increment usage count for a custom agent""" - try: - session = session_factory() - - try: - agent = session.query(CustomAgent).filter( - CustomAgent.agent_id == agent_id, - CustomAgent.user_id == user_id - ).first() - - if not agent: - raise HTTPException(status_code=404, detail=f"Custom agent with ID {agent_id} not found") - - agent.usage_count += 1 - agent.updated_at = datetime.now(UTC) - session.commit() - - logger.log_message(f"Incremented usage count for agent {agent_id} (now: {agent.usage_count})", level=logging.INFO) - - return {"message": "Usage count incremented", "usage_count": agent.usage_count} - - except Exception as e: - session.rollback() - logger.log_message(f"Error incrementing usage count: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to increment usage count: {str(e)}") - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error incrementing usage count: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to increment usage count: {str(e)}") - -@router.post("/{agent_id}/toggle_active") -async def toggle_agent_active_status(agent_id: int, user_id: int = Query(...)): - """Toggle active status for a custom agent with 10-agent limit""" - try: - session = session_factory() - - try: - agent = session.query(CustomAgent).filter( - CustomAgent.agent_id == agent_id, - CustomAgent.user_id == user_id, - CustomAgent.is_template == False - ).first() - - if not agent: - raise HTTPException(status_code=404, detail=f"Custom agent with ID {agent_id} not found") - - # If trying to activate an agent, check the 10-agent limit - if not agent.is_active: - # Count currently active custom agents for this user - active_count = session.query(CustomAgent).filter( - CustomAgent.user_id == user_id, - CustomAgent.is_active == True, - CustomAgent.is_template == False - ).count() - - if active_count >= 10: - raise HTTPException( - status_code=400, - detail="You can have at most 10 active custom agents. Please deactivate some agents first." - ) - - # Toggle the active status - agent.is_active = not agent.is_active - agent.updated_at = datetime.now(UTC) - session.commit() - - status_text = "activated" if agent.is_active else "deactivated" - logger.log_message(f"Agent {agent_id} {status_text} for user {user_id}", level=logging.INFO) - - return { - "message": f"Agent {status_text} successfully", - "is_active": agent.is_active, - "agent_id": agent.agent_id - } - - except Exception as e: - session.rollback() - logger.log_message(f"Error toggling agent status: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to toggle agent status: {str(e)}") - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error toggling agent status: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to toggle agent status: {str(e)}") - -@router.get("/active_count/{user_id}") -async def get_active_agents_count(user_id: int): - """Get count of active custom agents for a user""" - try: - session = session_factory() - - try: - active_count = session.query(CustomAgent).filter( - CustomAgent.user_id == user_id, - CustomAgent.is_active == True, - CustomAgent.is_template == False - ).count() - - return { - "active_count": active_count, - "max_allowed": 10, - "can_activate_more": active_count < 10 - } - - finally: - session.close() - - except Exception as e: - logger.log_message(f"Error getting active agents count: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to get active agents count: {str(e)}") - -@router.get("/validate_name/{agent_name}") -async def validate_agent_name(agent_name: str, user_id: int = Query(...)): - """Check if an agent name is available for a user""" - try: - session = session_factory() - - try: - existing_agent = session.query(CustomAgent).filter( - CustomAgent.user_id == user_id, - CustomAgent.agent_name == agent_name.lower().strip() - ).first() - - is_available = existing_agent is None - - return { - "agent_name": agent_name, - "is_available": is_available, - "message": "Agent name is available" if is_available else "Agent name already exists" - } - - finally: - session.close() - - except Exception as e: - logger.log_message(f"Error validating agent name: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to validate agent name: {str(e)}") - -# Template endpoints -@router.get("/templates/", response_model=List[TemplatesByCategory]) -async def get_template_agents(): - """Get all template agents organized by category""" - try: - session = session_factory() - - try: - templates = session.query(CustomAgent).filter( - CustomAgent.is_template == True, - CustomAgent.is_active == True - ).order_by(CustomAgent.template_category, CustomAgent.agent_name).all() - - # Group templates by category - categories = {} - for template in templates: - category = template.template_category or "Other" - if category not in categories: - categories[category] = [] - - categories[category].append(TemplateAgentResponse( - agent_id=template.agent_id, - agent_name=template.agent_name, - display_name=template.display_name, - description=template.description, - prompt_template=template.prompt_template, - template_category=template.template_category, - is_premium_only=template.is_premium_only, - is_active=template.is_active, - usage_count=template.usage_count, - created_at=template.created_at - )) - - # Convert to list of category objects - result = [TemplatesByCategory(category=category, templates=templates) - for category, templates in categories.items()] - - return result - - finally: - session.close() - - except Exception as e: - logger.log_message(f"Error retrieving template agents: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to retrieve template agents: {str(e)}") - -@router.get("/templates/{template_id}", response_model=TemplateAgentResponse) -async def get_template_agent(template_id: int): - """Get a specific template agent by ID""" - try: - session = session_factory() - - try: - template = session.query(CustomAgent).filter( - CustomAgent.agent_id == template_id, - CustomAgent.is_template == True - ).first() - - if not template: - raise HTTPException(status_code=404, detail=f"Template agent with ID {template_id} not found") - - return TemplateAgentResponse( - agent_id=template.agent_id, - agent_name=template.agent_name, - display_name=template.display_name, - description=template.description, - prompt_template=template.prompt_template, - template_category=template.template_category, - is_premium_only=template.is_premium_only, - is_active=template.is_active, - usage_count=template.usage_count, - created_at=template.created_at - ) - - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error retrieving template agent: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to retrieve template agent: {str(e)}") - -@router.post("/templates/{template_id}/copy", response_model=CustomAgentResponse) -async def copy_template_to_user(template_id: int, user_id: int = Query(...), agent_name: str = Query(...)): - """Copy a template agent to a user's custom agents""" - try: - session = session_factory() - - try: - # Validate user exists - user = session.query(User).filter(User.user_id == user_id).first() - if not user: - raise HTTPException(status_code=404, detail="User not found") - - # Get template agent - template = session.query(CustomAgent).filter( - CustomAgent.agent_id == template_id, - CustomAgent.is_template == True - ).first() - - if not template: - raise HTTPException(status_code=404, detail=f"Template agent with ID {template_id} not found") - - # Check if agent name already exists for user - existing_agent = session.query(CustomAgent).filter( - CustomAgent.user_id == user_id, - CustomAgent.agent_name == agent_name.lower().strip() - ).first() - - if existing_agent: - raise HTTPException(status_code=400, detail=f"Agent name '{agent_name}' already exists, please choose a different name") - - # Create new custom agent from template - now = datetime.now(UTC) - new_agent = CustomAgent( - user_id=user_id, - agent_name=agent_name.lower().strip(), - display_name=template.display_name, - description=template.description, - prompt_template=template.prompt_template, - is_template=False, # This is a user agent, not a template - template_category=None, # User agents don't have template categories - is_premium_only=False, # User agents are not premium-restricted - is_active=True, - usage_count=0, - created_at=now, - updated_at=now - ) - - session.add(new_agent) - session.commit() - session.refresh(new_agent) - - # Increment template usage count - template.usage_count += 1 - session.commit() - - logger.log_message(f"Copied template '{template.agent_name}' to user {user_id} as '{agent_name}'", level=logging.INFO) - - return CustomAgentResponse( - agent_id=new_agent.agent_id, - agent_name=new_agent.agent_name, - display_name=new_agent.display_name, - description=new_agent.description, - prompt_template=new_agent.prompt_template, - is_active=new_agent.is_active, - usage_count=new_agent.usage_count, - created_at=new_agent.created_at, - updated_at=new_agent.updated_at - ) - - except IntegrityError: - session.rollback() - raise HTTPException(status_code=400, detail=f"Agent name '{agent_name}' already exists, please choose a different name") - except Exception as e: - session.rollback() - logger.log_message(f"Error copying template to user: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to copy template: {str(e)}") - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error copying template to user: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to copy template: {str(e)}") - -@router.post("/templates/{template_id}/increment_usage") -async def increment_template_usage_count(template_id: int): - """Increment usage count for a template agent""" - try: - session = session_factory() - - try: - template = session.query(CustomAgent).filter( - CustomAgent.agent_id == template_id, - CustomAgent.is_template == True - ).first() - - if not template: - raise HTTPException(status_code=404, detail=f"Template agent with ID {template_id} not found") - - template.usage_count += 1 - template.updated_at = datetime.now(UTC) - session.commit() - - logger.log_message(f"Incremented template usage count for template {template_id} (now: {template.usage_count})", level=logging.INFO) - - return {"message": "Template usage count incremented", "usage_count": template.usage_count} - - except Exception as e: - session.rollback() - logger.log_message(f"Error incrementing template usage count: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to increment template usage count: {str(e)}") - finally: - session.close() - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error incrementing template usage count: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to increment template usage count: {str(e)}") - -@router.post("/generate_instructions", response_model=AgentInstructionResponse) -async def generate_agent_instructions(request: AgentInstructionRequest): - """Generate agent instructions using the custom_agent_instruction_generator""" - try: - # Validate category - valid_categories = ["Visualization", "Modelling", "Data Manipulation"] - if request.category not in valid_categories: - raise HTTPException( - status_code=400, - detail=f"Invalid category. Must be one of: {', '.join(valid_categories)}" - ) - - default_lm = dspy.LM( - model="openai/gpt-4o-mini", - api_key=os.getenv("OPENAI_API_KEY"), - temperature=0.7, - max_tokens=7000 - ) - - instruction_generator = dspy.ChainOfThought(custom_agent_instruction_generator) - - with dspy.context(lm=default_lm): - result = instruction_generator( - category=request.category, - user_requirements=request.user_requirements - ) - - logger.log_message(f"Generated agent instructions for category: {request.category}", level=logging.INFO) - - return AgentInstructionResponse( - agent_instructions=result.agent_instructions, - category=request.category, - user_requirements=request.user_requirements - ) - - except HTTPException: - raise - except Exception as e: - logger.log_message(f"Error generating agent instructions: {str(e)}", level=logging.ERROR) - raise HTTPException(status_code=500, detail=f"Failed to generate agent instructions: {str(e)}") \ No newline at end of file diff --git a/auto-analyst-backend/src/routes/templates_routes.py b/auto-analyst-backend/src/routes/templates_routes.py new file mode 100644 index 00000000..e7c354cf --- /dev/null +++ b/auto-analyst-backend/src/routes/templates_routes.py @@ -0,0 +1,545 @@ +import logging +import os +from fastapi import APIRouter, Depends, HTTPException, Query, Body +from pydantic import BaseModel, Field +from typing import List, Optional, Dict, Any +from datetime import datetime, UTC +from sqlalchemy import desc, func +from sqlalchemy.exc import IntegrityError + +from src.db.init_db import session_factory +from src.db.schemas.models import AgentTemplate, User, UserTemplatePreference +from src.utils.logger import Logger +from src.agents.agents import get_all_available_templates, toggle_user_template_preference + +# Initialize logger with console logging disabled +logger = Logger("templates_routes", see_time=True, console_log=False) + +# Initialize router +router = APIRouter(prefix="/templates", tags=["templates"]) + +# Pydantic models for request/response +class TemplateResponse(BaseModel): + template_id: int + template_name: str + display_name: Optional[str] + description: str + prompt_template: str + template_category: Optional[str] + icon_url: Optional[str] + is_premium_only: bool + is_active: bool + usage_count: int + created_at: datetime + updated_at: datetime + +class UserTemplatePreferenceResponse(BaseModel): + template_id: int + template_name: str + display_name: Optional[str] + description: str + template_category: Optional[str] + icon_url: Optional[str] + is_premium_only: bool + is_enabled: bool + usage_count: int + last_used_at: Optional[datetime] + +class ToggleTemplateRequest(BaseModel): + is_enabled: bool = Field(..., description="Whether to enable or disable the template") + +def get_global_usage_counts(session, template_ids: List[int] = None) -> Dict[int, int]: + """ + Calculate global usage counts for templates by summing usage_count across all users. + + Args: + session: Database session + template_ids: Optional list of template IDs to filter by. If None, gets all templates. + + Returns: + Dict mapping template_id to global usage count + """ + try: + query = session.query( + UserTemplatePreference.template_id, + func.sum(UserTemplatePreference.usage_count).label('total_usage') + ).group_by(UserTemplatePreference.template_id) + + if template_ids: + query = query.filter(UserTemplatePreference.template_id.in_(template_ids)) + + results = query.all() + + # Convert to dictionary, defaulting to 0 for templates with no usage + usage_dict = {template_id: int(total_usage or 0) for template_id, total_usage in results} + + # If specific template_ids were requested, ensure all are represented + if template_ids: + for template_id in template_ids: + if template_id not in usage_dict: + usage_dict[template_id] = 0 + + return usage_dict + + except Exception as e: + logger.log_message(f"Error calculating global usage counts: {str(e)}", level=logging.ERROR) + return {} + +# Routes +@router.get("/", response_model=List[TemplateResponse]) +async def get_all_templates(): + """Get all available agent templates with global usage statistics""" + try: + session = session_factory() + + try: + templates = get_all_available_templates(session) + + # Get template IDs for usage calculation + template_ids = [template.template_id for template in templates] + + # Calculate global usage counts + global_usage = get_global_usage_counts(session, template_ids) + + return [TemplateResponse( + template_id=template.template_id, + template_name=template.template_name, + display_name=template.display_name, + description=template.description, + prompt_template=template.prompt_template, + template_category=template.category, + icon_url=template.icon_url, + is_premium_only=template.is_premium_only, + is_active=template.is_active, + usage_count=global_usage.get(template.template_id, 0), # Global usage count + created_at=template.created_at, + updated_at=template.updated_at + ) for template in templates] + + finally: + session.close() + + except Exception as e: + logger.log_message(f"Error retrieving templates: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve templates: {str(e)}") + +@router.get("/user/{user_id}", response_model=List[UserTemplatePreferenceResponse]) +async def get_user_template_preferences(user_id: int): + """Get all templates with user preferences (enabled/disabled status and usage)""" + try: + session = session_factory() + + try: + # Validate user exists + user = session.query(User).filter(User.user_id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Get all active templates + templates = session.query(AgentTemplate).filter( + AgentTemplate.is_active == True + ).all() + + result = [] + for template in templates: + # Get user preference for this template if it exists + preference = session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.template_id == template.template_id + ).first() + + result.append(UserTemplatePreferenceResponse( + template_id=template.template_id, + template_name=template.template_name, + display_name=template.display_name, + description=template.description, + template_category=template.category, + icon_url=template.icon_url, + is_premium_only=template.is_premium_only, + is_enabled=preference.is_enabled if preference else False, # Default to disabled + usage_count=preference.usage_count if preference else 0, + last_used_at=preference.last_used_at if preference else None + )) + + return result + + finally: + session.close() + + except HTTPException: + raise + except Exception as e: + logger.log_message(f"Error retrieving user template preferences: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve user template preferences: {str(e)}") + +@router.get("/user/{user_id}/enabled", response_model=List[UserTemplatePreferenceResponse]) +async def get_user_enabled_templates(user_id: int): + """Get only templates that are enabled for the user (all templates enabled by default)""" + try: + session = session_factory() + + try: + # Validate user exists + user = session.query(User).filter(User.user_id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Get all active templates + all_templates = session.query(AgentTemplate).filter( + AgentTemplate.is_active == True + ).all() + + result = [] + for template in all_templates: + # Check if user has a preference record for this template + preference = session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.template_id == template.template_id + ).first() + + # Template is disabled by default unless explicitly enabled + is_enabled = preference.is_enabled if preference else False + + if is_enabled: + result.append(UserTemplatePreferenceResponse( + template_id=template.template_id, + template_name=template.template_name, + display_name=template.display_name, + description=template.description, + template_category=template.category, + icon_url=template.icon_url, + is_premium_only=template.is_premium_only, + is_enabled=True, + usage_count=preference.usage_count if preference else 0, + last_used_at=preference.last_used_at if preference else None + )) + + return result + + finally: + session.close() + + except HTTPException: + raise + except Exception as e: + logger.log_message(f"Error retrieving user enabled templates: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve user enabled templates: {str(e)}") + +@router.get("/user/{user_id}/enabled/planner", response_model=List[UserTemplatePreferenceResponse]) +async def get_user_enabled_templates_for_planner(user_id: int): + """Get enabled templates for planner use (max 10 templates)""" + try: + session = session_factory() + + try: + # Validate user exists + user = session.query(User).filter(User.user_id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Get enabled templates ordered by usage (most used first) and limit to 10 + enabled_preferences = session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.is_enabled == True + ).order_by( + UserTemplatePreference.usage_count.desc(), + UserTemplatePreference.last_used_at.desc() + ).limit(10).all() + + result = [] + for preference in enabled_preferences: + # Get template details + template = session.query(AgentTemplate).filter( + AgentTemplate.template_id == preference.template_id, + AgentTemplate.is_active == True + ).first() + + if template: + result.append(UserTemplatePreferenceResponse( + template_id=template.template_id, + template_name=template.template_name, + display_name=template.display_name, + description=template.description, + template_category=template.category, + icon_url=template.icon_url, + is_premium_only=template.is_premium_only, + is_enabled=True, + usage_count=preference.usage_count, + last_used_at=preference.last_used_at + )) + + logger.log_message(f"Retrieved {len(result)} enabled templates for planner for user {user_id}", level=logging.INFO) + return result + + finally: + session.close() + + except HTTPException: + raise + except Exception as e: + logger.log_message(f"Error retrieving planner templates for user {user_id}: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve planner templates: {str(e)}") + +@router.post("/user/{user_id}/template/{template_id}/toggle") +async def toggle_template_preference(user_id: int, template_id: int, request: ToggleTemplateRequest): + """Toggle a user's template preference (enable/disable for planner use)""" + try: + session = session_factory() + + try: + # Validate user exists + user = session.query(User).filter(User.user_id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + success, message = toggle_user_template_preference( + user_id, template_id, request.is_enabled, session + ) + + if not success: + raise HTTPException(status_code=400, detail=message) + + logger.log_message(f"Toggled template {template_id} for user {user_id}: {message}", level=logging.INFO) + + return {"message": message} + + finally: + session.close() + + except HTTPException: + raise + except Exception as e: + logger.log_message(f"Error toggling template preference: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to toggle template preference: {str(e)}") + +@router.post("/user/{user_id}/bulk-toggle") +async def bulk_toggle_template_preferences(user_id: int, request: dict): + """Bulk toggle multiple template preferences""" + try: + session = session_factory() + + try: + # Validate user exists + user = session.query(User).filter(User.user_id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + template_preferences = request.get("preferences", []) + if not template_preferences: + raise HTTPException(status_code=400, detail="No preferences provided") + + # Check current enabled count for limit enforcement + current_enabled_count = session.query(UserTemplatePreference).filter( + UserTemplatePreference.user_id == user_id, + UserTemplatePreference.is_enabled == True + ).count() + + # Count how many templates we're trying to enable + enabling_count = sum(1 for pref in template_preferences if pref.get("is_enabled", False)) + disabling_count = sum(1 for pref in template_preferences if not pref.get("is_enabled", False)) + + # Calculate what the new count would be + projected_enabled_count = current_enabled_count + enabling_count - disabling_count + + results = [] + for pref in template_preferences: + template_id = pref.get("template_id") + is_enabled = pref.get("is_enabled", True) + + if template_id is None: + results.append({"template_id": None, "success": False, "message": "Template ID required"}) + continue + + # Check 10-template limit for enabling + if is_enabled and projected_enabled_count > 10: + results.append({ + "template_id": template_id, + "success": False, + "message": "Cannot enable more than 10 templates for planner use", + "is_enabled": False + }) + continue + + success, message = toggle_user_template_preference( + user_id, template_id, is_enabled, session + ) + + results.append({ + "template_id": template_id, + "success": success, + "message": message, + "is_enabled": is_enabled + }) + + logger.log_message(f"Bulk toggled {len(template_preferences)} templates for user {user_id}", level=logging.INFO) + + return {"results": results} + + finally: + session.close() + + except HTTPException: + raise + except Exception as e: + logger.log_message(f"Error bulk toggling template preferences: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to bulk toggle template preferences: {str(e)}") + +@router.get("/template/{template_id}", response_model=TemplateResponse) +async def get_template(template_id: int): + """Get a specific template by ID with global usage statistics""" + try: + session = session_factory() + + try: + template = session.query(AgentTemplate).filter( + AgentTemplate.template_id == template_id + ).first() + + if not template: + raise HTTPException(status_code=404, detail=f"Template with ID {template_id} not found") + + # Calculate global usage count for this template + global_usage = get_global_usage_counts(session, [template_id]) + + return TemplateResponse( + template_id=template.template_id, + template_name=template.template_name, + display_name=template.display_name, + description=template.description, + prompt_template=template.prompt_template, + template_category=template.category, + icon_url=template.icon_url, + is_premium_only=template.is_premium_only, + is_active=template.is_active, + usage_count=global_usage.get(template_id, 0), # Global usage count + created_at=template.created_at, + updated_at=template.updated_at + ) + + finally: + session.close() + + except HTTPException: + raise + except Exception as e: + logger.log_message(f"Error retrieving template: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve template: {str(e)}") + +@router.get("/categories/list") +async def get_template_categories(): + """Get list of all template categories""" + try: + session = session_factory() + + try: + categories = session.query(AgentTemplate.category).filter( + AgentTemplate.is_active == True, + AgentTemplate.category.isnot(None) + ).distinct().all() + + category_list = [category[0] for category in categories if category[0]] + + return {"categories": category_list} + + finally: + session.close() + + except Exception as e: + logger.log_message(f"Error retrieving template categories: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve template categories: {str(e)}") + +@router.get("/categories") +async def get_templates_by_categories(): + """Get all templates grouped by category for frontend template browser with global usage statistics""" + try: + session = session_factory() + + try: + # Get all active templates + templates = session.query(AgentTemplate).filter( + AgentTemplate.is_active == True + ).order_by(AgentTemplate.category, AgentTemplate.template_name).all() + + # Get template IDs for usage calculation + template_ids = [template.template_id for template in templates] + + # Calculate global usage counts + global_usage = get_global_usage_counts(session, template_ids) + + # Group templates by category + categories_dict = {} + for template in templates: + category = template.category or "Uncategorized" + if category not in categories_dict: + categories_dict[category] = [] + + categories_dict[category].append({ + "agent_id": template.template_id, # Use template_id as agent_id for compatibility + "agent_name": template.template_name, + "display_name": template.display_name or template.template_name, + "description": template.description, + "prompt_template": template.prompt_template, + "template_category": template.category, + "icon_url": template.icon_url, + "is_premium_only": template.is_premium_only, + "is_active": template.is_active, + "usage_count": global_usage.get(template.template_id, 0), # Global usage count + "created_at": template.created_at.isoformat() if template.created_at else None + }) + + # Convert to list format expected by frontend + result = [] + for category, templates in categories_dict.items(): + result.append({ + "category": category, + "templates": templates + }) + + return result + + finally: + session.close() + + except Exception as e: + logger.log_message(f"Error retrieving templates by categories: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve templates by categories: {str(e)}") + +@router.get("/category/{category}") +async def get_templates_by_category(category: str): + """Get all templates in a specific category with global usage statistics""" + try: + session = session_factory() + + try: + templates = session.query(AgentTemplate).filter( + AgentTemplate.is_active == True, + AgentTemplate.category == category + ).all() + + # Get template IDs for usage calculation + template_ids = [template.template_id for template in templates] + + # Calculate global usage counts + global_usage = get_global_usage_counts(session, template_ids) + + return [TemplateResponse( + template_id=template.template_id, + template_name=template.template_name, + display_name=template.display_name, + description=template.description, + prompt_template=template.prompt_template, + template_category=template.category, + icon_url=template.icon_url, + is_premium_only=template.is_premium_only, + is_active=template.is_active, + usage_count=global_usage.get(template.template_id, 0), # Global usage count + created_at=template.created_at, + updated_at=template.updated_at + ) for template in templates] + + finally: + session.close() + + except Exception as e: + logger.log_message(f"Error retrieving templates by category: {str(e)}", level=logging.ERROR) + raise HTTPException(status_code=500, detail=f"Failed to retrieve templates by category: {str(e)}") \ No newline at end of file diff --git a/auto-analyst-backend/src/utils/model_registry.py b/auto-analyst-backend/src/utils/model_registry.py index 6f1186a1..1d8ac1f3 100644 --- a/auto-analyst-backend/src/utils/model_registry.py +++ b/auto-analyst-backend/src/utils/model_registry.py @@ -23,7 +23,7 @@ "o1": {"input": 0.015, "output": 0.06}, "o1-pro": {"input": 0.015, "output": 0.6}, "o1-mini": {"input": 0.00011, "output": 0.00044}, - "o3": {"input": 0.001, "output": 0.04}, + "o3": {"input": 0.002, "output": 0.008}, "o3-mini": {"input": 0.00011, "output": 0.00044}, "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015}, }, diff --git a/auto-analyst-frontend/components/chat/AgentSuggestions.tsx b/auto-analyst-frontend/components/chat/AgentSuggestions.tsx index 76682e30..e699875e 100644 --- a/auto-analyst-frontend/components/chat/AgentSuggestions.tsx +++ b/auto-analyst-frontend/components/chat/AgentSuggestions.tsx @@ -49,27 +49,22 @@ export default function AgentSuggestions({ const getUserId = (): number | null => { // Prioritize userId prop if provided if (userId !== undefined && userId !== null) { - console.log('AgentSuggestions - Using userId prop:', userId); return userId } if (session?.user?.id) { - console.log('AgentSuggestions - Using session user ID:', parseInt(session.user.id)); return parseInt(session.user.id) } const adminId = localStorage.getItem('adminUserId') if (adminId) { - console.log('AgentSuggestions - Using adminUserId from localStorage:', parseInt(adminId)); return parseInt(adminId) } - console.log('AgentSuggestions - No userId found'); return null } // Fetch agents from the main agents endpoint (includes both standard and custom) const fetchAllAgents = async (): Promise => { const currentUserId = getUserId() - console.log('fetchAllAgents - currentUserId:', currentUserId); try { // Build URL with user_id if available @@ -78,13 +73,10 @@ export default function AgentSuggestions({ agentsUrl += `?user_id=${currentUserId}` } - console.log('fetchAllAgents - agentsUrl:', agentsUrl); const response = await fetch(agentsUrl) - console.log('fetchAllAgents - response status:', response.status); if (response.ok) { const data = await response.json() - console.log('fetchAllAgents - response data:', data); const allAgents: AgentSuggestion[] = [] // Add standard agents @@ -99,22 +91,12 @@ export default function AgentSuggestions({ // Add template agents (only for users with custom agents access) if (data.template_agents && data.template_agents.length > 0 && customAgentsAccess.hasAccess) { - console.log('fetchAllAgents - found template agents, fetching details...'); const templateAgents = await fetchTemplateAgents() - console.log('fetchAllAgents - template agent details:', templateAgents); allAgents.push(...templateAgents) } - // Add custom agents (only for users with custom agents access) - if (data.custom_agents && data.custom_agents.length > 0 && customAgentsAccess.hasAccess) { - console.log('fetchAllAgents - found custom agents, fetching details...'); - // Fetch custom agent details - const customAgents = await fetchCustomAgents() - console.log('fetchAllAgents - custom agent details:', customAgents); - allAgents.push(...customAgents) - } + // Custom agents are deprecated - using templates only - console.log('fetchAllAgents - final allAgents:', allAgents); return allAgents } else { console.error('Failed to fetch agents:', response.status, await response.text()) @@ -128,44 +110,14 @@ export default function AgentSuggestions({ } } - // Fetch custom agents - const fetchCustomAgents = async (): Promise => { - const currentUserId = getUserId() - if (!currentUserId) { - return [] - } - - try { - const customAgentsUrl = `${API_URL}/custom_agents/?user_id=${currentUserId}` - const response = await fetch(customAgentsUrl) - - if (response.ok) { - const customAgents = await response.json() - const mappedAgents = customAgents.map((agent: any) => ({ - name: agent.agent_name, - description: agent.description, - isCustom: true - })) - return mappedAgents - } else { - console.error('Failed to fetch custom agents:', response.status, await response.text()) - } - } catch (error) { - console.error('Error fetching custom agents:', error) - } - return [] - } - // Fetch template agents const fetchTemplateAgents = async (): Promise => { try { - const templatesUrl = `${API_URL}/custom_agents/templates/` - console.log('fetchTemplateAgents - templatesUrl:', templatesUrl); + const templatesUrl = `${API_URL}/templates/categories` const response = await fetch(templatesUrl) if (response.ok) { const templateCategories = await response.json() - console.log('fetchTemplateAgents - template categories:', templateCategories); const allTemplates: AgentSuggestion[] = [] // Flatten all templates from all categories @@ -180,7 +132,6 @@ export default function AgentSuggestions({ } }) - console.log('fetchTemplateAgents - all templates:', allTemplates); return allTemplates } else { console.error('Failed to fetch template agents:', response.status, await response.text()) @@ -200,24 +151,6 @@ export default function AgentSuggestions({ loadAgents() }, [session, userId, customAgentsAccess.hasAccess]) - // Add event listener to refresh agents when custom agents are created/updated - useEffect(() => { - const handleRefreshAgents = () => { - const loadAgents = async () => { - const allAgents = await fetchAllAgents() - setAgents(allAgents) - } - loadAgents() - } - - // Listen for custom events that signal agent changes - window.addEventListener('custom-agents-updated', handleRefreshAgents) - - return () => { - window.removeEventListener('custom-agents-updated', handleRefreshAgents) - } - }, [session, userId, customAgentsAccess.hasAccess]) - // Filter agents based on current typing useEffect(() => { if (!isVisible || !message.includes('@')) { diff --git a/auto-analyst-frontend/components/chat/ChatInput.tsx b/auto-analyst-frontend/components/chat/ChatInput.tsx index ca1c0996..20744e21 100644 --- a/auto-analyst-frontend/components/chat/ChatInput.tsx +++ b/auto-analyst-frontend/components/chat/ChatInput.tsx @@ -35,8 +35,6 @@ import { } from "@/components/ui/select" // Deep Analysis imports import { DeepAnalysisSidebar, DeepAnalysisButton } from '../deep-analysis' -// Custom Agents imports -import { CustomAgentsSidebar, CustomAgentsButton } from '../custom-agents' import CommandSuggestions from './CommandSuggestions' import AgentSuggestions from './AgentSuggestions' import { useUserSubscriptionStore } from '@/lib/store/userSubscriptionStore' @@ -225,15 +223,15 @@ const ChatInput = forwardRef< const [shouldForceExpanded, setShouldForceExpanded] = useState(false) // Custom Agents states - const [showCustomAgentsSidebar, setShowCustomAgentsSidebar] = useState(false) - const [shouldForceExpandedCustomAgents, setShouldForceExpandedCustomAgents] = useState(false) + const [showTemplatesSidebar, setShowTemplatesSidebar] = useState(false) + const [shouldForceExpandedTemplates, setShouldForceExpandedTemplates] = useState(false) const [showCommandSuggestions, setShowCommandSuggestions] = useState(false) const [commandQuery, setCommandQuery] = useState('') // Get subscription from store instead of manual construction const { subscription } = useUserSubscriptionStore() const deepAnalysisAccess = useFeatureAccess('DEEP_ANALYSIS', subscription) - + // Expose handlePreviewDefaultDataset to parent useImperativeHandle(ref, () => ({ handlePreviewDefaultDataset, @@ -247,7 +245,6 @@ const ChatInput = forwardRef< useEffect(() => { const checkDisabledStatus = () => { const isDisabled = isInputDisabled(); - // logger.log(`[ChatInput] Input disabled on mount: ${isDisabled}, isChatBlocked: ${isChatBlocked}`); }; checkDisabledStatus(); }, []); @@ -278,7 +275,6 @@ const ChatInput = forwardRef< useEffect(() => { // When sessionId changes (switching chats), check for dataset info if (sessionId) { - logger.log('Session ID changed, checking dataset info:', sessionId); // First try to get session info to see if we have a custom dataset axios.get(`${PREVIEW_API_URL}/api/session-info`, { @@ -289,7 +285,6 @@ const ChatInput = forwardRef< .then(infoResponse => { const { is_custom_dataset, dataset_name, dataset_description } = infoResponse.data; - logger.log('Session info response:', infoResponse.data); if (is_custom_dataset) { // If we have a custom dataset, check if we have local file info @@ -325,7 +320,7 @@ const ChatInput = forwardRef< setFilePreview({ headers, rows, name, description }); setDatasetDescription({ name, description }); - logger.log('Successfully restored dataset preview data'); + }) .catch(error => { logger.error('Failed to get dataset preview:', error); @@ -411,8 +406,6 @@ const ChatInput = forwardRef< } }); - logger.log("Session info in ChatInput:", response.data); - // If we have a custom dataset on the server if (response.data && response.data.is_custom_dataset) { const customName = response.data.dataset_name || 'Custom Dataset'; @@ -438,8 +431,6 @@ const ChatInput = forwardRef< } } else if (!fileUpload && !hasLocalStorageFile) { // UI shows no custom dataset, but server has one, and no localStorage - // This is likely after a refresh - show the dataset reset popup - logger.log("UI shows no dataset, but server has custom dataset - showing reset dialog"); // Create a mock File object just for display purposes const mockFile = new File([""], `${customName}.csv`, { type: 'text/csv' }); @@ -457,7 +448,6 @@ const ChatInput = forwardRef< } else if (fileUpload && fileUpload.status === 'success') { // The UI shows a custom dataset, but the server says we're using the default // This means there's a mismatch - the session was reset on the server side - logger.log("Dataset mismatch detected: UI shows custom dataset but server uses default"); setDatasetMismatch(true); setShowDatasetResetPopup(true); } else { @@ -522,15 +512,7 @@ const ChatInput = forwardRef< setErrorNotification(null); if (errorTimeoutRef.current) { clearTimeout(errorTimeoutRef.current); - } - - // Log file details for debugging - logger.log('Selected file:', { - name: file.name, - size: file.size, - type: file.type, - lastModified: file.lastModified - }); + } // Check file type before proceeding const isCSVByExtension = file.name.toLowerCase().endsWith('.csv'); @@ -673,7 +655,6 @@ const ChatInput = forwardRef< 'X-Session-ID': sessionId, }, }); - logger.log('Session reset before new file upload'); // Reset the popup shown flags to ensure we show the popup for this new dataset state popupShownForChatIdsRef.current = new Set(); @@ -684,7 +665,6 @@ const ChatInput = forwardRef< } // Always do a fresh upload for new files - logger.log('Uploading new file and getting preview...', file.name, file.size, file.type); const formData = new FormData(); formData.append('file', file); @@ -700,15 +680,6 @@ const ChatInput = forwardRef< formData.append('name', tempName); formData.append('description', existingDescription); - logger.log('FormData prepared:', { - fileName: file.name, - fileSize: file.size, - fileType: file.type, - name: tempName, - description: existingDescription, - isNewDataset: isNewDataset - }); - // Upload the file try { const uploadResponse = await axios.post(`${PREVIEW_API_URL}/upload_dataframe`, formData, { @@ -719,7 +690,6 @@ const ChatInput = forwardRef< }, }); - logger.log('Upload response:', uploadResponse.data); const previewSessionId = uploadResponse.data.session_id || sessionId; // Capture the dataset upload ID if available @@ -737,7 +707,6 @@ const ChatInput = forwardRef< }, }); - logger.log('Preview response:', previewResponse.data); // Extract all fields including name and description const { headers, rows, name, description } = previewResponse.data; @@ -824,7 +793,6 @@ const ChatInput = forwardRef< }, 5000); } } else { - logger.log('Not a CSV file'); // Set error notification with detailed information setErrorNotification({ message: 'Invalid file format', @@ -939,11 +907,11 @@ const ChatInput = forwardRef< setTimeout(() => setShouldForceExpanded(false), 100) } else if (command.id === 'custom-agents') { // Show custom agents sidebar in expanded state - setShouldForceExpandedCustomAgents(true) - setShowCustomAgentsSidebar(true) + setShouldForceExpanded(true) + setShowTemplatesSidebar(true) setMessage('') // Reset force expanded after a brief moment - setTimeout(() => setShouldForceExpandedCustomAgents(false), 100) + setTimeout(() => setShouldForceExpanded(false), 100) } else { // For other commands, replace the "/" with the command setMessage(`${command.name} `) @@ -1013,7 +981,6 @@ const ChatInput = forwardRef< ...(sessionId && { 'X-Session-ID': sessionId }), }, }); - logger.log('Session forcefully reset to default dataset'); } catch (resetError) { console.error('Failed to reset session for default dataset:', resetError); // Continue anyway @@ -1053,7 +1020,6 @@ const ChatInput = forwardRef< setDatasetMismatch(false); setShowDatasetResetPopup(false); - logger.log("Default dataset preview loaded, upload state reset"); } catch (error) { console.error('Failed to fetch dataset preview:', error); } @@ -1083,7 +1049,6 @@ const ChatInput = forwardRef< ...(sessionId && { 'X-Session-ID': sessionId }), }, }); - logger.log('Session silently reset to default dataset'); } catch (resetError) { console.error('Failed to silently reset session for default dataset:', resetError); // Continue anyway @@ -1124,7 +1089,6 @@ const ChatInput = forwardRef< setDatasetMismatch(false); setShowDatasetResetPopup(false); - logger.log("Default dataset silently loaded, upload state reset"); } catch (error) { console.error('Failed to silently load default dataset:', error); } @@ -1151,23 +1115,12 @@ const ChatInput = forwardRef< clearTimeout(errorTimeoutRef.current); } - // Log the description we're about to use - logger.log('Using dataset description for upload:', datasetDescription.description); // Try to get the actual file from the file input ref first (most reliable source) const actualFile = fileInputRef.current?.files?.[0] || (fileUpload?.file || null); if (actualFile) { - // Log file details to console for debugging - logger.log("Upload file details:", { - name: actualFile.name, - size: actualFile.size, - type: actualFile.type, - lastModified: actualFile.lastModified, - description: datasetDescription.description, - isExcel: fileUpload?.isExcel, - selectedSheet: fileUpload?.selectedSheet - }); + // Only check for mock files in specific cases when we know it was created programmatically // This avoids incorrectly flagging legitimate small files @@ -1205,7 +1158,6 @@ const ChatInput = forwardRef< 'X-Session-ID': sessionId, }, }); - logger.log('Session reset before final upload'); // Reset the popup shown flags for the new dataset state popupShownForChatIdsRef.current = new Set(); @@ -1234,15 +1186,6 @@ const ChatInput = forwardRef< formData.append('sheet_name', fileUpload.selectedSheet); } - logger.log('Final upload with description:', { - fileName: actualFile.name, - fileSize: actualFile.size, - name: datasetDescription.name, - description: finalDescription, - isExcel: isExcelFile, - selectedSheet: fileUpload?.selectedSheet - }); - try { // Use the appropriate endpoint based on file type const endpoint = isExcelFile ? `${PREVIEW_API_URL}/upload_excel` : `${PREVIEW_API_URL}/upload_dataframe`; @@ -1388,7 +1331,6 @@ const ChatInput = forwardRef< const resetDate = new Date(creditResetDate); if (!isNaN(resetDate.getTime())) { - logger.log(`[ChatInput] Using actual reset date from Redis: ${resetDate.toISOString()}`); return resetDate.toLocaleDateString('en-US', { year: 'numeric', month: 'long', @@ -1460,7 +1402,6 @@ const ChatInput = forwardRef< // If we have a file input reference, clear it and trigger a click if (fileInputRef.current) { - logger.log("Clearing file input and requesting new selection"); fileInputRef.current.value = ""; // Close the dataset reset popup first @@ -1475,7 +1416,6 @@ const ChatInput = forwardRef< }, 100); } else { // If we can't access the file input, show the preview dialog - logger.log("Showing preview dialog for file selection"); setShowPreview(true); // Pre-fill the name from the file @@ -1491,7 +1431,6 @@ const ChatInput = forwardRef< } else { // This is a real file, we can try to show the preview directly try { - logger.log("Showing preview for existing file"); await handleFilePreview(fileUpload.file); // Close the dataset reset popup @@ -1945,18 +1884,19 @@ const ChatInput = forwardRef< )} - { - setShouldForceExpandedCustomAgents(true) - setShowCustomAgentsSidebar(true) + setShouldForceExpanded(true) + setShowTemplatesSidebar(true) // Reset force expanded after a brief moment - setTimeout(() => setShouldForceExpandedCustomAgents(false), 100) + setTimeout(() => setShouldForceExpanded(false), 100) }} userProfile={subscription} showLabel={true} size="sm" /> - + */} + { setShouldForceExpanded(true) @@ -2423,13 +2363,13 @@ const ChatInput = forwardRef< forceExpanded={shouldForceExpanded} /> - {/* Custom Agents Sidebar */} - setShowCustomAgentsSidebar(false)} + {/* Templates Sidebar */} + {/* setShowTemplatesSidebar(false)} userId={userId} - forceExpanded={shouldForceExpandedCustomAgents} - /> + forceExpanded={shouldForceExpanded} + /> */} ) }) diff --git a/auto-analyst-frontend/components/chat/ChatInterface.tsx b/auto-analyst-frontend/components/chat/ChatInterface.tsx index 55214937..99f983ae 100644 --- a/auto-analyst-frontend/components/chat/ChatInterface.tsx +++ b/auto-analyst-frontend/components/chat/ChatInterface.tsx @@ -1502,6 +1502,7 @@ const ChatInterface: React.FC = () => { onChatSelect={loadChat} isLoading={isLoadingHistory} onDeleteChat={handleChatDelete} + userId={userId || undefined} /> )} diff --git a/auto-analyst-frontend/components/chat/Sidebar.tsx b/auto-analyst-frontend/components/chat/Sidebar.tsx index efa59377..398a4617 100644 --- a/auto-analyst-frontend/components/chat/Sidebar.tsx +++ b/auto-analyst-frontend/components/chat/Sidebar.tsx @@ -14,6 +14,8 @@ import API_URL from '@/config/api' import { format } from 'date-fns' import { useModelSettings } from '@/lib/hooks/useModelSettings' import logger from '@/lib/utils/logger' +import { TemplatesButton, TemplatesModal, useTemplates } from '@/components/custom-templates' +import { useUserSubscriptionStore } from '@/lib/store/userSubscriptionStore' const PREVIEW_API_URL = API_URL; @@ -31,9 +33,10 @@ interface SidebarProps { onChatSelect: (chatId: number) => void isLoading: boolean onDeleteChat: (chatId: number) => void + userId?: number | null } -const Sidebar: React.FC = ({ isOpen, onClose, onNewChat, chatHistories = [], activeChatId, onChatSelect, isLoading, onDeleteChat }) => { +const Sidebar: React.FC = ({ isOpen, onClose, onNewChat, chatHistories = [], activeChatId, onChatSelect, isLoading, onDeleteChat, userId: userIdProp }) => { const { clearMessages } = useChatHistoryStore() const { data: session } = useSession() const router = useRouter() @@ -43,10 +46,57 @@ const Sidebar: React.FC = ({ isOpen, onClose, onNewChat, chatHisto const [isAdmin, setIsAdmin] = useState(false) const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false); const [chatToDelete, setChatToDelete] = useState(null); - + const [isTemplatesModalOpen, setIsTemplatesModalOpen] = useState(false); + const { subscription } = useUserSubscriptionStore() + + // Get current user profile for templates + const userProfile = subscription || null + + // Handle userId properly for both regular users and admin testing + const [userId, setUserId] = useState(1) // Default fallback + + const getUserId = () => { + // First check if userId prop is provided (correct user-specific ID) + if (userIdProp) { + return userIdProp + } + + // Only access localStorage on client side + if (typeof window !== 'undefined') { + // Then check for admin testing user ID + const adminUserId = localStorage.getItem('adminUserId') + if (adminUserId) { + return parseInt(adminUserId) + } + } + + // Then check session user ID + if (session?.user?.id) { + return parseInt(session.user.id) + } + + // Default fallback + return 1 // Use 1 instead of 0 since user IDs start from 1 + } + + // Update userId when dependencies change useEffect(() => { - setIsAdmin(localStorage.getItem('isAdmin') === 'true') + const newUserId = getUserId() + setUserId(newUserId) + }, [userIdProp, session?.user?.id]) + + // Also check admin status on client side only + useEffect(() => { + if (typeof window !== 'undefined') { + setIsAdmin(localStorage.getItem('isAdmin') === 'true') + } }, []) + + // Use templates hook for data management + const { templateCount, enabledCount } = useTemplates({ + userId, + enabled: isOpen && !!userId + }) const handleNewChat = async () => { if (sessionId) { @@ -98,7 +148,7 @@ const Sidebar: React.FC = ({ isOpen, onClose, onNewChat, chatHisto } const handleSignOut = async () => { - if (localStorage.getItem('isAdmin') === 'true') { + if (typeof window !== 'undefined' && localStorage.getItem('isAdmin') === 'true') { // Clear admin status localStorage.removeItem('isAdmin') // Redirect to home page @@ -247,6 +297,16 @@ const Sidebar: React.FC = ({ isOpen, onClose, onNewChat, chatHisto New Chat + {/* Templates Section */} +
+ setIsTemplatesModalOpen(true)} + userProfile={userProfile} + templateCount={templateCount} + enabledCount={enabledCount} + /> +
+ {/* Chat History - more minimal with less padding */}
@@ -368,6 +428,13 @@ const Sidebar: React.FC = ({ isOpen, onClose, onNewChat, chatHisto
)} + + {/* Templates Modal */} + setIsTemplatesModalOpen(false)} + userId={userId} + /> ) } diff --git a/auto-analyst-frontend/components/custom-agents/AgentDetailView.tsx b/auto-analyst-frontend/components/custom-agents/AgentDetailView.tsx deleted file mode 100644 index e2b08890..00000000 --- a/auto-analyst-frontend/components/custom-agents/AgentDetailView.tsx +++ /dev/null @@ -1,417 +0,0 @@ -'use client'; - -import React, { useState, useEffect } from 'react'; -import { Button } from '@/components/ui/button'; -import { Input } from '@/components/ui/input'; -import { Textarea } from '@/components/ui/textarea'; -import { Label } from '@/components/ui/label'; -import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; -import { Badge } from '@/components/ui/badge'; -import { Switch } from '@/components/ui/switch'; -import { - ArrowLeft, - Edit, - Save, - X, - Calendar, - TrendingUp, - Settings2, - Trash2, - AlertCircle, - CheckCircle -} from 'lucide-react'; -import { CustomAgent } from './types'; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, - AlertDialogTrigger, -} from '@/components/ui/alert-dialog'; - -interface AgentDetailViewProps { - agent: CustomAgent | null; - onUpdateAgent: (agentId: number, updates: Partial) => Promise; - onDeleteAgent: (agentId: number) => Promise; - onToggleActive?: (agentId: number) => Promise; - onBack: () => void; -} - -export default function AgentDetailView({ - agent, - onUpdateAgent, - onDeleteAgent, - onToggleActive, - onBack -}: AgentDetailViewProps) { - const [isEditing, setIsEditing] = useState(false); - const [editedAgent, setEditedAgent] = useState>({}); - const [isUpdating, setIsUpdating] = useState(false); - const [isDeleting, setIsDeleting] = useState(false); - - useEffect(() => { - if (agent) { - setEditedAgent({ - display_name: agent.display_name, - description: agent.description, - prompt_template: agent.prompt_template, - is_active: agent.is_active - }); - } - }, [agent]); - - if (!agent) { - return ( -
- -

No Agent Selected

-

- Select an agent from the list to view and edit its details. -

-
- ); - } - - const handleEdit = () => { - setIsEditing(true); - }; - - const handleCancelEdit = () => { - setIsEditing(false); - setEditedAgent({ - display_name: agent.display_name, - description: agent.description, - prompt_template: agent.prompt_template, - is_active: agent.is_active - }); - }; - - const handleSave = async () => { - setIsUpdating(true); - try { - const success = await onUpdateAgent(agent.agent_id, editedAgent); - if (success) { - setIsEditing(false); - } - } finally { - setIsUpdating(false); - } - }; - - const handleDelete = async () => { - setIsDeleting(true); - try { - const success = await onDeleteAgent(agent.agent_id); - if (success) { - onBack(); - } - } finally { - setIsDeleting(false); - } - }; - - const handleInputChange = (field: keyof CustomAgent, value: string | boolean) => { - setEditedAgent(prev => ({ ...prev, [field]: value })); - }; - - const formatDate = (dateString: string) => { - return new Date(dateString).toLocaleDateString('en-US', { - month: 'long', - day: 'numeric', - year: 'numeric', - hour: '2-digit', - minute: '2-digit' - }); - }; - - const canSave = () => { - return ( - (editedAgent.description?.length || 0) >= 10 && - (editedAgent.prompt_template?.length || 0) >= 50 - ); - }; - - return ( -
- {/* Header */} -
-
- -
-

- {agent.display_name || agent.agent_name} -

-
- - @{agent.agent_name} - -
-
-
- {agent.is_active ? 'Active' : 'Inactive'} -
- {!isEditing && onToggleActive && ( - { - onToggleActive(agent.agent_id); - }} - className="data-[state=checked]:bg-[#FF7F7F] scale-75" - /> - )} -
-
-
-
- -
- {isEditing ? ( - <> - - - - ) : ( - <> - - - - - - - - - Delete Custom Agent - - Are you sure you want to delete "{agent.display_name || agent.agent_name}"? - This action cannot be undone and the agent will no longer be available for use. - - - - Cancel - - {isDeleting ? 'Deleting...' : 'Delete'} - - - - - - )} -
-
- - {/* Content */} -
- {/* Metadata Card */} - - - Agent Metadata - - -
-
- -
-

Created

-

{formatDate(agent.created_at)}

-
-
-
- -
-

Usage Count

-

{agent.usage_count} times

-
-
-
- - {agent.updated_at !== agent.created_at && ( -
- Last updated: {formatDate(agent.updated_at)} -
- )} -
-
- - {/* Agent Configuration */} - - - Configuration - - Customize your agent's behavior and availability - - - - {/* Active Status */} -
-
- -

- When active, this agent can be used in conversations -

-
- handleInputChange('is_active', value)} - disabled={!isEditing} - className="flex-shrink-0" - /> -
- - {/* Display Name */} -
- - {isEditing ? ( - handleInputChange('display_name', e.target.value)} - placeholder="User-friendly name for your agent" - className="mt-1 text-sm" - /> - ) : ( -
- {agent.display_name || No display name set} -
- )} -
- - {/* Description */} -
- - {isEditing ? ( - <> -