@@ -18,23 +18,6 @@ def wrapper(*args, **kwargs):
1818 return wrapper
1919
2020
21- @_log_exceptions
22- def _setup_dedicated_session ():
23- from dbruntime import UserNamespaceInitializer
24-
25- _user_namespace_initializer = UserNamespaceInitializer .getOrCreate ()
26- _entry_point = _user_namespace_initializer .get_spark_entry_point ()
27- _globals = _user_namespace_initializer .get_namespace_globals ()
28- for name , value in _globals .items ():
29- print (f"Registering global: { name } = { value } " )
30- if name not in globals ():
31- globals ()[name ] = value
32-
33- # 'display' from the runtime uses custom widgets that don't work in Jupyter.
34- # We use the IPython display instead (in combination with the html formatter for DataFrames).
35- globals ()["display" ] = ip_display
36-
37-
3821@_log_exceptions
3922def _register_runtime_hooks ():
4023 from dbruntime .monkey_patches import apply_dataframe_display_patch
@@ -167,7 +150,7 @@ def _register_common_magics():
167150
168151
169152@_log_exceptions
170- def _register_pip_magics (user_namespace_initializer : any , entry_point : any ):
153+ def _register_pip_magics ():
171154 """Register the pip magic command parser with IPython."""
172155 from dbruntime .DatasetInfo import UserNamespaceDict
173156 from dbruntime .PipMagicOverrides import PipMagicOverrides
@@ -181,7 +164,15 @@ def _register_pip_magics(user_namespace_initializer: any, entry_point: any):
181164 entry_point ,
182165 )
183166 ip = get_ipython ()
184- ip .register_magics (PipMagicOverrides (entry_point , globals ["sc" ]._conf , user_ns ))
167+
168+ try :
169+ # Older DBRs
170+ pip_magic = PipMagicOverrides (entry_point , ip .user_ns ["sc" ]._conf , user_ns )
171+ except Exception :
172+ # Newer DBRs
173+ pip_magic = PipMagicOverrides (entry_point , user_ns , ip )
174+
175+ ip .register_magics (pip_magic )
185176
186177
187178@_log_exceptions
@@ -198,34 +189,66 @@ def df_html(df: DataFrame) -> str:
198189 html_formatter .for_type (DataFrame , df_html )
199190
200191
201- @_log_exceptions
202- def _setup_serverless_session ():
203- import IPython
192+ def _create_spark_session (builder_fn ):
204193 from databricks .connect import DatabricksSession
205194
206- user_ns = getattr (IPython .get_ipython (), "user_ns" , {})
207- existing_session = getattr (user_ns , "spark" , None )
195+ user_ns = get_ipython ().user_ns
196+ existing_session = user_ns .get ("spark" )
197+ # Clear the existing local spark session, otherwise DatabricksSession will re-use it.
198+ user_ns ["spark" ] = None
208199 try :
209- # Clear the existing local spark session, otherwise DatabricksSession will re-use it.
210- user_ns ["spark" ] = None
211- globals ()["spark" ] = None
212- # DatabricksSession will use the existing env vars for the connection.
213- spark_session = DatabricksSession .builder .serverless (True ).getOrCreate ()
214- user_ns ["spark" ] = spark_session
215- globals ()["spark" ] = spark_session
216- except Exception as e :
200+ return builder_fn (DatabricksSession .builder ).getOrCreate ()
201+ except Exception :
217202 user_ns ["spark" ] = existing_session
218- globals ()["spark" ] = existing_session
219- raise e
203+ raise
204+
205+
206+ def _initialize_spark (is_serverless : bool , existing_spark : any ):
207+ from pyspark .sql .session import SparkSession
208+
209+ # On serverless always initialize a new remote Databricks Connect session.
210+ if is_serverless :
211+ return _create_spark_session (lambda b : b .serverless (True ))
212+ # On dedicated or standard initialize a new remote session if the existing spark session is local.
213+ if existing_spark is None or isinstance (existing_spark , SparkSession ):
214+ return _create_spark_session (
215+ lambda b : b .remote (
216+ host = os .environ ["DATABRICKS_HOST" ],
217+ token = os .environ ["DATABRICKS_TOKEN" ],
218+ cluster_id = os .environ ["DATABRICKS_CLUSTER_ID" ],
219+ )
220+ )
221+ # Otherwise re-use the existing remote session.
222+ return existing_spark
220223
221224
222- if os .environ .get ("DATABRICKS_JUPYTER_SERVERLESS" ) == "true" :
223- _setup_serverless_session ()
224- else :
225- _setup_dedicated_session ()
226- _register_pip_magics ()
225+ @_log_exceptions
226+ def _setup_globals (is_serverless : bool ):
227+ from dbruntime import UserNamespaceInitializer
228+
229+ ns = UserNamespaceInitializer .getOrCreate ()
230+ ns_globals = ns .get_namespace_globals ()
231+ existing_spark = ns_globals .get ("spark" )
232+ spark = _initialize_spark (is_serverless , existing_spark )
233+ try :
234+ ns .db_connection .spark_provider .set_spark (spark )
235+ except Exception as e :
236+ print (f"Error updating spark provider: { e } " )
237+ ns_globals ["spark" ] = spark
238+ if spark is not None :
239+ ns_globals ["table" ] = spark .table
240+ ns_globals ["sql" ] = spark .sql
241+ user_ns = get_ipython ().user_ns
242+ for name , value in ns_globals .items ():
243+ print (f"Registering global: { name } = { value } " )
244+ user_ns [name ] = value
245+ # 'display' from the runtime uses custom widgets that don't work in Jupyter.
246+ # We use the IPython display instead (in combination with the html formatter for DataFrames).
247+ user_ns ["display" ] = ip_display
227248
228249
250+ _setup_globals (os .environ .get ("DATABRICKS_JUPYTER_SERVERLESS" ) == "true" )
251+ _register_pip_magics ()
229252_register_common_magics ()
230253_register_formatters ()
231254_register_runtime_hooks ()
0 commit comments