Skip to content

Commit c520ed9

Browse files
simonfaltumclaude
andauthored
Improve aitools query: async polling, spinner, cancellation, --warehouse flag (#4612)
## Why The existing `databricks experimental aitools tools query` command blocks silently for up to 20 minutes with no user feedback, no way to cancel a running query, and no way to override warehouse auto-detection. This makes it frustrating for interactive use with long-running queries. ## Changes **Before:** The command used `ExecuteAndWait` which blocks silently, has no cancellation support, and always auto-detects the warehouse. **Now:** The command submits queries asynchronously and polls with a visible spinner, supports Ctrl+C cancellation that cleans up server-side, and accepts a `--warehouse` flag. Specifically: - **Async polling with spinner**: Replaced `ExecuteAndWait` with a two-step flow — `ExecuteStatement(WaitTimeout=0s)` to get the statement ID immediately, then poll with `GetStatementByStatementId`. A background ticker updates the spinner text every second showing elapsed time (`⣾ Executing query... (12s elapsed)`). - **Ctrl+C cancellation**: Signal handler catches SIGINT/SIGTERM, cancels the poll context, and calls `CancelExecution` to clean up the running statement server-side. Prints `"Query cancelled."` to stderr. - **`--warehouse` / `-w` flag**: Overrides the auto-detection chain (env var → server default → first running warehouse). - **Poll backoff**: Additive backoff 1s → 2s → 3s → 4s → 5s (capped) between status polls. - **Non-interactive debug logging**: `log.Debugf` on each poll iteration for visibility when stderr is not a TTY. - **Shell quoting hint**: When Databricks returns an `UNRESOLVED_MAP_KEY` error (common when shell double-quote processing strips inner quotes from map key access), the error message includes a hint suggesting single quotes or `--file`. ## Test plan - [x] Unit tests: 13 tests covering immediate success/failure, polling state transitions, context cancellation with `CancelExecution` assertion, warehouse flag resolution, result formatting, terminal state detection, error hints. - [x] Manual testing: ran queries against a live workspace with `--profile`, verified spinner updates every second, verified Ctrl+C cancels server-side, verified `--warehouse` override, verified `UNRESOLVED_MAP_KEY` hint. - [x] `go build ./...` clean - [x] `go test ./experimental/aitools/cmd/ -count=1` passes (5s) --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 197a17a commit c520ed9

File tree

2 files changed

+441
-26
lines changed

2 files changed

+441
-26
lines changed

experimental/aitools/cmd/query.go

Lines changed: 193 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,51 @@
11
package mcp
22

33
import (
4+
"context"
45
"encoding/json"
56
"errors"
67
"fmt"
8+
"os"
9+
"os/signal"
710
"strings"
11+
"syscall"
12+
"time"
813

914
"github.com/databricks/cli/cmd/root"
1015
"github.com/databricks/cli/experimental/aitools/lib/middlewares"
1116
"github.com/databricks/cli/experimental/aitools/lib/session"
1217
"github.com/databricks/cli/libs/cmdctx"
1318
"github.com/databricks/cli/libs/cmdio"
19+
"github.com/databricks/cli/libs/log"
1420
"github.com/databricks/databricks-sdk-go/service/sql"
1521
"github.com/spf13/cobra"
1622
)
1723

24+
const (
25+
// pollIntervalInitial is the starting interval between status polls.
26+
pollIntervalInitial = 1 * time.Second
27+
28+
// pollIntervalMax is the maximum interval between status polls.
29+
pollIntervalMax = 5 * time.Second
30+
31+
// cancelTimeout is how long to wait for server-side cancellation.
32+
cancelTimeout = 10 * time.Second
33+
)
34+
1835
func newQueryCmd() *cobra.Command {
36+
var warehouseID string
37+
1938
cmd := &cobra.Command{
2039
Use: "query SQL",
2140
Short: "Execute SQL against a Databricks warehouse",
2241
Long: `Execute a SQL statement against a Databricks SQL warehouse and return results.
2342
24-
The command auto-detects an available warehouse unless DATABRICKS_WAREHOUSE_ID is set.
43+
The command auto-detects an available warehouse unless --warehouse is set
44+
or the DATABRICKS_WAREHOUSE_ID environment variable is configured.
2545
2646
Output includes the query results as JSON and row count.`,
27-
Example: ` databricks experimental aitools query "SELECT * FROM samples.nyctaxi.trips LIMIT 5"`,
47+
Example: ` databricks experimental aitools tools query "SELECT * FROM samples.nyctaxi.trips LIMIT 5"
48+
databricks experimental aitools tools query --warehouse abc123 "SELECT 1"`,
2849
Args: cobra.ExactArgs(1),
2950
PreRunE: root.MustWorkspaceClient,
3051
RunE: func(cmd *cobra.Command, args []string) error {
@@ -36,31 +57,14 @@ Output includes the query results as JSON and row count.`,
3657
return errors.New("SQL statement is required")
3758
}
3859

39-
// set up session with client for middleware compatibility
40-
sess := session.NewSession()
41-
sess.Set(middlewares.DatabricksClientKey, w)
42-
ctx = session.WithSession(ctx, sess)
43-
44-
warehouseID, err := middlewares.GetWarehouseID(ctx, true)
60+
wID, err := resolveWarehouseID(ctx, w, warehouseID)
4561
if err != nil {
4662
return err
4763
}
4864

49-
resp, err := w.StatementExecution.ExecuteAndWait(ctx, sql.ExecuteStatementRequest{
50-
WarehouseId: warehouseID,
51-
Statement: sqlStatement,
52-
WaitTimeout: "50s",
53-
})
65+
resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqlStatement)
5466
if err != nil {
55-
return fmt.Errorf("execute statement: %w", err)
56-
}
57-
58-
if resp.Status != nil && resp.Status.State == sql.StatementStateFailed {
59-
errMsg := "query failed"
60-
if resp.Status.Error != nil {
61-
errMsg = resp.Status.Error.Message
62-
}
63-
return errors.New(errMsg)
67+
return err
6468
}
6569

6670
output, err := formatQueryResult(resp)
@@ -73,13 +77,178 @@ Output includes the query results as JSON and row count.`,
7377
},
7478
}
7579

80+
cmd.Flags().StringVarP(&warehouseID, "warehouse", "w", "", "SQL warehouse ID to use for execution")
81+
7682
return cmd
7783
}
7884

85+
// resolveWarehouseID returns the warehouse ID to use for query execution.
86+
// Priority: explicit flag > middleware auto-detection (env var > server default > first running).
87+
func resolveWarehouseID(ctx context.Context, w any, flagValue string) (string, error) {
88+
if flagValue != "" {
89+
return flagValue, nil
90+
}
91+
92+
sess := session.NewSession()
93+
sess.Set(middlewares.DatabricksClientKey, w)
94+
ctx = session.WithSession(ctx, sess)
95+
96+
return middlewares.GetWarehouseID(ctx, true)
97+
}
98+
99+
// executeAndPoll submits a SQL statement asynchronously and polls until completion.
100+
// It shows a spinner in interactive mode and supports Ctrl+C cancellation.
101+
func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string) (*sql.StatementResponse, error) {
102+
// Submit asynchronously to get the statement ID immediately for cancellation.
103+
resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{
104+
WarehouseId: warehouseID,
105+
Statement: statement,
106+
WaitTimeout: "0s",
107+
})
108+
if err != nil {
109+
return nil, fmt.Errorf("execute statement: %w", err)
110+
}
111+
112+
statementID := resp.StatementId
113+
114+
// Check if it completed immediately.
115+
if isTerminalState(resp.Status) {
116+
return resp, checkFailedState(resp.Status)
117+
}
118+
119+
// Set up Ctrl+C: signal cancels the poll context, cleanup is unified below.
120+
pollCtx, pollCancel := context.WithCancel(ctx)
121+
defer pollCancel()
122+
123+
sigCh := make(chan os.Signal, 1)
124+
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
125+
defer signal.Stop(sigCh)
126+
127+
go func() {
128+
select {
129+
case <-sigCh:
130+
log.Infof(ctx, "Received interrupt, cancelling query %s", statementID)
131+
pollCancel()
132+
case <-pollCtx.Done():
133+
}
134+
}()
135+
136+
// cancelStatement performs best-effort server-side cancellation.
137+
// Called on any poll exit due to context cancellation (signal or parent).
138+
cancelStatement := func() {
139+
cancelCtx, cancel := context.WithTimeout(context.Background(), cancelTimeout)
140+
defer cancel()
141+
if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{
142+
StatementId: statementID,
143+
}); err != nil {
144+
log.Warnf(ctx, "Failed to cancel statement %s: %v", statementID, err)
145+
}
146+
}
147+
148+
// Spinner for interactive feedback, updated every second via ticker.
149+
sp := cmdio.NewSpinner(pollCtx)
150+
defer sp.Close()
151+
start := time.Now()
152+
sp.Update("Executing query...")
153+
154+
ticker := time.NewTicker(1 * time.Second)
155+
defer ticker.Stop()
156+
go func() {
157+
for {
158+
select {
159+
case <-pollCtx.Done():
160+
return
161+
case <-ticker.C:
162+
elapsed := time.Since(start).Truncate(time.Second)
163+
sp.Update(fmt.Sprintf("Executing query... (%s elapsed)", elapsed))
164+
}
165+
}
166+
}()
167+
168+
// Poll with additive backoff: 1s, 2s, 3s, 4s, 5s (capped).
169+
interval := pollIntervalInitial
170+
for {
171+
select {
172+
case <-pollCtx.Done():
173+
cancelStatement()
174+
cmdio.LogString(ctx, "Query cancelled.")
175+
return nil, root.ErrAlreadyPrinted
176+
case <-time.After(interval):
177+
}
178+
179+
log.Debugf(ctx, "Polling statement %s: %s elapsed", statementID, time.Since(start).Truncate(time.Second))
180+
181+
pollResp, err := api.GetStatementByStatementId(pollCtx, statementID)
182+
if err != nil {
183+
if pollCtx.Err() != nil {
184+
cancelStatement()
185+
cmdio.LogString(ctx, "Query cancelled.")
186+
return nil, root.ErrAlreadyPrinted
187+
}
188+
return nil, fmt.Errorf("poll statement status: %w", err)
189+
}
190+
191+
if isTerminalState(pollResp.Status) {
192+
sp.Close()
193+
if err := checkFailedState(pollResp.Status); err != nil {
194+
return nil, err
195+
}
196+
return &sql.StatementResponse{
197+
StatementId: pollResp.StatementId,
198+
Status: pollResp.Status,
199+
Manifest: pollResp.Manifest,
200+
Result: pollResp.Result,
201+
}, nil
202+
}
203+
204+
interval = min(interval+time.Second, pollIntervalMax)
205+
}
206+
}
207+
208+
// isTerminalState returns true if the statement has reached a final state.
209+
func isTerminalState(status *sql.StatementStatus) bool {
210+
if status == nil {
211+
return false
212+
}
213+
switch status.State {
214+
case sql.StatementStateSucceeded, sql.StatementStateFailed,
215+
sql.StatementStateCanceled, sql.StatementStateClosed:
216+
return true
217+
case sql.StatementStatePending, sql.StatementStateRunning:
218+
return false
219+
}
220+
return false
221+
}
222+
223+
// checkFailedState returns an error if the statement is in a non-success terminal state.
224+
func checkFailedState(status *sql.StatementStatus) error {
225+
if status == nil {
226+
return nil
227+
}
228+
switch status.State {
229+
case sql.StatementStateFailed:
230+
msg := "query failed"
231+
if status.Error != nil {
232+
msg = fmt.Sprintf("query failed: %s %s", status.Error.ErrorCode, status.Error.Message)
233+
if strings.Contains(status.Error.Message, "UNRESOLVED_MAP_KEY") {
234+
msg += "\n\nHint: your shell may have stripped quotes from the SQL string. " +
235+
"Use single quotes for map keys (e.g. info['key']) or pass the query via --file."
236+
}
237+
}
238+
return errors.New(msg)
239+
case sql.StatementStateCanceled:
240+
return errors.New("query was cancelled")
241+
case sql.StatementStateClosed:
242+
return errors.New("query was closed before results could be fetched")
243+
case sql.StatementStatePending, sql.StatementStateRunning, sql.StatementStateSucceeded:
244+
return nil
245+
}
246+
return nil
247+
}
248+
79249
// cleanSQL removes surrounding quotes, empty lines, and SQL comments.
80250
func cleanSQL(s string) string {
81251
s = strings.TrimSpace(s)
82-
// remove surrounding quotes if present
83252
if (strings.HasPrefix(s, `"`) && strings.HasSuffix(s, `"`)) ||
84253
(strings.HasPrefix(s, `'`) && strings.HasSuffix(s, `'`)) {
85254
s = s[1 : len(s)-1]
@@ -88,12 +257,12 @@ func cleanSQL(s string) string {
88257
var lines []string
89258
for _, line := range strings.Split(s, "\n") {
90259
line = strings.TrimSpace(line)
91-
// skip empty lines and single-line comments
92260
if line == "" || strings.HasPrefix(line, "--") {
93261
continue
94262
}
95263
lines = append(lines, line)
96264
}
265+
97266
return strings.Join(lines, "\n")
98267
}
99268

@@ -105,15 +274,13 @@ func formatQueryResult(resp *sql.StatementResponse) (string, error) {
105274
return sb.String(), nil
106275
}
107276

108-
// get column names
109277
var columns []string
110278
if resp.Manifest.Schema != nil {
111279
for _, col := range resp.Manifest.Schema.Columns {
112280
columns = append(columns, col.Name)
113281
}
114282
}
115283

116-
// format as JSON array for consistency with Neon API
117284
var rows []map[string]any
118285
if resp.Result.DataArray != nil {
119286
for _, row := range resp.Result.DataArray {

0 commit comments

Comments
 (0)