Skip to content

Commit 737a1ce

Browse files
authored
SSH: initialize Databricks Connect session in serverless mode (#4590)
## Changes Init spark session with databricks connect on serverless, instead of using user_namespace_initializer (which doesn't work there). ## Why Databricks connect session will ensure that all spark workloads go through spark connect router and end up on interactive spark compute ## Tests Existing tests, plus manually on both dedicated clusters and SGC
1 parent 439c5de commit 737a1ce

File tree

6 files changed

+72
-20
lines changed

6 files changed

+72
-20
lines changed

experimental/ssh/cmd/server.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ and proxies them to local SSH daemon processes.
3030
var version string
3131
var secretScopeName string
3232
var authorizedKeySecretName string
33+
var serverless bool
3334

3435
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
3536
cmd.MarkFlagRequired("cluster")
@@ -43,6 +44,7 @@ and proxies them to local SSH daemon processes.
4344
cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients")
4445
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down after no pings from clients")
4546
cmd.Flags().StringVar(&version, "version", "", "Client version of the Databricks CLI")
47+
cmd.Flags().BoolVar(&serverless, "serverless", false, "Enable serverless mode for Jupyter initialization")
4648

4749
cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
4850
// The server can be executed under a directory with an invalid bundle configuration.
@@ -70,6 +72,7 @@ and proxies them to local SSH daemon processes.
7072
AuthorizedKeySecretName: authorizedKeySecretName,
7173
DefaultPort: defaultServerPort,
7274
PortRange: serverPortRange,
75+
Serverless: serverless,
7376
}
7477
return server.Run(ctx, wsc, opts)
7578
}

experimental/ssh/internal/client/client.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
452452
"shutdownDelay": opts.ShutdownDelay.String(),
453453
"maxClients": strconv.Itoa(opts.MaxClients),
454454
"sessionId": sessionID,
455+
"serverless": strconv.FormatBool(opts.IsServerlessMode()),
455456
}
456457

457458
cmdio.LogString(ctx, "Submitting a job to start the ssh server...")

experimental/ssh/internal/client/ssh-server-bootstrap.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
dbutils.widgets.text("authorizedKeySecretName", "")
1818
dbutils.widgets.text("maxClients", "10")
1919
dbutils.widgets.text("shutdownDelay", "10m")
20-
dbutils.widgets.text("sessionId", "") # Required: unique identifier for the session
20+
dbutils.widgets.text("sessionId", "")
21+
dbutils.widgets.text("serverless", "false")
2122

2223

2324
def cleanup():
@@ -115,6 +116,7 @@ def run_ssh_server():
115116
session_id = dbutils.widgets.get("sessionId")
116117
if not session_id:
117118
raise RuntimeError("Session ID is required. Please provide it using the 'sessionId' widget.")
119+
serverless = dbutils.widgets.get("serverless")
118120

119121
arch = platform.machine()
120122
if arch == "x86_64":
@@ -137,6 +139,7 @@ def run_ssh_server():
137139
"server",
138140
f"--cluster={ctx.clusterId}",
139141
f"--session-id={session_id}",
142+
f"--serverless={serverless}",
140143
f"--secret-scope-name={secrets_scope}",
141144
f"--authorized-key-secret-name={public_key_secret_name}",
142145
f"--max-clients={max_clients}",

experimental/ssh/internal/server/jupyter-init.py

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import List, Optional
22
from IPython.core.getipython import get_ipython
33
from IPython.display import display as ip_display
4-
from dbruntime import UserNamespaceInitializer
4+
import os
55

66

77
def _log_exceptions(func):
@@ -18,18 +18,21 @@ def wrapper(*args, **kwargs):
1818
return wrapper
1919

2020

21-
_user_namespace_initializer = UserNamespaceInitializer.getOrCreate()
22-
_entry_point = _user_namespace_initializer.get_spark_entry_point()
23-
_globals = _user_namespace_initializer.get_namespace_globals()
24-
for name, value in _globals.items():
25-
print(f"Registering global: {name} = {value}")
26-
if name not in globals():
27-
globals()[name] = value
21+
@_log_exceptions
22+
def _setup_dedicated_session():
23+
from dbruntime import UserNamespaceInitializer
2824

25+
_user_namespace_initializer = UserNamespaceInitializer.getOrCreate()
26+
_entry_point = _user_namespace_initializer.get_spark_entry_point()
27+
_globals = _user_namespace_initializer.get_namespace_globals()
28+
for name, value in _globals.items():
29+
print(f"Registering global: {name} = {value}")
30+
if name not in globals():
31+
globals()[name] = value
2932

