Skip to content

Commit 299af6c

Browse files
authored
Respect user's default warehouse override in aitools (#4614)
## Why The `aitools tools get-default-warehouse` command ignores the user's default warehouse override (a per-user preference set via the SQL UI or `databricks warehouses create-default-warehouse-override`). Users who configured a preferred warehouse still get a different one. Additionally, commands like `query` and `discover-schema` fail when the resolved warehouse is stopped. ## Changes Before: the command checked the `DATABRICKS_WAREHOUSE_ID` env var, then fell through to server-side default detection (which does not incorporate user overrides). Stopped warehouses were returned as-is. Now: - After the env var check, the command queries the user's default warehouse override via `GetDefaultWarehouseOverride("default-warehouse-overrides/me")`. If a `CUSTOM` override exists with a valid, non-deleted warehouse, that warehouse is used. All errors silently fall through to existing behavior. `LAST_SELECTED` overrides are skipped since they require UI state not available from the CLI. - `GetWarehouseEndpoint` and `GetWarehouseID` accept an `autoStart` parameter. When true, a stopped warehouse is started and the call blocks until it reaches RUNNING. Enabled for `query` and `discover-schema`, disabled for `get-default-warehouse` and `discover`. New resolution priority: 1. `DATABRICKS_WAREHOUSE_ID` env var 2. User's default warehouse override (`CUSTOM` type only) 3. Server-side "default" warehouse 4. First usable warehouse by state ## Test plan - [x] `make test-exp-aitools` passes (96 unit tests + 25 acceptance tests) - [x] Manual: `get-default-warehouse` returns existing default when no override is set - [x] Manual: set override with `create-default-warehouse-override me CUSTOM --warehouse-id <id>`, verified `get-default-warehouse` returns the overridden warehouse - [x] Manual: deleted override, verified fallback to original behavior
1 parent d00476d commit 299af6c

File tree

5 files changed

+68
-8
lines changed

5 files changed

+68
-8
lines changed

experimental/aitools/cmd/discover_schema.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ For each table, returns:
5050
sess.Set(middlewares.DatabricksClientKey, w)
5151
ctx = session.WithSession(ctx, sess)
5252

53-
warehouseID, err := middlewares.GetWarehouseID(ctx)
53+
warehouseID, err := middlewares.GetWarehouseID(ctx, true)
5454
if err != nil {
5555
return err
5656
}

experimental/aitools/cmd/get_default_warehouse.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Returns warehouse ID of the default warehouse. Use --output json to get the full
4343
sess.Set(middlewares.DatabricksClientKey, w)
4444
ctx = session.WithSession(ctx, sess)
4545

46-
warehouse, err := middlewares.GetWarehouseEndpoint(ctx)
46+
warehouse, err := middlewares.GetWarehouseEndpoint(ctx, false)
4747
if err != nil {
4848
return err
4949
}

