Skip to content

Commit e8f6ec2

Browse files
hectorcast-dbclaude
andcommitted
Fix OIDC endpoints detection
Fix oidc_endpoints property to separate Databricks and Azure OIDC endpoints. ## Problem The oidc_endpoints property incorrectly returned Azure OIDC endpoints when ARM_CLIENT_ID was set, even for Databricks OAuth flows (like oauth-m2m) that don't use Azure authentication. This caused Databricks M2M OAuth to fail when users set ARM_CLIENT_ID for other purposes. ## Solution - Created databricks_oidc_endpoints property for Databricks OIDC only - Kept oidc_endpoints for backward compatibility (marked as deprecated) - Updated all Databricks OAuth flows to use databricks_oidc_endpoints - Updated Azure-specific flows to explicitly use Azure endpoints ## Tests - Added 4 integration tests for OAuth M2M and Azure Client Secret auth - Added 6 unit tests covering all scenarios including the bug case Mirrors: databricks/databricks-sdk-java#657 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent a0b7590 commit e8f6ec2

File tree

5 files changed

+283
-11
lines changed

5 files changed

+283
-11
lines changed

databricks/sdk/config.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -486,12 +486,23 @@ def with_user_agent_extra(self, key: str, value: str) -> "Config":
486486
return self
487487

488488
@property
489-
def oidc_endpoints(self) -> Optional[OidcEndpoints]:
489+
def databricks_oidc_endpoints(self) -> Optional[OidcEndpoints]:
490+
"""Get OIDC endpoints for Databricks OAuth.
491+
492+
This method returns the appropriate Databricks OIDC endpoints based on the host type:
493+
- Unified hosts: Returns unified account-scoped endpoints
494+
- Account hosts: Returns traditional account endpoints
495+
- Workspace hosts: Returns workspace endpoints
496+
497+
Note: This method does NOT return Azure Entra ID endpoints. For Azure authentication,
498+
use get_azure_entra_id_workspace_endpoints() directly.
499+
500+
Returns:
501+
OidcEndpoints for Databricks OAuth, or None if host is not configured.
502+
"""
490503
self._fix_host_if_needed()
491504
if not self.host:
492505
return None
493-
if self.is_azure and self.azure_client_id:
494-
return get_azure_entra_id_workspace_endpoints(self.host)
495506

496507
# Handle unified hosts
497508
if self.host_type == HostType.UNIFIED:
@@ -506,6 +517,28 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]:
506517
# Default to workspace endpoints
507518
return get_workspace_endpoints(self.host)
508519

520+
@property
521+
def oidc_endpoints(self) -> Optional[OidcEndpoints]:
522+
"""[DEPRECATED] Get OIDC endpoints with automatic Azure detection (deprecated).
523+
524+
This method incorrectly returns Azure OIDC endpoints when azure_client_id
525+
is set, even for Databricks OAuth flows that don't use Azure authentication. This caused
526+
bugs where Databricks M2M OAuth would fail when ARM_CLIENT_ID was set for other purposes.
527+
528+
Use instead:
529+
- databricks_oidc_endpoints: For Databricks OAuth (oauth-m2m, external-browser, etc.)
530+
- get_azure_entra_id_workspace_endpoints(): For Azure Entra ID authentication
531+
532+
Returns:
533+
OidcEndpoints (Azure or Databricks depending on config), or None if host is not configured.
534+
"""
535+
self._fix_host_if_needed()
536+
if not self.host:
537+
return None
538+
if self.is_azure and self.azure_client_id:
539+
return get_azure_entra_id_workspace_endpoints(self.host)
540+
return self.databricks_oidc_endpoints
541+
509542
def debug_string(self) -> str:
510543
"""Returns log-friendly representation of configured attributes"""
511544
buf = []

databricks/sdk/credentials_provider.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from google.auth.transport.requests import Request # type: ignore
2121
from google.oauth2 import service_account # type: ignore
2222

23+
from databricks.sdk.oauth import get_azure_entra_id_workspace_endpoints
24+
2325
from . import azure, oauth, oidc, oidc_token_supplier
2426
from .client_types import ClientType
2527

@@ -218,7 +220,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
218220
"""Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
219221
if /oidc/.well-known/oauth-authorization-server is available on the given host.
220222
"""
221-
oidc = cfg.oidc_endpoints
223+
oidc = cfg.databricks_oidc_endpoints
222224
if oidc is None:
223225
return None
224226

@@ -248,14 +250,21 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
248250
return None
249251

