Skip to content

Commit 6357294

Browse files
authored
SSH: refactor Jupyter init script to use better Spark session initialization flow (#4645)
## Changes * Always use UserNamespaceInitializer do initialize global jupyter variables (sql, table, display, etc), even in serverless context. * Always use Databricks Connect to initialize spark, even on dedicated clusters (where we point it to the cluster itself). ## Why We've had different path for serverless spark before, where we didn't use UserNamespaceInitializer, which was not ideal, since the global jupyter scope for serverless and dedicated was different because of that. The reason to use Databricks Connect on dedicated cluster is to expose the same spark connect API in all environments, avoiding compatibility issues. Local spark has access to some internal jvm APIs, which are not available in spark connect mode, but the rest is the same. ## Tests Manually and existing e2e tests
1 parent 2b857d5 commit 6357294

File tree

1 file changed

+62
-39
lines changed

1 file changed

+62
-39
lines changed

experimental/ssh/internal/server/jupyter-init.py

Lines changed: 62 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,6 @@ def wrapper(*args, **kwargs):
1818
return wrapper
1919

2020

21-
@_log_exceptions
22-
def _setup_dedicated_session():
23-
from dbruntime import UserNamespaceInitializer
24-
25-
_user_namespace_initializer = UserNamespaceInitializer.getOrCreate()
26-
_entry_point = _user_namespace_initializer.get_spark_entry_point()
27-
_globals = _user_namespace_initializer.get_namespace_globals()
28-
for name, value in _globals.items():
29-
print(f"Registering global: {name} = {value}")
30-
if name not in globals():
31-
globals()[name] = value
32-
33-
# 'display' from the runtime uses custom widgets that don't work in Jupyter.
34-
# We use the IPython display instead (in combination with the html formatter for DataFrames).
35-
globals()["display"] = ip_display
36-
37-
3821
@_log_exceptions
3922
def _register_runtime_hooks():
4023
from dbruntime.monkey_patches import apply_dataframe_display_patch
@@ -167,7 +150,7 @@ def _register_common_magics():
167150

168151

169152
@_log_exceptions
170-
def _register_pip_magics(user_namespace_initializer: any, entry_point: any):
153+
def _register_pip_magics():
171154
"""Register the pip magic command parser with IPython."""
172155
from dbruntime.DatasetInfo import UserNamespaceDict
173156
from dbruntime.PipMagicOverrides import PipMagicOverrides
@@ -181,7 +164,15 @@ def _register_pip_magics(user_namespace_initializer: any, entry_point: any):
181164
entry_point,
182165
)
183166
ip = get_ipython()
184-
ip.register_magics(PipMagicOverrides(entry_point, globals["sc"]._conf, user_ns))
167+
168+
try:
169+
# Older DBRs
170+
pip_magic = PipMagicOverrides(entry_point, ip.user_ns["sc"]._conf, user_ns)
171+
except Exception:
172+
# Newer DBRs
173+
pip_magic = PipMagicOverrides(entry_point, user_ns, ip)
174+
175+
ip.register_magics(pip_magic)
185176

186177

187178
@_log_exceptions
@@ -198,34 +189,66 @@ def df_html(df: DataFrame) -> str:
198189
html_formatter.for_type(DataFrame, df_html)
199190

200191

201-
@_log_exceptions
202-
def _setup_serverless_session():
203-
import IPython
192+
def _create_spark_session(builder_fn):
204193
from databricks.connect import DatabricksSession
205194

206-
user_ns = getattr(IPython.get_ipython(), "user_ns", {})
207-
existing_session = getattr(user_ns, "spark", None)
195+
user_ns = get_ipython().user_ns
196+
existing_session = user_ns.get("spark")
197+
# Clear the existing local spark session, otherwise DatabricksSession will re-use it.
198+
user_ns["spark"] = None
208199
try:
209-
# Clear the existing local spark session, otherwise DatabricksSession will re-use it.
210-
user_ns["spark"] = None
211-
globals()["spark"] = None
212-
# DatabricksSession will use the existing env vars for the connection.
213-
spark_session = DatabricksSession.builder.serverless(True).getOrCreate()
214-
user_ns["spark"] = spark_session
215-
globals()["spark"] = spark_session
216-
except Exception as e:
200+
return builder_fn(DatabricksSession.builder).getOrCreate()
201+
except Exception:
217202
user_ns["spark"] = existing_session
218-
globals()["spark"] = existing_session
219-
raise e
203+
raise
204+
205+
206+
def _initialize_spark(is_serverless: bool, existing_spark: any):
207+
from pyspark.sql.session import SparkSession
208+
209+
# On serverless always initialize a new remote Databricks Connect session.
210+
if is_serverless:
211+
return _create_spark_session(lambda b: b.serverless(True))
212+
# On dedicated or standard initialize a new remote session if the existing spark session is local.
213+
if existing_spark is None or isinstance(existing_spark, SparkSession):
214+
return _create_spark_session(
215+
lambda b: b.remote(
216+
host=os.environ["DATABRICKS_HOST"],
217+
token=os.environ["DATABRICKS_TOKEN"],
218+
cluster_id=os.environ["DATABRICKS_CLUSTER_ID"],
219+
)
220+
)
221+
# Otherwise re-use the existing remote session.
222+
return existing_spark
220223

221224

222-
if os.environ.get("DATABRICKS_JUPYTER_SERVERLESS") == "true":
223-
_setup_serverless_session()
224-
else:
225-
_setup_dedicated_session()
226-
_register_pip_magics()
225+
@_log_exceptions
226+
def _setup_globals(is_serverless: bool):
227+
from dbruntime import UserNamespaceInitializer
228+
229+
ns = UserNamespaceInitializer.getOrCreate()
230+
ns_globals = ns.get_namespace_globals()
231+
existing_spark = ns_globals.get("spark")
232+
spark = _initialize_spark(is_serverless, existing_spark)
233+
try:
234+
ns.db_connection.spark_provider.set_spark(spark)
235+
except Exception as e:
236+
print(f"Error updating spark provider: {e}")
237+
ns_globals["spark"] = spark
238+
if spark is not None:
239+
ns_globals["table"] = spark.table
240+
ns_globals["sql"] = spark.sql
241+
user_ns = get_ipython().user_ns
242+
for name, value in ns_globals.items():
243+
print(f"Registering global: {name} = {value}")
244+
user_ns[name] = value
245+
# 'display' from the runtime uses custom widgets that don't work in Jupyter.
246+
# We use the IPython display instead (in combination with the html formatter for DataFrames).
247+
user_ns["display"] = ip_display
227248

228249

250+
_setup_globals(os.environ.get("DATABRICKS_JUPYTER_SERVERLESS") == "true")
251+
_register_pip_magics()
229252
_register_common_magics()
230253
_register_formatters()
231254
_register_runtime_hooks()

0 commit comments

Comments
 (0)