experimental/aitools/cmd/query.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Output includes the query results as JSON and row count.`,
4141
sess.Set(middlewares.DatabricksClientKey, w)
4242
ctx = session.WithSession(ctx, sess)
4343

44-
warehouseID, err := middlewares.GetWarehouseID(ctx)
44+
warehouseID, err := middlewares.GetWarehouseID(ctx, true)
4545
if err != nil {
4646
return err
4747
}

experimental/aitools/lib/middlewares/warehouse.go

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/databricks/cli/experimental/aitools/lib/session"
99
"github.com/databricks/cli/libs/databrickscfg/cfgpickers"
1010
"github.com/databricks/cli/libs/env"
11+
"github.com/databricks/databricks-sdk-go"
1112
"github.com/databricks/databricks-sdk-go/service/sql"
1213
)
1314

@@ -39,7 +40,9 @@ func loadWarehouseInBackground(ctx context.Context) {
3940
sess.Set("warehouse_endpoint", warehouse)
4041
}
4142

42-
func GetWarehouseEndpoint(ctx context.Context) (*sql.EndpointInfo, error) {
43+
// GetWarehouseEndpoint returns the resolved warehouse endpoint.
44+
// If autoStart is true and the warehouse is stopped, it will be started automatically.
45+
func GetWarehouseEndpoint(ctx context.Context, autoStart bool) (*sql.EndpointInfo, error) {
4346
sess, err := session.GetSession(ctx)
4447
if err != nil {
4548
return nil, err
@@ -68,23 +71,62 @@ func GetWarehouseEndpoint(ctx context.Context) (*sql.EndpointInfo, error) {
6871
sess.Set("warehouse_endpoint", warehouse)
6972
}
7073

71-
return warehouse.(*sql.EndpointInfo), nil
74+
endpoint := warehouse.(*sql.EndpointInfo)
75+
76+
if autoStart && (endpoint.State == sql.StateStopped || endpoint.State == sql.StateStopping) {
77+
endpoint, err = startWarehouse(ctx, endpoint.Id)
78+
if err != nil {
79+
return nil, err
80+
}
81+
sess.Set("warehouse_endpoint", endpoint)
82+
}
83+
84+
return endpoint, nil
7285
}
7386

74-
func GetWarehouseID(ctx context.Context) (string, error) {
75-
warehouse, err := GetWarehouseEndpoint(ctx)
87+
// GetWarehouseID returns the resolved warehouse ID.
88+
// If autoStart is true and the warehouse is stopped, it will be started automatically.
89+
func GetWarehouseID(ctx context.Context, autoStart bool) (string, error) {
90+
warehouse, err := GetWarehouseEndpoint(ctx, autoStart)
7691
if err != nil {
7792
return "", err
7893
}
7994
return warehouse.Id, nil
8095
}
8196

97+
func startWarehouse(ctx context.Context, id string) (*sql.EndpointInfo, error) {
98+
w, err := GetDatabricksClient(ctx)
99+
if err != nil {
100+
return nil, fmt.Errorf("get databricks client: %w", err)
101+
}
102+
wait, err := w.Warehouses.Start(ctx, sql.StartRequest{Id: id})
103+
if err != nil {
104+
return nil, fmt.Errorf("start warehouse %s: %w", id, err)
105+
}
106+
resp, err := wait.Get()
107+
if err != nil {
108+
return nil, fmt.Errorf("wait for warehouse %s to start: %w", id, err)
109+
}
110+
return &sql.EndpointInfo{
111+
Id: resp.Id,
112+
Name: resp.Name,
113+
State: resp.State,
114+
}, nil
115+
}
116+
82117
func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
83118
w, err := GetDatabricksClient(ctx)
84119
if err != nil {
85120
return nil, fmt.Errorf("get databricks client: %w", err)
86121
}
122+
return resolveWarehouse(ctx, w)
123+
}
87124

125+
// resolveWarehouse selects a warehouse using the following priority:
126+
// 1. DATABRICKS_WAREHOUSE_ID env var
127+
// 2. User's default warehouse override (CUSTOM type only)
128+
// 3. Server-side default / first usable warehouse by state
129+
func resolveWarehouse(ctx context.Context, w *databricks.WorkspaceClient) (*sql.EndpointInfo, error) {
88130
// first resolve DATABRICKS_WAREHOUSE_ID env variable
89131
warehouseID := env.Get(ctx, "DATABRICKS_WAREHOUSE_ID")
90132
if warehouseID != "" {
@@ -101,5 +143,23 @@ func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
101143
}, nil
102144
}
103145

146+
// Check user's default warehouse override (set via the SQL UI or CLI).
147+
// Only CUSTOM overrides are used; LAST_SELECTED requires UI state we don't have.
148+
override, err := w.Warehouses.GetDefaultWarehouseOverride(ctx, sql.GetDefaultWarehouseOverrideRequest{
149+
Name: "default-warehouse-overrides/me",
150+
})
151+
if err == nil && override.Type == sql.DefaultWarehouseOverrideTypeCustom && override.WarehouseId != "" {
152+
warehouse, err := w.Warehouses.Get(ctx, sql.GetWarehouseRequest{
153+
Id: override.WarehouseId,
154+
})
155+
if err == nil && warehouse.State != sql.StateDeleted && warehouse.State != sql.StateDeleting {
156+
return &sql.EndpointInfo{
157+
Id: warehouse.Id,
158+
Name: warehouse.Name,
159+
State: warehouse.State,
160+
}, nil
161+
}
162+
}
163+
104164
return cfgpickers.GetDefaultWarehouse(ctx, w)
105165
}

experimental/aitools/lib/providers/clitools/discover.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616
// Discover provides workspace context and workflow guidance.
1717
// Returns L1 (flow) always + L2 (target) for detected target types + L3 (skills) listing.
1818
func Discover(ctx context.Context, workingDirectory string) (string, error) {
19-
warehouse, err := middlewares.GetWarehouseEndpoint(ctx)
19+
warehouse, err := middlewares.GetWarehouseEndpoint(ctx, false)
2020
if err != nil {
2121
log.Debugf(ctx, "Failed to get default warehouse (non-fatal): %v", err)
2222
warehouse = nil

0 commit comments

Comments
 (0)