diff --git a/chuck_data/agent/manager.py b/chuck_data/agent/manager.py index 7ec47e8..74cfd5e 100644 --- a/chuck_data/agent/manager.py +++ b/chuck_data/agent/manager.py @@ -210,29 +210,36 @@ def process_with_tools(self, tools, max_iterations: int = 20): if response_message.tool_calls: # Add the assistant's response (requesting tool calls) to history # Convert ChatCompletionMessage to dict format for consistency + tool_calls_list = [] + for tc in response_message.tool_calls: + func = getattr(tc, "function", None) + if func is not None: + tool_calls_list.append( + { + "id": tc.id, + "type": getattr(tc, "type", "function"), + "function": { + "name": getattr(func, "name", ""), + "arguments": getattr(func, "arguments", "{}"), + }, + } + ) assistant_msg = { "role": "assistant", "content": response_message.content, - "tool_calls": [ - { - "id": tc.id, - "type": getattr(tc, "type", "function"), - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments, - }, - } - for tc in response_message.tool_calls - ], + "tool_calls": tool_calls_list, } self.conversation_history.append(assistant_msg) # Execute each tool call for tool_call in response_message.tool_calls: - tool_name = tool_call.function.name + func = getattr(tool_call, "function", None) + if func is None: + continue + tool_name = getattr(func, "name", "") tool_id = tool_call.id try: - tool_args = json.loads(tool_call.function.arguments) + tool_args = json.loads(getattr(func, "arguments", "{}")) tool_result = execute_tool( self.api_client, tool_name, @@ -276,7 +283,7 @@ def process_with_tools(self, tools, max_iterations: int = 20): continue else: # No tool calls, this is the final response - final_content = response_message.content + final_content = response_message.content or "" # remove all lines with any tags final_content = "\n".join( line diff --git a/chuck_data/agent/tool_executor.py b/chuck_data/agent/tool_executor.py index 181bf60..5774aa1 100644 --- a/chuck_data/agent/tool_executor.py +++ b/chuck_data/agent/tool_executor.py @@ -23,7 +23,8 @@ from chuck_data.clients.databricks import ( DatabricksAPIClient, ) # For type hinting api_client -from typing import Dict, Any, Optional, List +from typing import Dict, Any, Optional, List, Callable +from jsonschema.exceptions import ValidationError # The display_to_user utility and individual tool implementation functions @@ -48,7 +49,7 @@ def execute_tool( api_client: Optional[DatabricksAPIClient], tool_name: str, tool_args: Dict[str, Any], - output_callback: Optional[callable] = None, + output_callback: Optional[Callable[..., Any]] = None, ) -> Dict[str, Any]: """Execute a tool (command) by its name with the provided arguments. @@ -87,7 +88,7 @@ def execute_tool( try: jsonschema.validate(instance=tool_args, schema=schema_to_validate) logging.debug(f"Tool arguments for '{tool_name}' validated successfully.") - except jsonschema.exceptions.ValidationError as ve: + except ValidationError as ve: logging.error( f"Validation error for tool '{tool_name}' args {tool_args}: {ve.message}" ) diff --git a/chuck_data/api_client.py b/chuck_data/api_client.py index cd9a51a..989caaf 100644 --- a/chuck_data/api_client.py +++ b/chuck_data/api_client.py @@ -135,6 +135,8 @@ def upload_file(self, path, file_path=None, content=None, overwrite=False): binary_data = f.read() else: # Convert string content to bytes + # content is guaranteed non-None by the validation above + assert content is not None binary_data = content.encode("utf-8") try: diff --git a/chuck_data/clients/amperity.py b/chuck_data/clients/amperity.py index 3c12c91..9874e80 100644 --- a/chuck_data/clients/amperity.py +++ b/chuck_data/clients/amperity.py @@ -8,6 +8,7 @@ import webbrowser import readchar import json +from typing import Optional from rich.console import Console from chuck_data.config import set_amperity_token @@ -113,7 +114,7 @@ def get_auth_status(self) -> dict: return {"state": self.state, "nonce": self.nonce, "has_token": bool(self.token)} def wait_for_auth_completion( - self, poll_interval: int = 1, timeout: int = None + self, poll_interval: int = 1, timeout: Optional[int] = None ) -> tuple[bool, str]: """Wait for authentication to complete in a blocking manner.""" if not self.nonce: diff --git a/chuck_data/clients/databricks.py b/chuck_data/clients/databricks.py index c8e1b4c..2e63f6f 100644 --- a/chuck_data/clients/databricks.py +++ b/chuck_data/clients/databricks.py @@ -724,6 +724,8 @@ def upload_file(self, path, file_path=None, content=None, overwrite=False): binary_data = f.read() else: # Convert string content to bytes + # content is guaranteed non-None by the validation above + assert content is not None binary_data = content.encode("utf-8") try: diff --git a/chuck_data/command_output.py b/chuck_data/command_output.py index 573f0a2..7f1abd2 100644 --- a/chuck_data/command_output.py +++ b/chuck_data/command_output.py @@ -10,7 +10,7 @@ from chuck_data.ui.table_formatter import display_table -from chuck_data.command_result import CommandResult +from chuck_data.commands.base import CommandResult from chuck_data.ui.theme import ( SUCCESS, WARNING, @@ -132,7 +132,7 @@ def format_for_agent(result: CommandResult) -> Dict[str, Any]: } # Start with a base response - response = {"success": True} + response: Dict[str, Any] = {"success": True} # Add the message if available if result.message: diff --git a/chuck_data/commands/catalog_selection.py b/chuck_data/commands/catalog_selection.py index 0adc29c..ff7451c 100644 --- a/chuck_data/commands/catalog_selection.py +++ b/chuck_data/commands/catalog_selection.py @@ -64,7 +64,7 @@ def handle_command(client: Optional[DatabricksAPIClient], **kwargs) -> CommandRe client: API client instance **kwargs: catalog (str) - catalog name, tool_output_callback (optional) """ - catalog: str = kwargs.get("catalog") + catalog = kwargs.get("catalog") tool_output_callback = kwargs.get("tool_output_callback") if not catalog: diff --git a/chuck_data/commands/job_status.py b/chuck_data/commands/job_status.py index 5484d20..1498e02 100644 --- a/chuck_data/commands/job_status.py +++ b/chuck_data/commands/job_status.py @@ -639,7 +639,7 @@ def handle_list_jobs(client=None, **kwargs) -> CommandResult: cached_job_data = job_entry.get("job_data") # If we have cached data for a terminal state, use it - if cached_job_data: + if cached_job_data and isinstance(cached_job_data, dict): state = (cached_job_data.get("state") or "").lower().replace(":", "") # Only use cache for terminal states (succeeded, failed, unknown) if state in ["succeeded", "success", "failed", "error", "unknown"]: diff --git a/chuck_data/commands/jobs.py b/chuck_data/commands/jobs.py index a282303..6870504 100644 --- a/chuck_data/commands/jobs.py +++ b/chuck_data/commands/jobs.py @@ -13,8 +13,8 @@ def handle_launch_job(client: Optional[DatabricksAPIClient], **kwargs) -> Comman client: API client instance **kwargs: config_path (str), init_script_path (str), run_name (str, optional), tool_output_callback (callable, optional) """ - config_path: str = kwargs.get("config_path") - init_script_path: str = kwargs.get("init_script_path") + config_path = kwargs.get("config_path") + init_script_path = kwargs.get("init_script_path") run_name: Optional[str] = kwargs.get("run_name") tool_output_callback = kwargs.get("tool_output_callback") policy_id: Optional[str] = kwargs.get("policy_id") diff --git a/chuck_data/commands/model_selection.py b/chuck_data/commands/model_selection.py index 4aab81a..ddb4040 100644 --- a/chuck_data/commands/model_selection.py +++ b/chuck_data/commands/model_selection.py @@ -24,7 +24,7 @@ def handle_command(client: Optional[DatabricksAPIClient], **kwargs) -> CommandRe client: API client instance (used for Databricks provider) **kwargs: model_name (str) """ - model_name: str = kwargs.get("model_name") + model_name = kwargs.get("model_name") if not model_name: return CommandResult(False, message="model_name parameter is required.") @@ -45,7 +45,7 @@ def handle_command(client: Optional[DatabricksAPIClient], **kwargs) -> CommandRe models_list = provider.list_models(tool_calling_only=False) # Extract model IDs (field name varies by provider) - model_ids = [m.get("model_id") or m.get("name") for m in models_list] + model_ids = [m.get("model_id") or m.get("name") or "" for m in models_list] # Validate model exists if model_name not in model_ids: diff --git a/chuck_data/commands/pii_tools.py b/chuck_data/commands/pii_tools.py index 5395892..65a5c9f 100644 --- a/chuck_data/commands/pii_tools.py +++ b/chuck_data/commands/pii_tools.py @@ -26,17 +26,16 @@ def _helper_tag_pii_columns_logic( response_content_for_error = "" try: # Resolve full table name using APIs directly instead of handler - table_details_kwargs = {"full_name": table_name_param} + resolved_table_name = table_name_param if catalog_name_context and schema_name_context and "." not in table_name_param: # Only a table name was provided, construct full name - full_name = ( + resolved_table_name = ( f"{catalog_name_context}.{schema_name_context}.{table_name_param}" ) - table_details_kwargs = {"full_name": full_name} try: # Use direct API call instead of handle_table - table_info = databricks_client.get_table(**table_details_kwargs) + table_info = databricks_client.get_table(full_name=resolved_table_name) if not table_info: error_msg = f"Failed to retrieve table details for PII tagging: {table_name_param}" return { @@ -88,6 +87,8 @@ def _helper_tag_pii_columns_logic( "and assign a PII semantic tag to each column if applicable. Use ONLY the following PII semantic tags: " "address, address2, birthdate, city, country, create-dt, email, full-name, gender, generational-suffix, " "given-name, phone, postal, state, surname, title, update-dt. If a column does not contain PII, assign null. " + "IMPORTANT: Do NOT assign semantic tags to numeric columns (types: LONG, BIGINT, INT, INTEGER, SMALLINT, " + "TINYINT, DOUBLE, FLOAT, DECIMAL, NUMERIC). Always assign null to numeric columns. " "Respond ONLY with a valid JSON list of objects, where each object represents a column and has the following structure: " '{"name": "column_name", "semantic": "pii_tag_or_null"}. ' "Maintain original order. No explanations or introductory text." @@ -100,9 +101,9 @@ def _helper_tag_pii_columns_logic( {"role": "user", "content": user_prompt}, ] ) - response_content_for_error = llm_response_obj.choices[ - 0 - ].message.content # Store for potential error reporting + response_content_for_error = ( + llm_response_obj.choices[0].message.content or "" + ) # Store for potential error reporting response_content_clean = response_content_for_error.strip() if response_content_clean.startswith("```json"): response_content_clean = response_content_clean[7:-3].strip() diff --git a/chuck_data/commands/schema_selection.py b/chuck_data/commands/schema_selection.py index 9b63c7c..9f938e5 100644 --- a/chuck_data/commands/schema_selection.py +++ b/chuck_data/commands/schema_selection.py @@ -65,7 +65,7 @@ def handle_command(client: Optional[DatabricksAPIClient], **kwargs) -> CommandRe client: API client instance **kwargs: schema (str) - schema name, tool_output_callback (optional) """ - schema: str = kwargs.get("schema") + schema = kwargs.get("schema") tool_output_callback = kwargs.get("tool_output_callback") if not schema: diff --git a/chuck_data/commands/setup_stitch.py b/chuck_data/commands/setup_stitch.py index 0e1b18a..31f6f1b 100644 --- a/chuck_data/commands/setup_stitch.py +++ b/chuck_data/commands/setup_stitch.py @@ -736,7 +736,7 @@ def _build_post_launch_guidance_message(launch_result, metadata, client=None): ) # Get workspace URL for constructing browser links - workspace_url = get_workspace_url() + workspace_url = get_workspace_url() or "" # If workspace_url is already a full URL, normalize it to get just the workspace ID # If it's just the workspace ID, this will return it as-is workspace_id = normalize_workspace_url(workspace_url) diff --git a/chuck_data/commands/setup_wizard.py b/chuck_data/commands/setup_wizard.py index 8224399..ccf5cab 100644 --- a/chuck_data/commands/setup_wizard.py +++ b/chuck_data/commands/setup_wizard.py @@ -248,7 +248,9 @@ def _clear_context(self): def handle_command( - client: Optional[DatabricksAPIClient], interactive_input: str = None, **kwargs: Any + client: Optional[DatabricksAPIClient], + interactive_input: Optional[str] = None, + **kwargs: Any, ) -> CommandResult: """ Setup wizard command handler using the new architecture. diff --git a/chuck_data/commands/sql_external_data.py b/chuck_data/commands/sql_external_data.py index be9655d..fab66b3 100644 --- a/chuck_data/commands/sql_external_data.py +++ b/chuck_data/commands/sql_external_data.py @@ -125,7 +125,7 @@ def get_paginated_rows( if start_row < chunk_end and current_row < start_row + num_rows: # We need some data from this chunk try: - chunk_data = fetch_chunk_data([link], link.get("chunk_index")) + chunk_data = fetch_chunk_data([link], link.get("chunk_index", 0)) if chunk_data: # Calculate which rows from this chunk we need local_start = max(0, start_row - chunk_start) diff --git a/chuck_data/commands/stitch_tools.py b/chuck_data/commands/stitch_tools.py index 00aa427..139dc48 100644 --- a/chuck_data/commands/stitch_tools.py +++ b/chuck_data/commands/stitch_tools.py @@ -28,6 +28,20 @@ "GEOMETRY", ] +# Numeric types that don't support semantic tags in Stitch +NUMERIC_TYPES = [ + "LONG", + "BIGINT", + "INT", + "INTEGER", + "SMALLINT", + "TINYINT", + "DOUBLE", + "FLOAT", + "DECIMAL", + "NUMERIC", +] + def validate_multi_location_access( client: DatabricksAPIClient, locations: List[Dict[str, str]] @@ -201,7 +215,11 @@ def _helper_prepare_stitch_config( "type": col_data["type"], "semantics": [], } - if col_data.get("semantic"): # Only add non-null/empty semantics + # Only add semantics for non-numeric types (Stitch doesn't support semantics on LONG, etc.) + if ( + col_data.get("semantic") + and col_data["type"].upper() not in NUMERIC_TYPES + ): field_cfg["semantics"].append(col_data["semantic"]) table_cfg["fields"].append(field_cfg) else: diff --git a/chuck_data/commands/tag_pii.py b/chuck_data/commands/tag_pii.py index a3356b7..e1643dc 100644 --- a/chuck_data/commands/tag_pii.py +++ b/chuck_data/commands/tag_pii.py @@ -31,7 +31,7 @@ def handle_command(client: Optional[DatabricksAPIClient], **kwargs) -> CommandRe table_name (str): Name of the table to tag pii_columns (list): List of columns with PII semantic info """ - table_name: str = kwargs.get("table_name") + table_name = kwargs.get("table_name") pii_columns: List[Dict[str, Any]] = kwargs.get("pii_columns", []) if not table_name: diff --git a/chuck_data/commands/upload_file.py b/chuck_data/commands/upload_file.py index d375342..bfe1d65 100644 --- a/chuck_data/commands/upload_file.py +++ b/chuck_data/commands/upload_file.py @@ -64,6 +64,8 @@ def handle_command( path=destination_path, contents=contents, overwrite=overwrite ) else: + # local_path is guaranteed non-None by validation above + assert local_path is not None with open(local_path, "r") as file: file_contents = file.read() client.store_dbfs_file( diff --git a/chuck_data/commands/warehouse_selection.py b/chuck_data/commands/warehouse_selection.py index 7bc36f2..e967ec1 100644 --- a/chuck_data/commands/warehouse_selection.py +++ b/chuck_data/commands/warehouse_selection.py @@ -65,7 +65,7 @@ def handle_command(client: Optional[DatabricksAPIClient], **kwargs) -> CommandRe client: API client instance **kwargs: warehouse (str) - warehouse ID or name, tool_output_callback (optional) """ - warehouse: str = kwargs.get("warehouse") + warehouse = kwargs.get("warehouse") tool_output_callback = kwargs.get("tool_output_callback") # Must provide warehouse parameter @@ -125,6 +125,8 @@ def handle_command(client: Optional[DatabricksAPIClient], **kwargs) -> CommandRe _report_step(f"Found warehouse '{selected_name}'", tool_output_callback) # Set the active warehouse + # target_warehouse is guaranteed to be a dict at this point + assert isinstance(target_warehouse, dict) warehouse_id_to_set = target_warehouse.get("id") warehouse_display_name = target_warehouse.get("name", "Unknown") warehouse_state = target_warehouse.get("state", "Unknown") diff --git a/chuck_data/commands/wizard/renderer.py b/chuck_data/commands/wizard/renderer.py index 1597161..3fae6cc 100644 --- a/chuck_data/commands/wizard/renderer.py +++ b/chuck_data/commands/wizard/renderer.py @@ -5,11 +5,12 @@ import platform import subprocess import logging -from typing import List, Dict, Any +from typing import List from rich.console import Console from rich.table import Table from rich import box +from chuck_data.llm.provider import ModelInfo from .state import WizardState, WizardStep from .steps import SetupStep @@ -105,7 +106,7 @@ def render_completion(self): self.console.print("You are now ready to use Chuck with all features enabled.") self.console.print("Type /help to see available commands.") - def _render_models_list(self, models: List[Dict[str, Any]]): + def _render_models_list(self, models: List[ModelInfo]): """Render the list of available models.""" if not models: self.render_warning("No models available.") diff --git a/chuck_data/commands/wizard/state.py b/chuck_data/commands/wizard/state.py index 4e09d1a..df2f607 100644 --- a/chuck_data/commands/wizard/state.py +++ b/chuck_data/commands/wizard/state.py @@ -6,6 +6,8 @@ from enum import Enum from typing import Dict, List, Optional, Any +from chuck_data.llm.provider import ModelInfo + class WizardStep(Enum): """Steps in the setup wizard.""" @@ -38,7 +40,7 @@ class WizardState: workspace_url: Optional[str] = None token: Optional[str] = None llm_provider: Optional[str] = None - models: List[Dict[str, Any]] = field(default_factory=list) + models: List[ModelInfo] = field(default_factory=list) selected_model: Optional[str] = None usage_consent: Optional[bool] = None error_message: Optional[str] = None diff --git a/chuck_data/commands/workspace_selection.py b/chuck_data/commands/workspace_selection.py index b7f0567..7aa3e44 100644 --- a/chuck_data/commands/workspace_selection.py +++ b/chuck_data/commands/workspace_selection.py @@ -22,7 +22,7 @@ def handle_command(client: Optional[DatabricksAPIClient], **kwargs) -> CommandRe client: API client instance (not used by this handler) **kwargs: workspace_url (str) """ - workspace_url: str = kwargs.get("workspace_url") + workspace_url = kwargs.get("workspace_url") if not workspace_url: return CommandResult(False, message="workspace_url parameter is required.") diff --git a/chuck_data/interactive_context.py b/chuck_data/interactive_context.py index 1da2e26..926e9a7 100644 --- a/chuck_data/interactive_context.py +++ b/chuck_data/interactive_context.py @@ -15,9 +15,11 @@ class InteractiveContext: Helps coordinate between command handlers and the TUI. """ - _instance = None + _instance: Optional["InteractiveContext"] = None + _active_contexts: Dict[str, Any] + _current_command: Optional[str] - def __new__(cls): + def __new__(cls) -> "InteractiveContext": if cls._instance is None: cls._instance = super(InteractiveContext, cls).__new__(cls) cls._instance._active_contexts = {} diff --git a/chuck_data/llm/factory.py b/chuck_data/llm/factory.py index 2f6a3a1..991c883 100644 --- a/chuck_data/llm/factory.py +++ b/chuck_data/llm/factory.py @@ -104,7 +104,7 @@ def _instantiate_provider(provider_name: str, config: dict) -> LLMProvider: elif provider_name == "openai": try: - from chuck_data.llm.providers.openai import OpenAIProvider + from chuck_data.llm.providers.openai import OpenAIProvider # type: ignore[reportMissingImports] return OpenAIProvider(**config) except ImportError as e: @@ -114,7 +114,7 @@ def _instantiate_provider(provider_name: str, config: dict) -> LLMProvider: elif provider_name == "anthropic": try: - from chuck_data.llm.providers.anthropic import AnthropicProvider + from chuck_data.llm.providers.anthropic import AnthropicProvider # type: ignore[reportMissingImports] return AnthropicProvider(**config) except ImportError as e: @@ -123,7 +123,7 @@ def _instantiate_provider(provider_name: str, config: dict) -> LLMProvider: ) from e elif provider_name == "mock": - from chuck_data.llm.providers.mock import MockProvider + from chuck_data.llm.providers.mock import MockProvider # type: ignore[reportMissingImports] return MockProvider(**config) diff --git a/chuck_data/llm/provider.py b/chuck_data/llm/provider.py index 89d137a..363a51e 100644 --- a/chuck_data/llm/provider.py +++ b/chuck_data/llm/provider.py @@ -4,13 +4,19 @@ from openai.types.chat import ChatCompletion -class ModelInfo(TypedDict, total=False): +class _ModelInfoRequired(TypedDict): + """Required fields for ModelInfo.""" + + model_id: str # Provider-specific model identifier + + +class ModelInfo(_ModelInfoRequired, total=False): """Unified model information across LLM providers. All providers must return model information in this format. + Required field: model_id """ - model_id: str # Provider-specific model identifier model_name: str # Human-readable model name provider_name: str # Provider name (e.g., "databricks", "aws_bedrock") supports_tool_use: bool # Whether model supports function calling @@ -44,9 +50,13 @@ def chat( """ ... - def list_models(self) -> List[ModelInfo]: + def list_models(self, tool_calling_only: bool = True) -> List[ModelInfo]: """List available models from this provider. + Args: + tool_calling_only: If True, only return models that support tool calling. + Defaults to True since tool calling is required for agent workflows. + Returns: List of ModelInfo dicts containing model metadata """ diff --git a/chuck_data/llm/providers/aws_bedrock.py b/chuck_data/llm/providers/aws_bedrock.py index 1d8fbd7..fb24c2e 100644 --- a/chuck_data/llm/providers/aws_bedrock.py +++ b/chuck_data/llm/providers/aws_bedrock.py @@ -20,6 +20,8 @@ from typing import Any, Dict, List, Literal, Optional from openai.types.chat.chat_completion import ChatCompletion, Choice + +from chuck_data.llm.provider import ModelInfo from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.completion_usage import CompletionUsage from openai.types.chat.chat_completion_message_tool_call import ( @@ -591,7 +593,7 @@ def _supports_tool_calling(model_id: str, provider: str) -> bool: return False - def list_models(self, tool_calling_only: bool = True) -> List[Dict[str, Any]]: + def list_models(self, tool_calling_only: bool = True) -> List[ModelInfo]: """List available Bedrock foundation models. Similar to Databricks list_models() but for Bedrock model catalog. @@ -601,15 +603,14 @@ def list_models(self, tool_calling_only: bool = True) -> List[Dict[str, Any]]: Defaults to True since tool calling is required for agent workflows. Returns: - List of model dicts with metadata: + List of ModelInfo dicts with metadata: [ { "model_id": "anthropic.claude-3-5-sonnet-20241022-v2:0", "model_name": "Claude 3.5 Sonnet", - "provider": "Anthropic", + "provider_name": "Anthropic", "supports_tool_use": True, - "input_modalities": ["TEXT"], - "output_modalities": ["TEXT"] + ... }, ... ] @@ -633,18 +634,13 @@ def list_models(self, tool_calling_only: bool = True) -> List[Dict[str, Any]]: continue models.append( - { - "model_id": model_id, - "model_name": model.get("modelName"), - "provider": provider, - "supports_tool_use": supports_tool_use, - "state": "READY", # AWS Bedrock models are always ready if returned - "input_modalities": model.get("inputModalities", []), - "output_modalities": model.get("outputModalities", []), - "response_streaming_supported": model.get( - "responseStreamingSupported", False - ), - } + ModelInfo( + model_id=model_id, + model_name=model.get("modelName"), + provider_name=provider, + supports_tool_use=supports_tool_use, + state="READY", # AWS Bedrock models are always ready if returned + ) ) return models diff --git a/chuck_data/llm/providers/databricks.py b/chuck_data/llm/providers/databricks.py index a62e812..18503a8 100644 --- a/chuck_data/llm/providers/databricks.py +++ b/chuck_data/llm/providers/databricks.py @@ -80,19 +80,24 @@ def chat( base_url=f"{self.workspace_url}/serving-endpoints", ) - # Make request + # Ensure we have a model - raise if none available + if not resolved_model: + raise ValueError("No model specified and no active model configured") + + # Make request - using type: ignore for OpenAI SDK strict typing + # The runtime behavior is correct as OpenAI accepts these formats if tools: response = client.chat.completions.create( model=resolved_model, - messages=messages, - tools=tools, + messages=messages, # type: ignore[arg-type] + tools=tools, # type: ignore[arg-type] stream=stream, - tool_choice=tool_choice, + tool_choice=tool_choice, # type: ignore[arg-type] ) else: response = client.chat.completions.create( model=resolved_model, - messages=messages, + messages=messages, # type: ignore[arg-type] stream=stream, ) diff --git a/chuck_data/metrics_collector.py b/chuck_data/metrics_collector.py index d15ecad..e835ef0 100644 --- a/chuck_data/metrics_collector.py +++ b/chuck_data/metrics_collector.py @@ -29,7 +29,7 @@ def _should_track(self) -> bool: Returns: bool: True if user has provided consent, False otherwise. """ - return self.config_manager.get_config().usage_tracking_consent + return self.config_manager.get_config().usage_tracking_consent or False def _get_chuck_configuration_for_metric(self) -> Dict[str, Any]: """ diff --git a/chuck_data/service.py b/chuck_data/service.py index 0386a32..e4b3d61 100644 --- a/chuck_data/service.py +++ b/chuck_data/service.py @@ -6,8 +6,9 @@ import json import logging import jsonschema +from jsonschema.exceptions import ValidationError import traceback -from typing import Dict, Optional, Any, Tuple +from typing import Dict, Optional, Any, Tuple, Callable from chuck_data.clients.databricks import DatabricksAPIClient from chuck_data.commands.base import CommandResult @@ -355,7 +356,7 @@ def _parse_and_validate_tui_args( jsonschema.validate( instance=final_args_for_validation, schema=full_schema_for_validation ) - except jsonschema.exceptions.ValidationError as ve: + except ValidationError as ve: usage = ( command_def.usage_hint or f"Use '/help' for details on '{command_def.name}'." @@ -373,7 +374,7 @@ def execute_command( command_name_from_ui: str, *raw_args: str, interactive_input: Optional[str] = None, - tool_output_callback: Optional[callable] = None, + tool_output_callback: Optional[Callable[..., Any]] = None, **raw_kwargs: Any, # For future TUI use, e.g. /cmd --named_arg value ) -> CommandResult: """ @@ -409,7 +410,7 @@ def execute_command( ) return CommandResult(False, message=f"Not authenticated. {error_msg}") - parsed_args_dict: Dict[str, Any] + parsed_args_dict: Optional[Dict[str, Any]] args_for_handler: Dict[str, Any] # Interactive Mode Handling @@ -439,6 +440,8 @@ def execute_command( return CommandResult( False, message="Internal error during argument parsing." ) + # Type is narrowed to Dict[str, Any] after None check + assert parsed_args_dict is not None args_for_handler = parsed_args_dict # Pass tool output callback for agent commands diff --git a/chuck_data/ui/table_formatter.py b/chuck_data/ui/table_formatter.py index 9bcaa41..3e7aa10 100644 --- a/chuck_data/ui/table_formatter.py +++ b/chuck_data/ui/table_formatter.py @@ -147,13 +147,14 @@ def format_cell(value: Any, style: Any = None, none_display: str = "N/A") -> Tex logging.error(f"Error applying style function: {e}") applied_style = None - return Text(value, style=applied_style) + # Cast applied_style to the expected type - it can be str, Style, or None + return Text(value, style=cast(Any, applied_style)) def add_row_with_styles( table: Table, row_data: List[Any], - styles: Optional[List[str]] = None, + styles: Optional[List[Optional[str]]] = None, ) -> None: """ Add a row to a table with optional styling per cell. @@ -161,11 +162,11 @@ def add_row_with_styles( Args: table: The Rich Table to add the row to row_data: List of cell values - styles: Optional list of styles to apply to each cell + styles: Optional list of styles to apply to each cell (can contain None) """ # Initialize defaults if not provided if styles is None: - styles = [None] * len(row_data) + styles = cast(List[Optional[str]], [None] * len(row_data)) # Format each cell and add to table formatted_cells = [ diff --git a/chuck_data/ui/tui.py b/chuck_data/ui/tui.py index 3d59243..e8ecb08 100644 --- a/chuck_data/ui/tui.py +++ b/chuck_data/ui/tui.py @@ -228,6 +228,9 @@ def _(event): # Check if we're in interactive mode if interactive_context.is_in_interactive_mode(): current_cmd = interactive_context.current_command + if current_cmd is None: + # Should not happen if is_in_interactive_mode() returns True + continue # Use prompt toolkit with interactive styling prompt_message = HTML( @@ -321,6 +324,12 @@ def _(event): self.console.print( f"\n[{WARNING_STYLE}]Interrupted by user. Type 'exit' to quit.[/{WARNING_STYLE}]" ) + except EOFError: + # Handle Ctrl+D + self.console.print( + f"\n[{WARNING_STYLE}]Thank you for using chuck![/{WARNING_STYLE}]" + ) + break except Exception as e: # Import at the top of the method to avoid scoping issues from chuck_data.exceptions import PaginationCancelled @@ -329,21 +338,13 @@ def _(event): # Handle pagination cancellation silently - just return to prompt pass else: - raise # Re-raise other exceptions - except EOFError: - # Handle Ctrl+D - self.console.print( - f"\n[{WARNING_STYLE}]Thank you for using chuck![/{WARNING_STYLE}]" - ) - break - except Exception as e: - # Handle other exceptions - self.console.print( - f"[{ERROR_STYLE}]Unexpected Error: {str(e)}[/{ERROR_STYLE}]" - ) - # Print stack trace in debug mode - if self.debug: - self.console.print("[dim]" + traceback.format_exc() + "[/dim]") + # Handle other exceptions + self.console.print( + f"[{ERROR_STYLE}]Unexpected Error: {str(e)}[/{ERROR_STYLE}]" + ) + # Print stack trace in debug mode + if self.debug: + self.console.print("[dim]" + traceback.format_exc() + "[/dim]") def _needs_shlex_parsing(self, command: str) -> bool: """Determine if command needs shlex parsing (has quotes or flags).""" @@ -504,7 +505,7 @@ def _process_command_result(self, cmd, result): elif cmd == "/usage": # For the usage command, we just display the message if result.message: - self._display_usage(result.message) + self.console.print(result.message) elif ( cmd.startswith("/help") and isinstance(result.data, dict) @@ -677,8 +678,11 @@ def _display_full_tool_output( elif tool_name in ["detailed-models", "list-models", "list_models", "models"]: if "models" in tool_result: self._display_models_consolidated(tool_result) - else: + elif isinstance(tool_result, list): self._display_models(tool_result) + else: + # Fallback: display as consolidated if it's a dict without "models" key + self._display_models_consolidated(tool_result) elif tool_name in ["list-warehouses", "list_warehouses", "warehouses"]: self._display_warehouses(tool_result) elif tool_name in ["list-volumes", "list_volumes", "volumes"]: @@ -1122,7 +1126,7 @@ def _display_models_consolidated(self, data: Dict[str, Any]) -> None: f"[{WARNING_STYLE}]No models found in workspace.[/{WARNING_STYLE}]" ) if data.get("message"): - self.console.print("\n" + data.get("message")) + self.console.print("\n" + (data.get("message") or "")) # Raise PaginationCancelled to return to chuck > prompt immediately raise PaginationCancelled() diff --git a/tests/unit/llm/providers/test_aws_bedrock.py b/tests/unit/llm/providers/test_aws_bedrock.py index 651e2e9..e3a560b 100644 --- a/tests/unit/llm/providers/test_aws_bedrock.py +++ b/tests/unit/llm/providers/test_aws_bedrock.py @@ -409,7 +409,7 @@ def test_list_models_returns_model_catalog(self, mock_boto3): assert len(models) == 1 assert models[0]["model_id"] == "anthropic.claude-3-5-sonnet-20241022-v2:0" assert models[0]["model_name"] == "Claude 3.5 Sonnet" - assert models[0]["provider"] == "Anthropic" + assert models[0]["provider_name"] == "Anthropic" assert models[0]["supports_tool_use"] is True # Test with show_all: all models (tool_calling_only=False) @@ -423,7 +423,7 @@ def test_list_models_returns_model_catalog(self, mock_boto3): # Verify second model (Llama 3 70B - does NOT support tool calling, not 3.1+) assert all_models[1]["model_id"] == "meta.llama3-70b-instruct-v1:0" - assert all_models[1]["provider"] == "Meta" + assert all_models[1]["provider_name"] == "Meta" assert all_models[1]["supports_tool_use"] is False # Verify third model (Claude 2 - does NOT support tool calling)