250252
client_id, client_secret = None, None
253+
oidc_endpoints = None
251254
if cfg.client_id:
252255
client_id = cfg.client_id
253256
client_secret = cfg.client_secret
257+
oidc_endpoints = cfg.databricks_oidc_endpoints
254258
elif cfg.azure_client_id:
255-
client_id = cfg.azure_client
259+
client_id = cfg.azure_client_id
256260
client_secret = cfg.azure_client_secret
261+
oidc_endpoints = get_azure_entra_id_workspace_endpoints(cfg.host)
257262
if not client_id:
258263
client_id = "databricks-cli"
264+
oidc_endpoints = cfg.databricks_oidc_endpoints
265+
266+
if not oidc_endpoints:
267+
return None
259268

260269
scopes = cfg.get_scopes()
261270
if not cfg.disable_oauth_refresh_token:
@@ -264,7 +273,6 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
264273

265274
# Load cached credentials from disk if they exist. Note that these are
266275
# local to the Python SDK and not reused by other SDKs.
267-
oidc_endpoints = cfg.oidc_endpoints
268276
redirect_url = "http://localhost:8020"
269277
token_cache = oauth.TokenCache(
270278
host=cfg.host,
@@ -390,7 +398,7 @@ def oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optio
390398

391399
token_source = oidc.DatabricksOidcTokenSource(
392400
host=cfg.host,
393-
token_endpoint=cfg.oidc_endpoints.token_endpoint,
401+
token_endpoint=cfg.databricks_oidc_endpoints.token_endpoint,
394402
client_id=cfg.client_id,
395403
account_id=cfg.account_id,
396404
id_token_source=id_token_source,
@@ -434,7 +442,7 @@ def _oidc_credentials_provider(
434442
if audience is None and cfg.client_type == ClientType.ACCOUNT:
435443
audience = cfg.account_id
436444
if audience is None and cfg.client_type != ClientType.ACCOUNT:
437-
audience = cfg.oidc_endpoints.token_endpoint
445+
audience = cfg.databricks_oidc_endpoints.token_endpoint
438446

439447
# Try to get an OIDC token. If no supplier returns a token, we cannot use this authentication mode.
440448
id_token = supplier.get_oidc_token(audience)
@@ -453,7 +461,7 @@ def token_source_for(audience: str) -> oauth.TokenSource:
453461
return oauth.ClientCredentials(
454462
client_id=cfg.client_id,
455463
client_secret="", # we have no (rotatable) secrets in OIDC flow
456-
token_url=cfg.oidc_endpoints.token_endpoint,
464+
token_url=cfg.databricks_oidc_endpoints.token_endpoint,
457465
endpoint_params={
458466
"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
459467
"subject_token": id_token,
@@ -528,7 +536,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
528536
aad_endpoint = cfg.arm_environment.active_directory_endpoint
529537
if not cfg.azure_tenant_id:
530538
# detect Azure AD Tenant ID if it's not specified directly
531-
token_endpoint = cfg.oidc_endpoints.token_endpoint
539+
token_endpoint = get_azure_entra_id_workspace_endpoints(cfg.host).token_endpoint
532540
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, "").split("/")[0]
533541

534542
inner = oauth.ClientCredentials(

databricks/sdk/oauth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ def noop_credentials(_: any):
692692
return lambda: {}
693693

694694
config = Config(host=host, credentials_strategy=noop_credentials)
695-
oidc = config.oidc_endpoints
695+
oidc = config.databricks_oidc_endpoints
696696
if not oidc:
697697
raise ValueError(f"{host} does not support OAuth")
698698
return OAuthClient(oidc, redirect_url, client_id, scopes, client_secret)

tests/integration/test_auth.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,98 @@ def test_wif_workspace(ucacct, env_or_skip, random):
269269
)
270270

271271
ws.current_user.me()
272+
273+
274+
def test_workspace_oauth_m2m_auth(w, env_or_skip):
275+
env_or_skip("CLOUD_ENV")
276+
277+
# Get environment variables
278+
host = env_or_skip("DATABRICKS_HOST")
279+
client_id = env_or_skip("TEST_DATABRICKS_CLIENT_ID")
280+
client_secret = env_or_skip("TEST_DATABRICKS_CLIENT_SECRET")
281+
282+
# Create workspace client with OAuth M2M authentication
283+
ws = WorkspaceClient(
284+
host=host,
285+
client_id=client_id,
286+
client_secret=client_secret,
287+
auth_type="oauth-m2m",
288+
)
289+
290+
# Call the "me" API
291+
me = ws.current_user.me()
292+
293+
# Verify we got a valid response
294+
assert me.user_name, "expected non-empty user_name"
295+
296+
297+
def test_workspace_azure_client_secret_auth(w, env_or_skip):
298+
env_or_skip("CLOUD_ENV")
299+
300+
host = env_or_skip("DATABRICKS_HOST")
301+
azure_client_id = env_or_skip("ARM_CLIENT_ID")
302+
azure_client_secret = env_or_skip("ARM_CLIENT_SECRET")
303+
azure_tenant_id = env_or_skip("ARM_TENANT_ID")
304+
305+
# Create workspace client with Azure client secret authentication
306+
ws = WorkspaceClient(
307+
host=host,
308+
azure_client_id=azure_client_id,
309+
azure_client_secret=azure_client_secret,
310+
azure_tenant_id=azure_tenant_id,
311+
auth_type="azure-client-secret",
312+
)
313+
314+
# Call the "me" API
315+
me = ws.current_user.me()
316+
317+
# Verify we got a valid response
318+
assert me.user_name, "expected non-empty user_name"
319+
320+
321+
def test_account_oauth_m2m_auth(a, env_or_skip):
322+
env_or_skip("CLOUD_ENV")
323+
324+
# Get environment variables
325+
host = env_or_skip("DATABRICKS_HOST")
326+
account_id = env_or_skip("DATABRICKS_ACCOUNT_ID")
327+
client_id = env_or_skip("TEST_DATABRICKS_CLIENT_ID")
328+
client_secret = env_or_skip("TEST_DATABRICKS_CLIENT_SECRET")
329+
330+
# Create account client with OAuth M2M authentication
331+
ac = AccountClient(
332+
host=host,
333+
account_id=account_id,
334+
client_id=client_id,
335+
client_secret=client_secret,
336+
auth_type="oauth-m2m",
337+
)
338+
339+
# List service principals to verify authentication works
340+
sps = ac.service_principals.list()
341+
next(sps)
342+
343+
344+
def test_account_azure_client_secret_auth(a, env_or_skip):
345+
env_or_skip("CLOUD_ENV")
346+
347+
# Get environment variables
348+
host = env_or_skip("DATABRICKS_HOST")
349+
account_id = env_or_skip("DATABRICKS_ACCOUNT_ID")
350+
azure_client_id = env_or_skip("ARM_CLIENT_ID")
351+
azure_client_secret = env_or_skip("ARM_CLIENT_SECRET")
352+
azure_tenant_id = env_or_skip("ARM_TENANT_ID")
353+
354+
# Create account client with Azure client secret authentication
355+
ac = AccountClient(
356+
host=host,
357+
account_id=account_id,
358+
azure_client_id=azure_client_id,
359+
azure_client_secret=azure_client_secret,
360+
azure_tenant_id=azure_tenant_id,
361+
auth_type="azure-client-secret",
362+
)
363+
364+
# List service principals to verify authentication works
365+
sps = ac.service_principals.list()
366+
next(sps)

tests/test_config.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,142 @@ def test_oidc_endpoints_unified_missing_ids():
422422
assert "Unified host requires account_id" in str(exc_info.value)
423423

424424

425+
def test_databricks_oidc_endpoints_ignores_azure_client_id(mocker, requests_mock):
426+
"""Test that databricks_oidc_endpoints returns Databricks endpoints even when azure_client_id is set."""
427+
requests_mock.get(
428+
"https://adb-123.4.azuredatabricks.net/oidc/.well-known/oauth-authorization-server",
429+
json={
430+
"authorization_endpoint": "https://adb-123.4.azuredatabricks.net/oidc/v1/authorize",
431+
"token_endpoint": "https://adb-123.4.azuredatabricks.net/oidc/v1/token",
432+
},
433+
)
434+
435+
config = Config(
436+
host="https://adb-123.4.azuredatabricks.net",
437+
azure_client_id="test-azure-client-id", # This should be ignored by databricks_oidc_endpoints
438+
token="test-token",
439+
)
440+
441+
endpoints = config.databricks_oidc_endpoints
442+
assert endpoints is not None
443+
assert "https://adb-123.4.azuredatabricks.net/oidc/v1/authorize" == endpoints.authorization_endpoint
444+
assert "https://adb-123.4.azuredatabricks.net/oidc/v1/token" == endpoints.token_endpoint
445+
446+
447+
def test_databricks_oidc_endpoints_unified_workspace(mocker, requests_mock):
448+
"""Test that databricks_oidc_endpoints returns unified endpoints for workspace on unified host."""
449+
requests_mock.get(
450+
"https://unified.databricks.com/oidc/accounts/test-account/.well-known/oauth-authorization-server",
451+
json={
452+
"authorization_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/authorize",
453+
"token_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/token",
454+
},
455+
)
456+
457+
config = Config(
458+
host="https://unified.databricks.com",
459+
workspace_id="test-workspace",
460+
account_id="test-account",
461+
experimental_is_unified_host=True,
462+
token="test-token",
463+
)
464+
465+
endpoints = config.databricks_oidc_endpoints
466+
assert endpoints is not None
467+
assert "accounts/test-account" in endpoints.authorization_endpoint
468+
assert "accounts/test-account" in endpoints.token_endpoint
469+
470+
471+
def test_databricks_oidc_endpoints_account(mocker, requests_mock):
472+
"""Test that databricks_oidc_endpoints returns account endpoints for account hosts."""
473+
requests_mock.get(
474+
"https://accounts.cloud.databricks.com/oidc/accounts/test-account/.well-known/oauth-authorization-server",
475+
json={
476+
"authorization_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/test-account/v1/authorize",
477+
"token_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/test-account/v1/token",
478+
},
479+
)
480+
481+
config = Config(
482+
host="https://accounts.cloud.databricks.com",
483+
account_id="test-account",
484+
token="test-token",
485+
)
486+
487+
endpoints = config.databricks_oidc_endpoints
488+
assert endpoints is not None
489+
assert "accounts/test-account" in endpoints.authorization_endpoint
490+
assert "accounts/test-account" in endpoints.token_endpoint
491+
492+
493+
def test_databricks_oidc_endpoints_workspace(mocker, requests_mock):
494+
"""Test that databricks_oidc_endpoints returns workspace endpoints for workspace hosts."""
495+
requests_mock.get(
496+
"https://test-workspace.cloud.databricks.com/oidc/.well-known/oauth-authorization-server",
497+
json={
498+
"authorization_endpoint": "https://test-workspace.cloud.databricks.com/oidc/v1/authorize",
499+
"token_endpoint": "https://test-workspace.cloud.databricks.com/oidc/v1/token",
500+
},
501+
)
502+
503+
config = Config(
504+
host="https://test-workspace.cloud.databricks.com",
505+
token="test-token",
506+
)
507+
508+
endpoints = config.databricks_oidc_endpoints
509+
assert endpoints is not None
510+
assert "https://test-workspace.cloud.databricks.com/oidc/v1/authorize" == endpoints.authorization_endpoint
511+
assert "https://test-workspace.cloud.databricks.com/oidc/v1/token" == endpoints.token_endpoint
512+
513+
514+
def test_oidc_endpoints_returns_azure_when_azure_client_id_set(mocker):
515+
"""Test that deprecated oidc_endpoints returns Azure endpoints when azure_client_id is set on Azure.
516+
517+
This tests the deprecated behavior that is maintained for backward compatibility.
518+
"""
519+
# Mock the Azure endpoint detection
520+
mocker.patch(
521+
"databricks.sdk.config.get_azure_entra_id_workspace_endpoints",
522+
return_value=mocker.Mock(
523+
authorization_endpoint="https://login.microsoftonline.com/tenant-id/oauth2/v2.0/authorize",
524+
token_endpoint="https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token",
525+
),
526+
)
527+
528+
config = Config(
529+
host="https://adb-123.4.azuredatabricks.net",
530+
azure_client_id="test-azure-client-id",
531+
token="test-token",
532+
)
533+
534+
endpoints = config.oidc_endpoints
535+
assert endpoints is not None
536+
assert "login.microsoftonline.com" in endpoints.authorization_endpoint
537+
assert "login.microsoftonline.com" in endpoints.token_endpoint
538+
539+
540+
def test_oidc_endpoints_falls_back_to_databricks_when_no_azure_client_id(mocker, requests_mock):
541+
"""Test that deprecated oidc_endpoints falls back to Databricks endpoints when azure_client_id is not set."""
542+
requests_mock.get(
543+
"https://adb-123.4.azuredatabricks.net/oidc/.well-known/oauth-authorization-server",
544+
json={
545+
"authorization_endpoint": "https://adb-123.4.azuredatabricks.net/oidc/v1/authorize",
546+
"token_endpoint": "https://adb-123.4.azuredatabricks.net/oidc/v1/token",
547+
},
548+
)
549+
550+
config = Config(
551+
host="https://adb-123.4.azuredatabricks.net",
552+
token="test-token",
553+
)
554+
555+
endpoints = config.oidc_endpoints
556+
assert endpoints is not None
557+
assert "https://adb-123.4.azuredatabricks.net/oidc/v1/authorize" == endpoints.authorization_endpoint
558+
assert "https://adb-123.4.azuredatabricks.net/oidc/v1/token" == endpoints.token_endpoint
559+
560+
425561
def test_workspace_org_id_header_on_unified_host(requests_mock):
426562
"""Test that X-Databricks-Org-Id header is added for workspace clients on unified hosts."""
427563

0 commit comments

Comments
 (0)