30-
# 'display' from the runtime uses custom widgets that don't work in Jupyter.
31-
# We use the IPython display instead (in combination with the html formatter for DataFrames).
32-
globals()["display"] = ip_display
33+
# 'display' from the runtime uses custom widgets that don't work in Jupyter.
34+
# We use the IPython display instead (in combination with the html formatter for DataFrames).
35+
globals()["display"] = ip_display
3336

3437

3538
@_log_exceptions
@@ -157,19 +160,28 @@ def _parse_line_for_databricks_magics(lines: List[str]) -> List[str]:
157160

158161

159162
@_log_exceptions
160-
def _register_magics():
161-
"""Register the magic command parser with IPython."""
163+
def _register_common_magics():
164+
"""Register the common magic command parser with IPython."""
165+
ip = get_ipython()
166+
ip.input_transformers_cleanup.append(_parse_line_for_databricks_magics)
167+
168+
169+
@_log_exceptions
170+
def _register_pip_magics(user_namespace_initializer: any, entry_point: any):
171+
"""Register the pip magic command parser with IPython."""
162172
from dbruntime.DatasetInfo import UserNamespaceDict
163173
from dbruntime.PipMagicOverrides import PipMagicOverrides
174+
from dbruntime import UserNamespaceInitializer
164175

176+
user_namespace_initializer = UserNamespaceInitializer.getOrCreate()
177+
entry_point = user_namespace_initializer.get_spark_entry_point()
165178
user_ns = UserNamespaceDict(
166-
_user_namespace_initializer.get_namespace_globals(),
167-
_entry_point.getDriverConf(),
168-
_entry_point,
179+
user_namespace_initializer.get_namespace_globals(),
180+
entry_point.getDriverConf(),
181+
entry_point,
169182
)
170183
ip = get_ipython()
171-
ip.input_transformers_cleanup.append(_parse_line_for_databricks_magics)
172-
ip.register_magics(PipMagicOverrides(_entry_point, _globals["sc"]._conf, user_ns))
184+
ip.register_magics(PipMagicOverrides(entry_point, globals["sc"]._conf, user_ns))
173185

174186

175187
@_log_exceptions
@@ -186,6 +198,34 @@ def df_html(df: DataFrame) -> str:
186198
html_formatter.for_type(DataFrame, df_html)
187199

188200

189-
_register_magics()
201+
@_log_exceptions
202+
def _setup_serverless_session():
203+
import IPython
204+
from databricks.connect import DatabricksSession
205+
206+
user_ns = getattr(IPython.get_ipython(), "user_ns", {})
207+
existing_session = getattr(user_ns, "spark", None)
208+
try:
209+
# Clear the existing local spark session, otherwise DatabricksSession will re-use it.
210+
user_ns["spark"] = None
211+
globals()["spark"] = None
212+
# DatabricksSession will use the existing env vars for the connection.
213+
spark_session = DatabricksSession.builder.serverless(True).getOrCreate()
214+
user_ns["spark"] = spark_session
215+
globals()["spark"] = spark_session
216+
except Exception as e:
217+
user_ns["spark"] = existing_session
218+
globals()["spark"] = existing_session
219+
raise e
220+
221+
222+
if os.environ.get("DATABRICKS_JUPYTER_SERVERLESS") == "true":
223+
_setup_serverless_session()
224+
else:
225+
_setup_dedicated_session()
226+
_register_pip_magics()
227+
228+
229+
_register_common_magics()
190230
_register_formatters()
191231
_register_runtime_hooks()

experimental/ssh/internal/server/server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ type ServerOptions struct {
3434
// SessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless).
3535
// Used for metadata storage path. Defaults to ClusterID if not set.
3636
SessionID string
37+
// Serverless indicates whether the server is running on serverless compute.
38+
Serverless bool
3739
// The directory to store sshd configuration
3840
ConfigDir string
3941
// The name of the secrets scope to use for client and server keys

experimental/ssh/internal/server/sshd.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient,
6767
setEnv += " GIT_CONFIG_GLOBAL=/Workspace/.proc/self/git/config"
6868
setEnv += " ENABLE_DATABRICKS_CLI=true"
6969
setEnv += " PYTHONPYCACHEPREFIX=/tmp/pycache"
70+
if opts.Serverless {
71+
setEnv += " DATABRICKS_JUPYTER_SERVERLESS=true"
72+
}
7073

7174
sshdConfigContent := "PubkeyAuthentication yes\n" +
7275
"PasswordAuthentication no\n" +

0 commit comments

Comments
 (0)