Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion chuck_data/clients/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,14 +571,17 @@ def submit_sql_statement(
# Jobs methods
#

def submit_job_run(self, config_path, init_script_path, run_name=None):
def submit_job_run(
self, config_path, init_script_path, run_name=None, policy_id=None
):
"""
Submit a one-time Databricks job run using the /runs/submit endpoint.

Args:
config_path: Path to the configuration file for the job in the Volume
init_script_path: Path to the initialization script
run_name: Optional name for the run. If None, a default name will be generated.
policy_id: Optional cluster policy ID to use for the job run.

Returns:
Dict containing the job run information (including run_id)
Expand Down Expand Up @@ -618,6 +621,9 @@ def submit_job_run(self, config_path, init_script_path, run_name=None):
"autoscale": {"min_workers": 10, "max_workers": 50},
}

if policy_id:
cluster_config["policy_id"] = policy_id

# Add cloud-specific attributes
cluster_config.update(self.get_cloud_attributes())

Expand Down
6 changes: 6 additions & 0 deletions chuck_data/commands/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def handle_launch_job(client: Optional[DatabricksAPIClient], **kwargs) -> Comman
init_script_path: str = kwargs.get("init_script_path")
run_name: Optional[str] = kwargs.get("run_name")
tool_output_callback = kwargs.get("tool_output_callback")
policy_id: Optional[str] = kwargs.get("policy_id")

if not config_path or not init_script_path:
return CommandResult(
Expand All @@ -34,6 +35,7 @@ def handle_launch_job(client: Optional[DatabricksAPIClient], **kwargs) -> Comman
config_path=config_path,
init_script_path=init_script_path,
run_name=run_name,
policy_id=policy_id,
)
run_id = run_data.get("run_id")
if run_id:
Expand Down Expand Up @@ -126,6 +128,10 @@ def handle_job_status(client: Optional[DatabricksAPIClient], **kwargs) -> Comman
"type": "string",
"description": "Path to the init script",
},
"policy_id": {
"type": "string",
"description": "Optional: cluster policy ID to use for the job run",
},
"run_name": {"type": "string", "description": "Optional name for the job run"},
},
required_params=["config_path", "init_script_path"],
Expand Down
28 changes: 27 additions & 1 deletion chuck_data/commands/setup_stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def handle_command(
client: Optional[DatabricksAPIClient],
interactive_input: Optional[str] = None,
auto_confirm: bool = False,
policy_id: Optional[str] = None,
**kwargs,
) -> CommandResult:
"""
Expand All @@ -95,6 +96,12 @@ def handle_command(
if not client:
return CommandResult(False, message="Client is required for Stitch setup.")

# Handle auto-confirm mode
if auto_confirm:
return _handle_legacy_setup(
client, catalog_name_arg, schema_name_arg, policy_id
)

# Interactive mode - use context management
context = InteractiveContext()
console = get_console()
Expand All @@ -103,7 +110,12 @@ def handle_command(
# Phase determination
if not interactive_input: # First call - Phase 1: Prepare config
return _phase_1_prepare_config(
client, context, console, catalog_name_arg, schema_name_arg
client,
context,
console,
catalog_name_arg,
schema_name_arg,
policy_id,
)

# Get stored context data
Expand Down Expand Up @@ -139,6 +151,7 @@ def _handle_legacy_setup(
client: DatabricksAPIClient,
catalog_name_arg: Optional[str],
schema_name_arg: Optional[str],
policy_id: Optional[str] = None,
) -> CommandResult:
"""Handle auto-confirm mode using the legacy direct setup approach."""
try:
Expand Down Expand Up @@ -183,6 +196,10 @@ def _handle_legacy_setup(

return CommandResult(False, message=prep_result["error"], data=prep_result)

# Add policy_id to metadata if provided
if policy_id:
prep_result["metadata"]["policy_id"] = policy_id

# Now we need to explicitly launch the job since _helper_setup_stitch_logic no longer does it
stitch_result_data = _helper_launch_stitch_job(
client, prep_result["stitch_config"], prep_result["metadata"]
Expand Down Expand Up @@ -254,6 +271,7 @@ def _phase_1_prepare_config(
console,
catalog_name_arg: Optional[str],
schema_name_arg: Optional[str],
policy_id: Optional[str] = None,
) -> CommandResult:
"""Phase 1: Prepare the Stitch configuration."""
target_catalog = catalog_name_arg or get_active_catalog()
Expand Down Expand Up @@ -284,6 +302,10 @@ def _phase_1_prepare_config(
context.clear_active_context("setup_stitch")
return CommandResult(False, message=prep_result["error"])

# Add policy_id to metadata if provided
if policy_id:
prep_result["metadata"]["policy_id"] = policy_id

# Store the prepared data in context (don't store llm_client object)
context.store_context_data("setup_stitch", "phase", "review")
context.store_context_data(
Expand Down Expand Up @@ -711,6 +733,10 @@ def _build_post_launch_guidance_message(launch_result, metadata, client=None):
"type": "boolean",
"description": "Optional: Skip interactive confirmation and launch job immediately (default: false)",
},
"policy_id": {
"type": "string",
"description": "Optional: cluster policy ID to use for the Stitch job run",
},
},
required_params=[],
tui_aliases=["/setup-stitch"],
Expand Down
4 changes: 4 additions & 0 deletions chuck_data/commands/stitch_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,10 +423,14 @@ def _helper_launch_stitch_job(

# Launch the Stitch job
try:
# Extract policy_id if present
policy_id = metadata.get("policy_id")

job_run_data = client.submit_job_run(
config_path=config_file_path,
init_script_path=init_script_path,
run_name=f"Stitch Setup: {stitch_job_name}",
policy_id=policy_id,
)
run_id = job_run_data.get("run_id")
if not run_id:
Expand Down
16 changes: 12 additions & 4 deletions chuck_data/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,18 @@ def execute_command(

# Interactive Mode Handling
if command_def.supports_interactive_input:
# For interactive commands, the `interactive_input` is the primary payload.
# Always include the interactive_input parameter, even if None, for interactive commands
# This prevents dropping inputs in complex interactive flows
args_for_handler = {"interactive_input": interactive_input}
# For interactive commands, we still want to parse any initial arguments/flags if provided.
# This supports use cases like "/setup-stitch --auto-confirm --policy_id=..."
parsed_args_dict, error_result = self._parse_and_validate_tui_args(
command_def, raw_args, raw_kwargs
)

if error_result:
return error_result

args_for_handler = parsed_args_dict or {}
# Ensure interactive_input is passed (it overrides any parsed arg with same name, though unlikely)
args_for_handler["interactive_input"] = interactive_input
else:
# Standard Argument Parsing & Validation
parsed_args_dict, error_result = self._parse_and_validate_tui_args(
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/databricks/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def reset(self):

# Reset call tracking
self.create_stitch_notebook_calls = []
self.submit_job_run_calls = []
self.list_catalogs_calls = []
self.get_catalog_calls = []
self.list_schemas_calls = []
Expand Down
24 changes: 22 additions & 2 deletions tests/fixtures/databricks/job_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class JobStubMixin:

def __init__(self):
self.create_stitch_notebook_calls = []
self.submit_job_run_calls = []

def list_jobs(self, **kwargs):
"""List jobs."""
Expand All @@ -23,15 +24,34 @@ def run_job(self, job_id):
"""Run a job."""
return {"run_id": f"run_{job_id}_001", "job_id": job_id, "state": "RUNNING"}

def submit_job_run(self, config_path, init_script_path, run_name=None):
"""Submit a job run and return run_id."""
def submit_job_run(
self, config_path, init_script_path, run_name=None, policy_id=None
):
"""Submit a job run and return run_id.

Args:
config_path: Path to the job configuration file
init_script_path: Path to the init script
run_name: Optional name for the job run
policy_id: Optional cluster policy ID to use for the job run
"""
from datetime import datetime

if not run_name:
run_name = (
f"Chuck AI One-Time Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)

# Track the call for test verification
self.submit_job_run_calls.append(
{
"config_path": config_path,
"init_script_path": init_script_path,
"run_name": run_name,
"policy_id": policy_id,
}
)

# Return a successful job submission
return {"run_id": 123456}

Expand Down
116 changes: 112 additions & 4 deletions tests/unit/commands/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def test_handle_launch_job_no_run_id(databricks_client_stub, temp_config):
"""Test launching a job where response doesn't include run_id."""
with patch("chuck_data.config._config_manager", temp_config):

def submit_no_run_id(config_path, init_script_path, run_name=None):
def submit_no_run_id(
config_path, init_script_path, run_name=None, policy_id=None
):
return {} # No run_id in response

databricks_client_stub.submit_job_run = submit_no_run_id
Expand All @@ -101,7 +103,9 @@ def test_handle_launch_job_http_error(databricks_client_stub, temp_config):
"""Test launching a job with HTTP error response."""
with patch("chuck_data.config._config_manager", temp_config):

def submit_failing(config_path, init_script_path, run_name=None):
def submit_failing(
config_path, init_script_path, run_name=None, policy_id=None
):
raise Exception("Bad Request")

databricks_client_stub.submit_job_run = submit_failing
Expand Down Expand Up @@ -138,6 +142,73 @@ def test_handle_launch_job_missing_url(temp_config):
assert "Client required" in result.message


# --- policy_id Parameter Tests: handle_launch_job ---


def test_handle_launch_job_with_policy_id_success(databricks_client_stub, temp_config):
"""Test launching a job with policy_id passes it to the client."""
with patch("chuck_data.config._config_manager", temp_config):
result: CommandResult = handle_launch_job(
databricks_client_stub,
config_path="/Volumes/test/config.json",
init_script_path="/init/script.sh",
run_name="MyTestJobWithPolicy",
policy_id="000F957411D99C1F",
)
assert result.success is True
assert "123456" in result.message
assert result.data["run_id"] == "123456"

# Verify policy_id was passed to the client
assert len(databricks_client_stub.submit_job_run_calls) == 1
call_args = databricks_client_stub.submit_job_run_calls[0]
assert call_args["policy_id"] == "000F957411D99C1F"
assert call_args["config_path"] == "/Volumes/test/config.json"
assert call_args["init_script_path"] == "/init/script.sh"


def test_handle_launch_job_without_policy_id(databricks_client_stub, temp_config):
"""Test launching a job without policy_id passes None to the client."""
with patch("chuck_data.config._config_manager", temp_config):
result: CommandResult = handle_launch_job(
databricks_client_stub,
config_path="/Volumes/test/config.json",
init_script_path="/init/script.sh",
run_name="MyTestJobNoPolicy",
)
assert result.success is True

# Verify policy_id was passed as None
assert len(databricks_client_stub.submit_job_run_calls) == 1
call_args = databricks_client_stub.submit_job_run_calls[0]
assert call_args["policy_id"] is None


def test_agent_launch_job_with_policy_id_success(databricks_client_stub, temp_config):
"""AGENT TEST: Launching a job with policy_id passes it correctly."""
with patch("chuck_data.config._config_manager", temp_config):
progress_steps = []

def capture_progress(tool_name, data):
progress_steps.append(data.get("step", str(data)))

result = handle_launch_job(
databricks_client_stub,
config_path="/Volumes/test/config.json",
init_script_path="/init/script.sh",
run_name="AgentTestJobWithPolicy",
policy_id="POLICY123ABC",
tool_output_callback=capture_progress,
)
assert result.success is True
assert "123456" in result.message

# Verify policy_id was passed to the client
assert len(databricks_client_stub.submit_job_run_calls) == 1
call_args = databricks_client_stub.submit_job_run_calls[0]
assert call_args["policy_id"] == "POLICY123ABC"


# --- Direct Command Execution Tests: handle_job_status ---


Expand Down Expand Up @@ -230,7 +301,9 @@ def capture_progress(tool_name, data):
progress_steps.append(data.get("step", str(data)))

# Configure stub to return response without run_id
def submit_no_run_id(config_path, init_script_path, run_name=None):
def submit_no_run_id(
config_path, init_script_path, run_name=None, policy_id=None
):
return {} # No run_id in response

databricks_client_stub.submit_job_run = submit_no_run_id
Expand Down Expand Up @@ -346,7 +419,9 @@ def test_agent_tool_executor_launch_job_integration(
"""
with patch("chuck_data.config._config_manager", temp_config):
databricks_client_stub.submit_job_run = (
lambda config_path, init_script_path, run_name=None: {"run_id": "789012"}
lambda config_path, init_script_path, run_name=None, policy_id=None: {
"run_id": "789012"
}
)
tool_args = {
"config_path": "/Volumes/agent/config.json",
Expand All @@ -362,6 +437,39 @@ def test_agent_tool_executor_launch_job_integration(
assert agent_result.get("run_id") == "789012"


def test_agent_tool_executor_launch_job_with_policy_id(
databricks_client_stub, temp_config
):
"""AGENT TEST: End-to-end integration for launching a job with policy_id via execute_tool."""
with patch("chuck_data.config._config_manager", temp_config):
captured_policy_id = []

def mock_submit_job_run(
config_path, init_script_path, run_name=None, policy_id=None
):
captured_policy_id.append(policy_id)
return {"run_id": "789012"}

databricks_client_stub.submit_job_run = mock_submit_job_run

tool_args = {
"config_path": "/Volumes/agent/config.json",
"init_script_path": "/agent/init.sh",
"run_name": "AgentExecutorTestJobWithPolicy",
"policy_id": "AGENT_POLICY_ID_123",
}
agent_result = execute_tool(
api_client=databricks_client_stub,
tool_name="launch_job",
tool_args=tool_args,
)
assert agent_result is not None
assert agent_result.get("run_id") == "789012"
# Verify policy_id was passed through the agent tool executor
assert len(captured_policy_id) == 1
assert captured_policy_id[0] == "AGENT_POLICY_ID_123"


def test_agent_tool_executor_job_status_integration(
databricks_client_stub, temp_config
):
Expand Down
Loading