11from typing import List , Optional
22from IPython .core .getipython import get_ipython
33from IPython .display import display as ip_display
4- from dbruntime import UserNamespaceInitializer
4+ import os
55
66
77def _log_exceptions (func ):
@@ -18,18 +18,21 @@ def wrapper(*args, **kwargs):
1818 return wrapper
1919
2020
21- _user_namespace_initializer = UserNamespaceInitializer .getOrCreate ()
22- _entry_point = _user_namespace_initializer .get_spark_entry_point ()
23- _globals = _user_namespace_initializer .get_namespace_globals ()
24- for name , value in _globals .items ():
25- print (f"Registering global: { name } = { value } " )
26- if name not in globals ():
27- globals ()[name ] = value
21+ @_log_exceptions
22+ def _setup_dedicated_session ():
23+ from dbruntime import UserNamespaceInitializer
2824
25+ _user_namespace_initializer = UserNamespaceInitializer .getOrCreate ()
26+ _entry_point = _user_namespace_initializer .get_spark_entry_point ()
27+ _globals = _user_namespace_initializer .get_namespace_globals ()
28+ for name , value in _globals .items ():
29+ print (f"Registering global: { name } = { value } " )
30+ if name not in globals ():
31+ globals ()[name ] = value
2932
30- # 'display' from the runtime uses custom widgets that don't work in Jupyter.
31- # We use the IPython display instead (in combination with the html formatter for DataFrames).
32- globals ()["display" ] = ip_display
33+ # 'display' from the runtime uses custom widgets that don't work in Jupyter.
34+ # We use the IPython display instead (in combination with the html formatter for DataFrames).
35+ globals ()["display" ] = ip_display
3336
3437
3538@_log_exceptions
@@ -157,19 +160,28 @@ def _parse_line_for_databricks_magics(lines: List[str]) -> List[str]:
157160
158161
159162@_log_exceptions
160- def _register_magics ():
161- """Register the magic command parser with IPython."""
163+ def _register_common_magics ():
164+ """Register the common magic command parser with IPython."""
165+ ip = get_ipython ()
166+ ip .input_transformers_cleanup .append (_parse_line_for_databricks_magics )
167+
168+
169+ @_log_exceptions
170+ def _register_pip_magics (user_namespace_initializer : any , entry_point : any ):
171+ """Register the pip magic command parser with IPython."""
162172 from dbruntime .DatasetInfo import UserNamespaceDict
163173 from dbruntime .PipMagicOverrides import PipMagicOverrides
174+ from dbruntime import UserNamespaceInitializer
164175
176+ user_namespace_initializer = UserNamespaceInitializer .getOrCreate ()
177+ entry_point = user_namespace_initializer .get_spark_entry_point ()
165178 user_ns = UserNamespaceDict (
166- _user_namespace_initializer .get_namespace_globals (),
167- _entry_point .getDriverConf (),
168- _entry_point ,
179+ user_namespace_initializer .get_namespace_globals (),
180+ entry_point .getDriverConf (),
181+ entry_point ,
169182 )
170183 ip = get_ipython ()
171- ip .input_transformers_cleanup .append (_parse_line_for_databricks_magics )
172- ip .register_magics (PipMagicOverrides (_entry_point , _globals ["sc" ]._conf , user_ns ))
184+ ip .register_magics (PipMagicOverrides (entry_point , globals ["sc" ]._conf , user_ns ))
173185
174186
175187@_log_exceptions
@@ -186,6 +198,34 @@ def df_html(df: DataFrame) -> str:
186198 html_formatter .for_type (DataFrame , df_html )
187199
188200
189- _register_magics ()
201+ @_log_exceptions
202+ def _setup_serverless_session ():
203+ import IPython
204+ from databricks .connect import DatabricksSession
205+
206+ user_ns = getattr (IPython .get_ipython (), "user_ns" , {})
207+ existing_session = getattr (user_ns , "spark" , None )
208+ try :
209+ # Clear the existing local spark session, otherwise DatabricksSession will re-use it.
210+ user_ns ["spark" ] = None
211+ globals ()["spark" ] = None
212+ # DatabricksSession will use the existing env vars for the connection.
213+ spark_session = DatabricksSession .builder .serverless (True ).getOrCreate ()
214+ user_ns ["spark" ] = spark_session
215+ globals ()["spark" ] = spark_session
216+ except Exception as e :
217+ user_ns ["spark" ] = existing_session
218+ globals ()["spark" ] = existing_session
219+ raise e
220+
221+
222+ if os .environ .get ("DATABRICKS_JUPYTER_SERVERLESS" ) == "true" :
223+ _setup_serverless_session ()
224+ else :
225+ _setup_dedicated_session ()
226+ _register_pip_magics ()
227+
228+
229+ _register_common_magics ()
190230_register_formatters ()
191231_register_runtime_hooks ()
0 commit comments