From a2d4f4c95adae9ab77f8d9ce01e886b012e17096 Mon Sep 17 00:00:00 2001 From: Aviv Shafir Date: Wed, 19 Nov 2025 11:44:49 +0200 Subject: [PATCH 1/2] Support databricks policy id for setup_stitch and launch_job commands --- chuck_data/clients/databricks.py | 8 +++++++- chuck_data/commands/jobs.py | 6 ++++++ chuck_data/commands/setup_stitch.py | 28 +++++++++++++++++++++++++++- chuck_data/commands/stitch_tools.py | 4 ++++ chuck_data/service.py | 16 ++++++++++++---- 5 files changed, 56 insertions(+), 6 deletions(-) diff --git a/chuck_data/clients/databricks.py b/chuck_data/clients/databricks.py index 2362456..c8e1b4c 100644 --- a/chuck_data/clients/databricks.py +++ b/chuck_data/clients/databricks.py @@ -571,7 +571,9 @@ def submit_sql_statement( # Jobs methods # - def submit_job_run(self, config_path, init_script_path, run_name=None): + def submit_job_run( + self, config_path, init_script_path, run_name=None, policy_id=None + ): """ Submit a one-time Databricks job run using the /runs/submit endpoint. @@ -579,6 +581,7 @@ def submit_job_run(self, config_path, init_script_path, run_name=None): config_path: Path to the configuration file for the job in the Volume init_script_path: Path to the initialization script run_name: Optional name for the run. If None, a default name will be generated. + policy_id: Optional cluster policy ID to use for the job run. Returns: Dict containing the job run information (including run_id) @@ -618,6 +621,9 @@ def submit_job_run(self, config_path, init_script_path, run_name=None): "autoscale": {"min_workers": 10, "max_workers": 50}, } + if policy_id: + cluster_config["policy_id"] = policy_id + # Add cloud-specific attributes cluster_config.update(self.get_cloud_attributes()) diff --git a/chuck_data/commands/jobs.py b/chuck_data/commands/jobs.py index 2906ac5..cbde272 100644 --- a/chuck_data/commands/jobs.py +++ b/chuck_data/commands/jobs.py @@ -17,6 +17,7 @@ def handle_launch_job(client: Optional[DatabricksAPIClient], **kwargs) -> Comman init_script_path: str = kwargs.get("init_script_path") run_name: Optional[str] = kwargs.get("run_name") tool_output_callback = kwargs.get("tool_output_callback") + policy_id: Optional[str] = kwargs.get("policy_id") if not config_path or not init_script_path: return CommandResult( @@ -34,6 +35,7 @@ def handle_launch_job(client: Optional[DatabricksAPIClient], **kwargs) -> Comman config_path=config_path, init_script_path=init_script_path, run_name=run_name, + policy_id=policy_id, ) run_id = run_data.get("run_id") if run_id: @@ -126,6 +128,10 @@ def handle_job_status(client: Optional[DatabricksAPIClient], **kwargs) -> Comman "type": "string", "description": "Path to the init script", }, + "policy_id": { + "type": "string", + "description": "Optional: cluster policy ID to use for the job run", + }, "run_name": {"type": "string", "description": "Optional name for the job run"}, }, required_params=["config_path", "init_script_path"], diff --git a/chuck_data/commands/setup_stitch.py b/chuck_data/commands/setup_stitch.py index 8a75a21..fa3ff47 100644 --- a/chuck_data/commands/setup_stitch.py +++ b/chuck_data/commands/setup_stitch.py @@ -77,6 +77,7 @@ def handle_command( client: Optional[DatabricksAPIClient], interactive_input: Optional[str] = None, auto_confirm: bool = False, + policy_id: Optional[str] = None, **kwargs, ) -> CommandResult: """ @@ -95,6 +96,12 @@ def handle_command( if not client: return CommandResult(False, message="Client is required for Stitch setup.") + # Handle auto-confirm mode + if auto_confirm: + return _handle_legacy_setup( + client, catalog_name_arg, schema_name_arg, policy_id + ) + # Interactive mode - use context management context = InteractiveContext() console = get_console() @@ -103,7 +110,12 @@ def handle_command( # Phase determination if not interactive_input: # First call - Phase 1: Prepare config return _phase_1_prepare_config( - client, context, console, catalog_name_arg, schema_name_arg + client, + context, + console, + catalog_name_arg, + schema_name_arg, + policy_id, ) # Get stored context data @@ -139,6 +151,7 @@ def _handle_legacy_setup( client: DatabricksAPIClient, catalog_name_arg: Optional[str], schema_name_arg: Optional[str], + policy_id: Optional[str] = None, ) -> CommandResult: """Handle auto-confirm mode using the legacy direct setup approach.""" try: @@ -183,6 +196,10 @@ def _handle_legacy_setup( return CommandResult(False, message=prep_result["error"], data=prep_result) + # Add policy_id to metadata if provided + if policy_id: + prep_result["metadata"]["policy_id"] = policy_id + # Now we need to explicitly launch the job since _helper_setup_stitch_logic no longer does it stitch_result_data = _helper_launch_stitch_job( client, prep_result["stitch_config"], prep_result["metadata"] @@ -254,6 +271,7 @@ def _phase_1_prepare_config( console, catalog_name_arg: Optional[str], schema_name_arg: Optional[str], + policy_id: Optional[str] = None, ) -> CommandResult: """Phase 1: Prepare the Stitch configuration.""" target_catalog = catalog_name_arg or get_active_catalog() @@ -284,6 +302,10 @@ def _phase_1_prepare_config( context.clear_active_context("setup_stitch") return CommandResult(False, message=prep_result["error"]) + # Add policy_id to metadata if provided + if policy_id: + prep_result["metadata"]["policy_id"] = policy_id + # Store the prepared data in context (don't store llm_client object) context.store_context_data("setup_stitch", "phase", "review") context.store_context_data( @@ -711,6 +733,10 @@ def _build_post_launch_guidance_message(launch_result, metadata, client=None): "type": "boolean", "description": "Optional: Skip interactive confirmation and launch job immediately (default: false)", }, + "policy_id": { + "type": "string", + "description": "Optional: cluster policy ID to use for the Stitch job run", + }, }, required_params=[], tui_aliases=["/setup-stitch"], diff --git a/chuck_data/commands/stitch_tools.py b/chuck_data/commands/stitch_tools.py index 78343c9..6d5a009 100644 --- a/chuck_data/commands/stitch_tools.py +++ b/chuck_data/commands/stitch_tools.py @@ -423,10 +423,14 @@ def _helper_launch_stitch_job( # Launch the Stitch job try: + # Extract policy_id if present + policy_id = metadata.get("policy_id") + job_run_data = client.submit_job_run( config_path=config_file_path, init_script_path=init_script_path, run_name=f"Stitch Setup: {stitch_job_name}", + policy_id=policy_id, ) run_id = job_run_data.get("run_id") if not run_id: diff --git a/chuck_data/service.py b/chuck_data/service.py index 7d0fc00..0386a32 100644 --- a/chuck_data/service.py +++ b/chuck_data/service.py @@ -414,10 +414,18 @@ def execute_command( # Interactive Mode Handling if command_def.supports_interactive_input: - # For interactive commands, the `interactive_input` is the primary payload. - # Always include the interactive_input parameter, even if None, for interactive commands - # This prevents dropping inputs in complex interactive flows - args_for_handler = {"interactive_input": interactive_input} + # For interactive commands, we still want to parse any initial arguments/flags if provided. + # This supports use cases like "/setup-stitch --auto-confirm --policy_id=..." + parsed_args_dict, error_result = self._parse_and_validate_tui_args( + command_def, raw_args, raw_kwargs + ) + + if error_result: + return error_result + + args_for_handler = parsed_args_dict or {} + # Ensure interactive_input is passed (it overrides any parsed arg with same name, though unlikely) + args_for_handler["interactive_input"] = interactive_input else: # Standard Argument Parsing & Validation parsed_args_dict, error_result = self._parse_and_validate_tui_args( From cf943e8f412e2f97edf0d390fa9cc7924d17f4ee Mon Sep 17 00:00:00 2001 From: Aviv Shafir Date: Wed, 26 Nov 2025 20:57:14 +0200 Subject: [PATCH 2/2] update tests --- tests/fixtures/databricks/client.py | 1 + tests/fixtures/databricks/job_stub.py | 24 ++- tests/unit/commands/test_jobs.py | 116 +++++++++++++- tests/unit/commands/test_setup_stitch.py | 113 ++++++++++++++ tests/unit/commands/test_stitch_tools.py | 8 +- tests/unit/core/test_service.py | 191 +++++++++++++++++++++++ uv.lock | 4 +- 7 files changed, 445 insertions(+), 12 deletions(-) diff --git a/tests/fixtures/databricks/client.py b/tests/fixtures/databricks/client.py index 4b3d05b..5f01561 100644 --- a/tests/fixtures/databricks/client.py +++ b/tests/fixtures/databricks/client.py @@ -65,6 +65,7 @@ def reset(self): # Reset call tracking self.create_stitch_notebook_calls = [] + self.submit_job_run_calls = [] self.list_catalogs_calls = [] self.get_catalog_calls = [] self.list_schemas_calls = [] diff --git a/tests/fixtures/databricks/job_stub.py b/tests/fixtures/databricks/job_stub.py index 073c5e4..f936184 100644 --- a/tests/fixtures/databricks/job_stub.py +++ b/tests/fixtures/databricks/job_stub.py @@ -6,6 +6,7 @@ class JobStubMixin: def __init__(self): self.create_stitch_notebook_calls = [] + self.submit_job_run_calls = [] def list_jobs(self, **kwargs): """List jobs.""" @@ -23,8 +24,17 @@ def run_job(self, job_id): """Run a job.""" return {"run_id": f"run_{job_id}_001", "job_id": job_id, "state": "RUNNING"} - def submit_job_run(self, config_path, init_script_path, run_name=None): - """Submit a job run and return run_id.""" + def submit_job_run( + self, config_path, init_script_path, run_name=None, policy_id=None + ): + """Submit a job run and return run_id. + + Args: + config_path: Path to the job configuration file + init_script_path: Path to the init script + run_name: Optional name for the job run + policy_id: Optional cluster policy ID to use for the job run + """ from datetime import datetime if not run_name: @@ -32,6 +42,16 @@ def submit_job_run(self, config_path, init_script_path, run_name=None): f"Chuck AI One-Time Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" ) + # Track the call for test verification + self.submit_job_run_calls.append( + { + "config_path": config_path, + "init_script_path": init_script_path, + "run_name": run_name, + "policy_id": policy_id, + } + ) + # Return a successful job submission return {"run_id": 123456} diff --git a/tests/unit/commands/test_jobs.py b/tests/unit/commands/test_jobs.py index fef6002..9338f3a 100644 --- a/tests/unit/commands/test_jobs.py +++ b/tests/unit/commands/test_jobs.py @@ -83,7 +83,9 @@ def test_handle_launch_job_no_run_id(databricks_client_stub, temp_config): """Test launching a job where response doesn't include run_id.""" with patch("chuck_data.config._config_manager", temp_config): - def submit_no_run_id(config_path, init_script_path, run_name=None): + def submit_no_run_id( + config_path, init_script_path, run_name=None, policy_id=None + ): return {} # No run_id in response databricks_client_stub.submit_job_run = submit_no_run_id @@ -101,7 +103,9 @@ def test_handle_launch_job_http_error(databricks_client_stub, temp_config): """Test launching a job with HTTP error response.""" with patch("chuck_data.config._config_manager", temp_config): - def submit_failing(config_path, init_script_path, run_name=None): + def submit_failing( + config_path, init_script_path, run_name=None, policy_id=None + ): raise Exception("Bad Request") databricks_client_stub.submit_job_run = submit_failing @@ -138,6 +142,73 @@ def test_handle_launch_job_missing_url(temp_config): assert "Client required" in result.message +# --- policy_id Parameter Tests: handle_launch_job --- + + +def test_handle_launch_job_with_policy_id_success(databricks_client_stub, temp_config): + """Test launching a job with policy_id passes it to the client.""" + with patch("chuck_data.config._config_manager", temp_config): + result: CommandResult = handle_launch_job( + databricks_client_stub, + config_path="/Volumes/test/config.json", + init_script_path="/init/script.sh", + run_name="MyTestJobWithPolicy", + policy_id="000F957411D99C1F", + ) + assert result.success is True + assert "123456" in result.message + assert result.data["run_id"] == "123456" + + # Verify policy_id was passed to the client + assert len(databricks_client_stub.submit_job_run_calls) == 1 + call_args = databricks_client_stub.submit_job_run_calls[0] + assert call_args["policy_id"] == "000F957411D99C1F" + assert call_args["config_path"] == "/Volumes/test/config.json" + assert call_args["init_script_path"] == "/init/script.sh" + + +def test_handle_launch_job_without_policy_id(databricks_client_stub, temp_config): + """Test launching a job without policy_id passes None to the client.""" + with patch("chuck_data.config._config_manager", temp_config): + result: CommandResult = handle_launch_job( + databricks_client_stub, + config_path="/Volumes/test/config.json", + init_script_path="/init/script.sh", + run_name="MyTestJobNoPolicy", + ) + assert result.success is True + + # Verify policy_id was passed as None + assert len(databricks_client_stub.submit_job_run_calls) == 1 + call_args = databricks_client_stub.submit_job_run_calls[0] + assert call_args["policy_id"] is None + + +def test_agent_launch_job_with_policy_id_success(databricks_client_stub, temp_config): + """AGENT TEST: Launching a job with policy_id passes it correctly.""" + with patch("chuck_data.config._config_manager", temp_config): + progress_steps = [] + + def capture_progress(tool_name, data): + progress_steps.append(data.get("step", str(data))) + + result = handle_launch_job( + databricks_client_stub, + config_path="/Volumes/test/config.json", + init_script_path="/init/script.sh", + run_name="AgentTestJobWithPolicy", + policy_id="POLICY123ABC", + tool_output_callback=capture_progress, + ) + assert result.success is True + assert "123456" in result.message + + # Verify policy_id was passed to the client + assert len(databricks_client_stub.submit_job_run_calls) == 1 + call_args = databricks_client_stub.submit_job_run_calls[0] + assert call_args["policy_id"] == "POLICY123ABC" + + # --- Direct Command Execution Tests: handle_job_status --- @@ -230,7 +301,9 @@ def capture_progress(tool_name, data): progress_steps.append(data.get("step", str(data))) # Configure stub to return response without run_id - def submit_no_run_id(config_path, init_script_path, run_name=None): + def submit_no_run_id( + config_path, init_script_path, run_name=None, policy_id=None + ): return {} # No run_id in response databricks_client_stub.submit_job_run = submit_no_run_id @@ -346,7 +419,9 @@ def test_agent_tool_executor_launch_job_integration( """ with patch("chuck_data.config._config_manager", temp_config): databricks_client_stub.submit_job_run = ( - lambda config_path, init_script_path, run_name=None: {"run_id": "789012"} + lambda config_path, init_script_path, run_name=None, policy_id=None: { + "run_id": "789012" + } ) tool_args = { "config_path": "/Volumes/agent/config.json", @@ -362,6 +437,39 @@ def test_agent_tool_executor_launch_job_integration( assert agent_result.get("run_id") == "789012" +def test_agent_tool_executor_launch_job_with_policy_id( + databricks_client_stub, temp_config +): + """AGENT TEST: End-to-end integration for launching a job with policy_id via execute_tool.""" + with patch("chuck_data.config._config_manager", temp_config): + captured_policy_id = [] + + def mock_submit_job_run( + config_path, init_script_path, run_name=None, policy_id=None + ): + captured_policy_id.append(policy_id) + return {"run_id": "789012"} + + databricks_client_stub.submit_job_run = mock_submit_job_run + + tool_args = { + "config_path": "/Volumes/agent/config.json", + "init_script_path": "/agent/init.sh", + "run_name": "AgentExecutorTestJobWithPolicy", + "policy_id": "AGENT_POLICY_ID_123", + } + agent_result = execute_tool( + api_client=databricks_client_stub, + tool_name="launch_job", + tool_args=tool_args, + ) + assert agent_result is not None + assert agent_result.get("run_id") == "789012" + # Verify policy_id was passed through the agent tool executor + assert len(captured_policy_id) == 1 + assert captured_policy_id[0] == "AGENT_POLICY_ID_123" + + def test_agent_tool_executor_job_status_integration( databricks_client_stub, temp_config ): diff --git a/tests/unit/commands/test_setup_stitch.py b/tests/unit/commands/test_setup_stitch.py index e37d6d1..6e9b8ca 100644 --- a/tests/unit/commands/test_setup_stitch.py +++ b/tests/unit/commands/test_setup_stitch.py @@ -184,6 +184,78 @@ def failing_callback(tool_name, data): assert not result.success # Will fail due to missing context/data +# Auto-confirm mode tests with policy_id + + +def test_auto_confirm_mode_passes_policy_id(databricks_client_stub, llm_client_stub): + """Auto-confirm mode passes policy_id to the job submission.""" + # Setup test data for successful operation + setup_successful_stitch_test_data(databricks_client_stub, llm_client_stub) + + with patch( + "chuck_data.commands.setup_stitch.LLMProviderFactory.create", + return_value=llm_client_stub, + ): + with patch( + "chuck_data.commands.stitch_tools.get_amperity_token", + return_value="test_token", + ): + with patch( + "chuck_data.commands.setup_stitch.get_metrics_collector", + return_value=MagicMock(), + ): + # Call with auto_confirm=True and policy_id + result = handle_command( + databricks_client_stub, + catalog_name="test_catalog", + schema_name="test_schema", + auto_confirm=True, + policy_id="000F957411D99C1F", + ) + + # Verify success + assert result.success + + # Verify policy_id was passed to submit_job_run + assert len(databricks_client_stub.submit_job_run_calls) == 1 + call_args = databricks_client_stub.submit_job_run_calls[0] + assert call_args["policy_id"] == "000F957411D99C1F" + + +def test_auto_confirm_mode_without_policy_id(databricks_client_stub, llm_client_stub): + """Auto-confirm mode works without policy_id (passes None).""" + # Setup test data for successful operation + setup_successful_stitch_test_data(databricks_client_stub, llm_client_stub) + + with patch( + "chuck_data.commands.setup_stitch.LLMProviderFactory.create", + return_value=llm_client_stub, + ): + with patch( + "chuck_data.commands.stitch_tools.get_amperity_token", + return_value="test_token", + ): + with patch( + "chuck_data.commands.setup_stitch.get_metrics_collector", + return_value=MagicMock(), + ): + # Call with auto_confirm=True but no policy_id + result = handle_command( + databricks_client_stub, + catalog_name="test_catalog", + schema_name="test_schema", + auto_confirm=True, + ) + + # Verify success + assert result.success + + # Verify policy_id was passed as None + assert len(databricks_client_stub.submit_job_run_calls) == 1 + call_args = databricks_client_stub.submit_job_run_calls[0] + assert call_args["policy_id"] is None + + # Interactive mode tests def test_interactive_mode_phase_1_preparation(databricks_client_stub, llm_client_stub): """Interactive mode Phase 1 prepares configuration and shows preview.""" @@ -209,3 +281,44 @@ def test_interactive_mode_phase_1_preparation(databricks_client_stub, llm_client assert result.success # Interactive mode should return empty message (console output handles display) assert result.message == "" + + +def test_interactive_mode_phase_1_stores_policy_id( + databricks_client_stub, llm_client_stub +): + """Interactive mode Phase 1 stores policy_id in context metadata.""" + from chuck_data.interactive_context import InteractiveContext + + # Setup test data for successful operation + setup_successful_stitch_test_data(databricks_client_stub, llm_client_stub) + + # Reset the interactive context before test + context = InteractiveContext() + context.clear_active_context("setup_stitch") + + with patch( + "chuck_data.commands.setup_stitch.LLMProviderFactory.create", + return_value=llm_client_stub, + ): + with patch( + "chuck_data.commands.stitch_tools.get_amperity_token", + return_value="test_token", + ): + # Call without auto_confirm to enter interactive mode, with policy_id + result = handle_command( + databricks_client_stub, + catalog_name="test_catalog", + schema_name="test_schema", + policy_id="INTERACTIVE_POLICY_123", + ) + + # Verify Phase 1 behavior + assert result.success + + # Verify policy_id was stored in context metadata + context_data = context.get_context_data("setup_stitch") + assert "metadata" in context_data + assert context_data["metadata"].get("policy_id") == "INTERACTIVE_POLICY_123" + + # Clean up context + context.clear_active_context("setup_stitch") diff --git a/tests/unit/commands/test_stitch_tools.py b/tests/unit/commands/test_stitch_tools.py index 7521cd9..9e3ff1d 100644 --- a/tests/unit/commands/test_stitch_tools.py +++ b/tests/unit/commands/test_stitch_tools.py @@ -837,7 +837,7 @@ def test_launch_stitch_job_returns_job_id( # Mock upload and submit databricks_client_stub.upload_file = lambda path, content, overwrite: True - def mock_submit(config_path, init_script_path, run_name): + def mock_submit(config_path, init_script_path, run_name, policy_id=None): return {"run_id": "run-789"} databricks_client_stub.submit_job_run = mock_submit @@ -907,7 +907,7 @@ def test_launch_stitch_job_records_submission( # Mock upload and submit databricks_client_stub.upload_file = lambda path, content, overwrite: True - def mock_submit(config_path, init_script_path, run_name): + def mock_submit(config_path, init_script_path, run_name, policy_id=None): return {"run_id": "run-456"} databricks_client_stub.submit_job_run = mock_submit @@ -985,7 +985,7 @@ def test_launch_stitch_job_records_submission_no_job_id( # Mock upload and submit databricks_client_stub.upload_file = lambda path, content, overwrite: True - def mock_submit(config_path, init_script_path, run_name): + def mock_submit(config_path, init_script_path, run_name, policy_id=None): return {"run_id": "run-999"} databricks_client_stub.submit_job_run = mock_submit @@ -1063,7 +1063,7 @@ def test_launch_stitch_job_records_submission_no_token( # Mock upload and submit databricks_client_stub.upload_file = lambda path, content, overwrite: True - def mock_submit(config_path, init_script_path, run_name): + def mock_submit(config_path, init_script_path, run_name, policy_id=None): return {"run_id": "run-888"} databricks_client_stub.submit_job_run = mock_submit diff --git a/tests/unit/core/test_service.py b/tests/unit/core/test_service.py index d79eefe..4b0c09e 100644 --- a/tests/unit/core/test_service.py +++ b/tests/unit/core/test_service.py @@ -235,3 +235,194 @@ def test_parameter_parsing_flag_equals_with_dashes(databricks_client_stub, temp_ # Should parse correctly with dash-to-underscore conversion assert isinstance(result, CommandResult) assert result.success + + +# --- Interactive Command Argument Parsing Tests --- + + +def _setup_stitch_test_data(databricks_client_stub, llm_client_stub): + """Helper function to set up test data for Stitch operations.""" + # Setup test data in client stub + databricks_client_stub.add_catalog("test_catalog") + databricks_client_stub.add_schema("test_catalog", "test_schema") + databricks_client_stub.add_table( + "test_catalog", + "test_schema", + "users", + columns=[ + {"name": "email", "type": "STRING"}, + {"name": "name", "type": "STRING"}, + {"name": "id", "type": "BIGINT"}, + ], + ) + + # Mock PII scan results - set up table with PII columns + llm_client_stub.set_pii_detection_result( + [ + {"column": "email", "semantic": "email"}, + {"column": "name", "semantic": "name"}, + ] + ) + + # Fix API compatibility issues + original_create_volume = databricks_client_stub.create_volume + + def mock_create_volume(catalog_name, schema_name, name, **kwargs): + return original_create_volume(catalog_name, schema_name, name, **kwargs) + + databricks_client_stub.create_volume = mock_create_volume + + # Override upload_file to match real API signature + def mock_upload_file(path, content=None, overwrite=False, **kwargs): + return True + + databricks_client_stub.upload_file = mock_upload_file + + # Set up other required API responses + databricks_client_stub.fetch_amperity_job_init_response = { + "cluster-init": "#!/bin/bash\necho init", + "job-id": "test-job-setup-123", + } + databricks_client_stub.submit_job_run_response = {"run_id": "12345"} + databricks_client_stub.create_stitch_notebook_response = { + "notebook_path": "/Workspace/test" + } + + +def test_interactive_command_parses_policy_id_flag( + databricks_client_stub, llm_client_stub, temp_config +): + """Test that interactive commands correctly parse --policy_id flag.""" + from unittest.mock import patch + from chuck_data.interactive_context import InteractiveContext + + # Setup test data + _setup_stitch_test_data(databricks_client_stub, llm_client_stub) + + # Reset interactive context + context = InteractiveContext() + context.clear_active_context("setup_stitch") + + with patch("chuck_data.config._config_manager", temp_config): + with patch( + "chuck_data.commands.setup_stitch.LLMProviderFactory.create", + return_value=llm_client_stub, + ): + with patch( + "chuck_data.commands.stitch_tools.get_amperity_token", + return_value="test_token", + ): + service = ChuckService(client=databricks_client_stub) + + # Test --policy_id=value syntax for interactive command + result = service.execute_command( + "/setup-stitch", + "--policy_id=TEST_POLICY_123", + "catalog_name=test_catalog", + "schema_name=test_schema", + ) + + # Should parse correctly and store in context + assert isinstance(result, CommandResult) + assert result.success + + # Verify policy_id was stored in context + context_data = context.get_context_data("setup_stitch") + assert ( + context_data.get("metadata", {}).get("policy_id") + == "TEST_POLICY_123" + ) + + # Clean up + context.clear_active_context("setup_stitch") + + +def test_interactive_command_parses_auto_confirm_and_policy_id( + databricks_client_stub, llm_client_stub, temp_config +): + """Test that interactive commands correctly parse combined --auto-confirm and --policy_id flags.""" + from unittest.mock import patch, MagicMock + + # Setup test data + _setup_stitch_test_data(databricks_client_stub, llm_client_stub) + + with patch("chuck_data.config._config_manager", temp_config): + with patch( + "chuck_data.commands.setup_stitch.LLMProviderFactory.create", + return_value=llm_client_stub, + ): + with patch( + "chuck_data.commands.stitch_tools.get_amperity_token", + return_value="test_token", + ): + with patch( + "chuck_data.commands.setup_stitch.get_metrics_collector", + return_value=MagicMock(), + ): + service = ChuckService(client=databricks_client_stub) + + # Test combined --auto-confirm and --policy_id flags + result = service.execute_command( + "/setup-stitch", + "--auto-confirm", + "--policy_id=COMBINED_POLICY_456", + "catalog_name=test_catalog", + "schema_name=test_schema", + ) + + # Should parse correctly and execute + assert isinstance(result, CommandResult) + assert result.success + + # Verify policy_id was passed to submit_job_run + assert len(databricks_client_stub.submit_job_run_calls) == 1 + call_args = databricks_client_stub.submit_job_run_calls[0] + assert call_args["policy_id"] == "COMBINED_POLICY_456" + + +def test_interactive_command_parses_policy_id_key_value_syntax( + databricks_client_stub, llm_client_stub, temp_config +): + """Test that interactive commands correctly parse policy_id=value syntax.""" + from unittest.mock import patch + from chuck_data.interactive_context import InteractiveContext + + # Setup test data + _setup_stitch_test_data(databricks_client_stub, llm_client_stub) + + # Reset interactive context + context = InteractiveContext() + context.clear_active_context("setup_stitch") + + with patch("chuck_data.config._config_manager", temp_config): + with patch( + "chuck_data.commands.setup_stitch.LLMProviderFactory.create", + return_value=llm_client_stub, + ): + with patch( + "chuck_data.commands.stitch_tools.get_amperity_token", + return_value="test_token", + ): + service = ChuckService(client=databricks_client_stub) + + # Test key=value syntax for policy_id (like policy_id=XYZ) + result = service.execute_command( + "/setup-stitch", + "policy_id=KEY_VALUE_POLICY_789", + "catalog_name=test_catalog", + "schema_name=test_schema", + ) + + # Should parse correctly and store in context + assert isinstance(result, CommandResult) + assert result.success + + # Verify policy_id was stored in context + context_data = context.get_context_data("setup_stitch") + assert ( + context_data.get("metadata", {}).get("policy_id") + == "KEY_VALUE_POLICY_789" + ) + + # Clean up + context.clear_active_context("setup_stitch") diff --git a/uv.lock b/uv.lock index 5336ae2..57f643b 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" [[package]] @@ -207,7 +207,7 @@ wheels = [ [[package]] name = "chuck-data" -version = "0.2.1" +version = "0.2.2" source = { editable = "." } dependencies = [ { name = "boto3" },