diff --git a/pom.xml b/pom.xml index c11c2a5295..c761030921 100644 --- a/pom.xml +++ b/pom.xml @@ -188,6 +188,11 @@ jackson-core ${jackson.version} + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + ${jackson.version} + org.immutables value diff --git a/src/main/java/com/databricks/jdbc/api/IDatabricksConnectionContext.java b/src/main/java/com/databricks/jdbc/api/IDatabricksConnectionContext.java index 28cbd7d3c1..4b7f10dcf9 100644 --- a/src/main/java/com/databricks/jdbc/api/IDatabricksConnectionContext.java +++ b/src/main/java/com/databricks/jdbc/api/IDatabricksConnectionContext.java @@ -247,4 +247,10 @@ public interface IDatabricksConnectionContext { /** Returns maximum number of rows that a query returns at a time. */ int getRowsFetchedPerBlock(); + + /** Returns the passphrase used for encrypting/decrypting token cache */ + String getTokenCachePassPhrase(); + + /** Returns whether token caching is enabled for OAuth authentication */ + boolean isTokenCacheEnabled(); } diff --git a/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java b/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java index d2db7188ef..9655aea757 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java +++ b/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java @@ -792,6 +792,16 @@ public int getRowsFetchedPerBlock() { return maxRows; } + @Override + public String getTokenCachePassPhrase() { + return getParameter(DatabricksJdbcUrlParams.TOKEN_CACHE_PASS_PHRASE); + } + + @Override + public boolean isTokenCacheEnabled() { + return getParameter(DatabricksJdbcUrlParams.ENABLE_TOKEN_CACHE).equals("1"); + } + private static boolean nullOrEmptyString(String s) { return s == null || s.isEmpty(); } diff --git a/src/main/java/com/databricks/jdbc/auth/CachingExternalBrowserCredentialsProvider.java b/src/main/java/com/databricks/jdbc/auth/CachingExternalBrowserCredentialsProvider.java new file mode 100644 index 0000000000..3d3b955cea --- /dev/null +++ b/src/main/java/com/databricks/jdbc/auth/CachingExternalBrowserCredentialsProvider.java @@ -0,0 +1,209 @@ +package com.databricks.jdbc.auth; + +import static com.databricks.jdbc.auth.AuthConstants.GRANT_TYPE_KEY; +import static com.databricks.jdbc.auth.AuthConstants.GRANT_TYPE_REFRESH_TOKEN_KEY; + +import com.databricks.jdbc.api.IDatabricksConnectionContext; +import com.databricks.jdbc.common.DatabricksJdbcConstants; +import com.databricks.jdbc.common.util.DatabricksAuthUtil; +import com.databricks.jdbc.log.JdbcLogger; +import com.databricks.jdbc.log.JdbcLoggerFactory; +import com.databricks.sdk.core.CredentialsProvider; +import com.databricks.sdk.core.DatabricksConfig; +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.HeaderFactory; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.oauth.*; +import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import java.time.LocalDateTime; +import java.util.HashMap; +import java.util.Map; +import org.apache.http.HttpHeaders; + +/** + * A {@code CredentialsProvider} which implements the Authorization Code + PKCE flow with token + * caching and automatic token refresh. This provider encrypts and caches tokens in the user's + * temporary directory and reuses them when possible. + * + *

This provider extends {@code RefreshableTokenSource} to handle token refreshing in a + * standardized way. When a token is obtained after successful authentication, it is encrypted and + * stored locally. On subsequent connection attempts, the provider will: + * + *

    + *
  1. Try to load and use a cached token if available + *
  2. If the cached token is expired but has a refresh token, attempt to refresh it using the + * OAuth2 token endpoint + *
  3. If no cached token exists or refresh fails, initiate the browser-based OAuth flow + *
+ * + *

This approach minimizes the need for users to repeatedly authenticate through the browser, + * improving the user experience while maintaining security through encryption of the cached tokens. + */ +public class CachingExternalBrowserCredentialsProvider extends RefreshableTokenSource + implements CredentialsProvider { + + private static final String AUTH_TYPE = "external-browser-with-cache"; + private static final JdbcLogger LOGGER = + JdbcLoggerFactory.getLogger(CachingExternalBrowserCredentialsProvider.class); + private final TokenCache tokenCache; + private final DatabricksConfig config; + private final String tokenEndpoint; + private HttpClient hc; + + /** + * Creates a new CachingExternalBrowserCredentialsProvider with the specified configuration, + * connection context, and token cache. + * + * @param config The Databricks configuration to use for authentication + * @param context The connection context containing OAuth configuration parameters + * @param tokenCache The token cache to use for storing and retrieving tokens + */ + public CachingExternalBrowserCredentialsProvider( + DatabricksConfig config, IDatabricksConnectionContext context, TokenCache tokenCache) { + this.config = config; + this.tokenCache = tokenCache; + this.tokenEndpoint = DatabricksAuthUtil.getTokenEndpoint(config, context); + try { + // Initialize token from cache + this.token = tokenCache.load(); + if (this.token == null) { + LOGGER.debug("No cached token found"); + // Initialize with an expired token to force authentication on first use + this.token = + new Token( + DatabricksJdbcConstants.EMPTY_STRING, + DatabricksJdbcConstants.EMPTY_STRING, + null, + LocalDateTime.now().minusMinutes(1)); + } else { + LOGGER.debug("Cached token found"); + } + } catch (IOException e) { + LOGGER.debug("Failed to load token from cache", e); + // Initialize with an expired token to force authentication on first use + this.token = + new Token( + DatabricksJdbcConstants.EMPTY_STRING, + DatabricksJdbcConstants.EMPTY_STRING, + null, + LocalDateTime.now().minusMinutes(1)); + } + } + + /** + * Returns the authentication type identifier for this provider. + * + * @return The string "external-browser-with-cache" + */ + @Override + public String authType() { + return AUTH_TYPE; + } + + /** + * Configures the authentication by setting up the necessary headers for authenticated requests. + * This method implements the core OAuth flow with caching logic. + * + * @param config The Databricks configuration to use + * @return A HeaderFactory that adds the OAuth authentication header to requests, or null if the + * configuration is not valid for this provider + */ + @Override + public HeaderFactory configure(DatabricksConfig config) { + if (config.getHost() == null || !config.getAuthType().equals(AUTH_TYPE)) { + return null; + } + + if (this.hc == null) { + this.hc = config.getHttpClient(); + } + + return () -> { + Map headers = new HashMap<>(); + headers.put( + HttpHeaders.AUTHORIZATION, getToken().getTokenType() + " " + getToken().getAccessToken()); + return headers; + }; + } + + /** + * Implements the token refresh logic as required by RefreshableTokenSource. This method handles + * token refreshing, falling back to browser authentication if needed. + * + * @return A new or refreshed token + * @throws DatabricksException If there is an error during the authentication process + */ + @Override + protected Token refresh() { + try { + // Try to refresh if we have a refresh token + if (this.token != null && this.token.getRefreshToken() != null) { + try { + LOGGER.debug("Using refresh token to get new access token"); + Token refreshedToken = refreshAccessToken(); + tokenCache.save(refreshedToken); + return refreshedToken; + } catch (Exception e) { + LOGGER.info("Failed to refresh access token, will restart browser auth", e); + // If refresh fails, fall through to browser auth + } + } + + // If we get here, we need to do browser auth + LOGGER.debug("Performing browser authentication to get new access token"); + Token newToken = performBrowserAuth(); + tokenCache.save(newToken); + return newToken; + } catch (Exception e) { + String errorMessage = "Failed to refresh or obtain new token"; + LOGGER.error(errorMessage, e); + throw new DatabricksException(errorMessage, e); + } + } + + /** + * Refreshes an access token using the refresh token from the current token. This method follows + * the OAuth 2.0 refresh token flow by sending a request to the token endpoint with the refresh + * token grant type. + * + * @return A new token with a refreshed access token + * @throws DatabricksException If there is an error during the refresh process or if the token or + * refresh token is not available + */ + @VisibleForTesting + Token refreshAccessToken() throws DatabricksException { + if (this.token == null || this.token.getRefreshToken() == null) { + throw new DatabricksException("oauth2: token is not set or refresh token is not available"); + } + + Map params = new HashMap<>(); + params.put(GRANT_TYPE_KEY, GRANT_TYPE_REFRESH_TOKEN_KEY); + params.put(GRANT_TYPE_REFRESH_TOKEN_KEY, this.token.getRefreshToken()); + Map headers = new HashMap<>(); + return retrieveToken( + hc, + config.getClientId(), + config.getClientSecret(), + tokenEndpoint, + params, + headers, + AuthParameterPosition.BODY); + } + + /** + * Performs browser-based authentication to obtain a new token. This method launches a browser + * window to allow the user to authenticate and authorize the application. + * + * @return A new token obtained through browser authentication + * @throws IOException If there is an error during the authentication process + * @throws DatabricksException If the Databricks API returns an error + */ + @VisibleForTesting + Token performBrowserAuth() throws IOException, DatabricksException { + OAuthClient client = new OAuthClient(config); + Consent consent = client.initiateConsent(); + SessionCredentials creds = consent.launchExternalBrowser(); + return creds.getToken(); + } +} diff --git a/src/main/java/com/databricks/jdbc/auth/TokenCache.java b/src/main/java/com/databricks/jdbc/auth/TokenCache.java new file mode 100644 index 0000000000..96a1ebaa7b --- /dev/null +++ b/src/main/java/com/databricks/jdbc/auth/TokenCache.java @@ -0,0 +1,190 @@ +package com.databricks.jdbc.auth; + +import com.databricks.jdbc.common.util.StringUtil; +import com.databricks.sdk.core.oauth.Token; +import com.databricks.sdk.core.utils.ClockSupplier; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.spec.KeySpec; +import java.time.LocalDateTime; +import java.util.Base64; +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import javax.crypto.SecretKeyFactory; +import javax.crypto.spec.PBEKeySpec; +import javax.crypto.spec.SecretKeySpec; + +/** + * A secure cache for OAuth tokens that encrypts and persists tokens to the local filesystem. + * + *

This class provides functionality to securely store and retrieve OAuth tokens between + * application sessions. Tokens are encrypted using AES encryption with a key derived from the + * provided passphrase using PBKDF2 key derivation. + * + *

The tokens are stored in a file within the system's temporary directory under a '.databricks' + * subdirectory. Each user has their own token cache file identified by the sanitized username. + */ +public class TokenCache { + private static final String CACHE_DIR = ".databricks"; + private static final String CACHE_FILE_SUFFIX = ".databricks_jdbc_token_cache"; + private static final String ALGORITHM = "AES"; + private static final String SECRET_KEY_ALGORITHM = "PBKDF2WithHmacSHA256"; + private static final byte[] SALT = "DatabricksTokenCache".getBytes(); // Fixed salt for simplicity + private static final int ITERATION_COUNT = 65536; + private static final int KEY_LENGTH = 256; + + private final Path cacheFile; + private final String passphrase; + private final ObjectMapper mapper; + + /** + * A serializable version of the Token class that can be serialized/deserialized by Jackson. This + * class extends the Token class from the SDK and adds JSON annotations for proper serialization. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + static class SerializableToken extends Token { + public SerializableToken(String accessToken, String tokenType, LocalDateTime expiry) { + super(accessToken, tokenType, expiry); + } + + public SerializableToken( + String accessToken, String tokenType, LocalDateTime expiry, ClockSupplier clockSupplier) { + super(accessToken, tokenType, expiry, clockSupplier); + } + + @JsonCreator + public SerializableToken( + @JsonProperty("accessToken") String accessToken, + @JsonProperty("tokenType") String tokenType, + @JsonProperty("refreshToken") String refreshToken, + @JsonProperty("expiry") LocalDateTime expiry) { + super(accessToken, tokenType, refreshToken, expiry); + } + + public SerializableToken( + String accessToken, + String tokenType, + String refreshToken, + LocalDateTime expiry, + ClockSupplier clockSupplier) { + super(accessToken, tokenType, refreshToken, expiry, clockSupplier); + } + } + + /** + * Constructs a new TokenCache instance with encryption using the user's system username to + * identify the cache file. + * + * @param passphrase The passphrase used to encrypt/decrypt the token cache + * @throws IllegalArgumentException if the passphrase is null or empty + */ + public TokenCache(String passphrase) { + this( + Paths.get( + System.getProperty("java.io.tmpdir"), + CACHE_DIR, + StringUtil.sanitizeUsernameForFile(System.getProperty("user.name")) + + CACHE_FILE_SUFFIX), + passphrase); + } + + /** + * Constructs a new TokenCache instance with encryption using the cache file path provided. + * + * @param cacheFile The cache file path + * @param passphrase The passphrase used to encrypt/decrypt the token cache + */ + @VisibleForTesting + public TokenCache(Path cacheFile, String passphrase) { + if (passphrase == null || passphrase.isEmpty()) { + throw new IllegalArgumentException( + "Required setting TokenCachePassPhrase has not been provided in connection settings"); + } + this.passphrase = passphrase; + this.cacheFile = cacheFile; + this.mapper = new ObjectMapper(); + this.mapper.registerModule(new JavaTimeModule()); + } + + /** + * Saves a token to the cache file, encrypting it with the configured passphrase. + * + * @param token The token to save to the cache + * @throws IOException If an error occurs writing the token to the file or during encryption + */ + public void save(Token token) throws IOException { + try { + Files.createDirectories(cacheFile.getParent()); + String json = mapper.writeValueAsString(token); + byte[] encrypted = encrypt(json.getBytes()); + Files.write(cacheFile, encrypted); + } catch (Exception e) { + throw new IOException("Failed to save token cache: " + e.getMessage(), e); + } + } + + /** + * Loads a token from the cache file, decrypting it with the configured passphrase. + * + * @return The decrypted token from the cache or null if the cache file doesn't exist + * @throws IOException If an error occurs reading the token from the file or during decryption + */ + public Token load() throws IOException { + try { + if (!Files.exists(cacheFile)) { + return null; + } + byte[] encrypted = Files.readAllBytes(cacheFile); + byte[] decrypted = decrypt(encrypted); + return mapper.readValue(decrypted, SerializableToken.class); + } catch (Exception e) { + throw new IOException("Failed to load token cache: " + e.getMessage(), e); + } + } + + /** + * Generates a secret key from the passphrase using PBKDF2 with HMAC-SHA256. + * + * @return A SecretKey generated from the passphrase + * @throws Exception If an error occurs generating the key + */ + private SecretKey generateSecretKey() throws Exception { + SecretKeyFactory factory = SecretKeyFactory.getInstance(SECRET_KEY_ALGORITHM); + KeySpec spec = new PBEKeySpec(passphrase.toCharArray(), SALT, ITERATION_COUNT, KEY_LENGTH); + return new SecretKeySpec(factory.generateSecret(spec).getEncoded(), ALGORITHM); + } + + /** + * Encrypts the given data using AES encryption with a key derived from the passphrase. + * + * @param data The data to encrypt + * @return The encrypted data, Base64 encoded + * @throws Exception If an error occurs during encryption + */ + private byte[] encrypt(byte[] data) throws Exception { + Cipher cipher = Cipher.getInstance(ALGORITHM); + cipher.init(Cipher.ENCRYPT_MODE, generateSecretKey()); + return Base64.getEncoder().encode(cipher.doFinal(data)); + } + + /** + * Decrypts the given encrypted data using AES decryption with a key derived from the passphrase. + * + * @param encryptedData The encrypted data, Base64 encoded + * @return The decrypted data + * @throws Exception If an error occurs during decryption + */ + private byte[] decrypt(byte[] encryptedData) throws Exception { + Cipher cipher = Cipher.getInstance(ALGORITHM); + cipher.init(Cipher.DECRYPT_MODE, generateSecretKey()); + return cipher.doFinal(Base64.getDecoder().decode(encryptedData)); + } +} diff --git a/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java b/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java index 5ee370514b..0174669a28 100644 --- a/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java +++ b/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java @@ -106,7 +106,9 @@ public enum DatabricksJdbcUrlParams { DEFAULT_STRING_COLUMN_LENGTH( "DefaultStringColumnLength", "Maximum number of characters that can be contained in STRING columns", - "255"); + "255"), + TOKEN_CACHE_PASS_PHRASE("TokenCachePassPhrase", "Pass phrase to use for OAuth U2M Token Cache"), + ENABLE_TOKEN_CACHE("EnableTokenCache", "Enable caching OAuth tokens", "1"); private final String paramName; private final String defaultValue; diff --git a/src/main/java/com/databricks/jdbc/common/util/StringUtil.java b/src/main/java/com/databricks/jdbc/common/util/StringUtil.java index e9b7f655d2..3fb1a0bf50 100644 --- a/src/main/java/com/databricks/jdbc/common/util/StringUtil.java +++ b/src/main/java/com/databricks/jdbc/common/util/StringUtil.java @@ -59,4 +59,13 @@ public static String getVolumePath(String catalog, String schema, String volume) // We need to escape '' to prevent SQL injection return escapeStringLiteral(String.format("/Volumes/%s/%s/%s/", catalog, schema, volume)); } + + /** + * Sanitizes the given username so it can be safely used in a file name. Replaces all characters + * that are not a-z, A-Z, 0-9, or underscore (_) with an underscore. + */ + public static String sanitizeUsernameForFile(String username) { + if (username == null) return "unknown_user"; + return username.replaceAll("[^a-zA-Z0-9_]", "_"); + } } diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/common/ClientConfigurator.java b/src/main/java/com/databricks/jdbc/dbclient/impl/common/ClientConfigurator.java index a71895048f..a6fbb04c54 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/common/ClientConfigurator.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/common/ClientConfigurator.java @@ -4,9 +4,7 @@ import static com.databricks.jdbc.common.util.DatabricksAuthUtil.initializeConfigWithToken; import com.databricks.jdbc.api.IDatabricksConnectionContext; -import com.databricks.jdbc.auth.AzureMSICredentialProvider; -import com.databricks.jdbc.auth.OAuthRefreshCredentialsProvider; -import com.databricks.jdbc.auth.PrivateKeyClientCredentialProvider; +import com.databricks.jdbc.auth.*; import com.databricks.jdbc.common.AuthMech; import com.databricks.jdbc.common.DatabricksJdbcConstants; import com.databricks.jdbc.common.util.DriverUtil; @@ -20,6 +18,7 @@ import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.ProxyConfig; import com.databricks.sdk.core.commons.CommonsHttpClient; +import com.databricks.sdk.core.oauth.ExternalBrowserCredentialsProvider; import com.databricks.sdk.core.utils.Cloud; import java.security.cert.*; import java.util.Arrays; @@ -130,7 +129,6 @@ public void setupOAuthConfig() throws DatabricksParsingException { /** Setup the OAuth U2M authentication settings in the databricks config. */ public void setupU2MConfig() throws DatabricksParsingException { databricksConfig - .setAuthType(DatabricksJdbcConstants.U2M_AUTH_TYPE) .setHost(connectionContext.getHostForOAuth()) .setClientId(connectionContext.getClientId()) .setClientSecret(connectionContext.getClientSecret()) @@ -138,6 +136,20 @@ public void setupU2MConfig() throws DatabricksParsingException { if (!databricksConfig.isAzure()) { databricksConfig.setScopes(connectionContext.getOAuthScopesForU2M()); } + + CredentialsProvider provider; + if (connectionContext.isTokenCacheEnabled()) { + LOGGER.debug("Using CachingExternalBrowserCredentialsProvider as token caching is enabled"); + TokenCache tokenCache = new TokenCache(connectionContext.getTokenCachePassPhrase()); + provider = + new CachingExternalBrowserCredentialsProvider( + databricksConfig, connectionContext, tokenCache); + } else { + LOGGER.debug("Using ExternalBrowserCredentialsProvider as token caching is disabled"); + provider = new ExternalBrowserCredentialsProvider(); + } + + databricksConfig.setCredentialsProvider(provider).setAuthType(provider.authType()); } /** Setup the PAT authentication settings in the databricks config. */ @@ -162,14 +174,13 @@ public void resetAccessTokenInConfig(String newAccessToken) { /** Setup the OAuth U2M refresh token authentication settings in the databricks config. */ public void setupU2MRefreshConfig() throws DatabricksParsingException { - CredentialsProvider provider = - new OAuthRefreshCredentialsProvider(connectionContext, databricksConfig); databricksConfig .setHost(connectionContext.getHostForOAuth()) - .setAuthType(provider.authType()) // oauth-refresh - .setCredentialsProvider(provider) .setClientId(connectionContext.getClientId()) .setClientSecret(connectionContext.getClientSecret()); + CredentialsProvider provider = + new OAuthRefreshCredentialsProvider(connectionContext, databricksConfig); + databricksConfig.setAuthType(provider.authType()).setCredentialsProvider(provider); } /** Setup the OAuth M2M authentication settings in the databricks config. */ diff --git a/src/test/java/com/databricks/jdbc/api/impl/DatabricksConnectionContextTest.java b/src/test/java/com/databricks/jdbc/api/impl/DatabricksConnectionContextTest.java index fb46338f37..9b99db53ba 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/DatabricksConnectionContextTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/DatabricksConnectionContextTest.java @@ -441,13 +441,39 @@ public void testParsingOfUrlWithSpecifiedCatalogAndSchema() throws DatabricksSQL @Test void testLogLevels() { - assertEquals(getLogLevel(123), LogLevel.OFF); - assertEquals(getLogLevel(0), LogLevel.OFF); - assertEquals(getLogLevel(1), LogLevel.FATAL); - assertEquals(getLogLevel(2), LogLevel.ERROR); - assertEquals(getLogLevel(3), LogLevel.WARN); - assertEquals(getLogLevel(4), LogLevel.INFO); - assertEquals(getLogLevel(5), LogLevel.DEBUG); - assertEquals(getLogLevel(6), LogLevel.TRACE); + assertEquals(LogLevel.OFF, getLogLevel(0)); + assertEquals(LogLevel.FATAL, getLogLevel(1)); + assertEquals(LogLevel.ERROR, getLogLevel(2)); + assertEquals(LogLevel.WARN, getLogLevel(3)); + assertEquals(LogLevel.INFO, getLogLevel(4)); + assertEquals(LogLevel.DEBUG, getLogLevel(5)); + assertEquals(LogLevel.TRACE, getLogLevel(6)); + assertEquals(LogLevel.OFF, getLogLevel(123)); + } + + @Test + public void testIsTokenCacheEnabled() throws DatabricksSQLException { + // Test with EnableTokenCache=1 (default) + Properties properties1 = new Properties(); + DatabricksConnectionContext connectionContext1 = + (DatabricksConnectionContext) + DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, properties1); + assertTrue(connectionContext1.isTokenCacheEnabled()); + + // Test with EnableTokenCache=0 + Properties properties2 = new Properties(); + properties2.setProperty("EnableTokenCache", "0"); + DatabricksConnectionContext connectionContext2 = + (DatabricksConnectionContext) + DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, properties2); + assertFalse(connectionContext2.isTokenCacheEnabled()); + + // Test with EnableTokenCache=1 explicitly set + Properties properties3 = new Properties(); + properties3.setProperty("EnableTokenCache", "1"); + DatabricksConnectionContext connectionContext3 = + (DatabricksConnectionContext) + DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, properties3); + assertTrue(connectionContext3.isTokenCacheEnabled()); } } diff --git a/src/test/java/com/databricks/jdbc/auth/CachingExternalBrowserCredentialsProviderTest.java b/src/test/java/com/databricks/jdbc/auth/CachingExternalBrowserCredentialsProviderTest.java new file mode 100644 index 0000000000..ead75942e5 --- /dev/null +++ b/src/test/java/com/databricks/jdbc/auth/CachingExternalBrowserCredentialsProviderTest.java @@ -0,0 +1,201 @@ +package com.databricks.jdbc.auth; + +import static com.databricks.jdbc.TestConstants.TEST_AUTH_URL; +import static com.databricks.jdbc.TestConstants.TEST_TOKEN_URL; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import com.databricks.jdbc.api.IDatabricksConnectionContext; +import com.databricks.jdbc.common.DatabricksJdbcConstants; +import com.databricks.sdk.core.DatabricksConfig; +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.HeaderFactory; +import com.databricks.sdk.core.oauth.OpenIDConnectEndpoints; +import com.databricks.sdk.core.oauth.Token; +import java.io.IOException; +import java.time.LocalDateTime; +import java.util.Map; +import org.apache.http.HttpHeaders; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class CachingExternalBrowserCredentialsProviderTest { + private static final String TEST_HOST = "test-host"; + private static final String AUTH_TYPE = "external-browser-with-cache"; + + @Mock private TokenCache tokenCache; + @Mock private IDatabricksConnectionContext connectionContext; + @Mock private DatabricksConfig config; + + private CachingExternalBrowserCredentialsProvider provider; + + @BeforeEach + void setUp() throws IOException { + // Set up necessary mocks + doReturn(new OpenIDConnectEndpoints(TEST_TOKEN_URL, TEST_AUTH_URL)) + .when(config) + .getOidcEndpoints(); + } + + @Test + void testAuthType() { + provider = + spy(new CachingExternalBrowserCredentialsProvider(config, connectionContext, tokenCache)); + assertEquals(AUTH_TYPE, provider.authType()); + } + + @Test + void testConfigureWithInvalidConfig() { + DatabricksConfig invalidConfig = new DatabricksConfig().setAuthType("invalid-type"); + provider = + spy(new CachingExternalBrowserCredentialsProvider(config, connectionContext, tokenCache)); + assertNull(provider.configure(invalidConfig)); + } + + @Test + void testUseValidTokenFromCache() throws IOException { + when(config.getHost()).thenReturn(TEST_HOST); + when(config.getAuthType()).thenReturn(AUTH_TYPE); + // Setup valid token in cache + Token validToken = + new Token( + "cached-token", "Bearer", "cached-refresh-token", LocalDateTime.now().plusHours(1)); + when(tokenCache.load()).thenReturn(validToken); + + provider = + spy(new CachingExternalBrowserCredentialsProvider(config, connectionContext, tokenCache)); + + // Should use cached token without any refresh or browser auth + HeaderFactory headerFactory = provider.configure(config); + assertNotNull(headerFactory); + + Map headers = headerFactory.headers(); + assertEquals("Bearer cached-token", headers.get(HttpHeaders.AUTHORIZATION)); + + // Verify no refresh or browser auth was attempted + verify(provider, never()).refresh(); + } + + @Test + void testRefreshExpiredTokenSuccess() throws IOException, DatabricksException { + when(config.getHost()).thenReturn(TEST_HOST); + when(config.getAuthType()).thenReturn(AUTH_TYPE); + + // Setup expired token in cache + Token expiredToken = + new Token( + "expired-token", "Bearer", "cached-refresh-token", LocalDateTime.now().minusMinutes(5)); + when(tokenCache.load()).thenReturn(expiredToken); + + provider = + spy(new CachingExternalBrowserCredentialsProvider(config, connectionContext, tokenCache)); + + // Setup successful token refresh + Token refreshedToken = + new Token( + "refreshed-token", "Bearer", "new-refresh-token", LocalDateTime.now().plusHours(1)); + doReturn(refreshedToken).when(provider).refreshAccessToken(); + + // Should refresh token using cached refresh token + HeaderFactory headerFactory = provider.configure(config); + assertNotNull(headerFactory); + + Map headers = headerFactory.headers(); + assertEquals("Bearer refreshed-token", headers.get(HttpHeaders.AUTHORIZATION)); + + // Verify refresh was attempted + verify(provider).refreshAccessToken(); + verify(tokenCache).save(refreshedToken); + } + + @Test + void testRefreshTokenFailureFallbackToBrowserAuth() throws IOException, DatabricksException { + when(config.getHost()).thenReturn(TEST_HOST); + when(config.getAuthType()).thenReturn(AUTH_TYPE); + + // Setup expired token in cache + Token expiredToken = + new Token( + "expired-token", + "Bearer", + "invalid-refresh-token", + LocalDateTime.now().minusMinutes(5)); + when(tokenCache.load()).thenReturn(expiredToken); + + provider = + spy(new CachingExternalBrowserCredentialsProvider(config, connectionContext, tokenCache)); + + // Setup failed token refresh + doThrow(new DatabricksException("Invalid refresh token")).when(provider).refreshAccessToken(); + + // Setup successful browser auth + Token newToken = + new Token("new-token", "Bearer", "new-refresh-token", LocalDateTime.now().plusHours(1)); + doReturn(newToken).when(provider).performBrowserAuth(); + + // Configure and get the token to trigger refresh + HeaderFactory headerFactory = provider.configure(config); + assertNotNull(headerFactory); + + Map headers = headerFactory.headers(); + assertEquals("Bearer new-token", headers.get(HttpHeaders.AUTHORIZATION)); + + // Verify both refresh and browser auth were attempted + verify(provider).refreshAccessToken(); + verify(provider).performBrowserAuth(); + verify(tokenCache).save(newToken); + } + + @Test + void testEmptyCacheFallbackToBrowserAuth() throws IOException, DatabricksException { + when(config.getHost()).thenReturn(TEST_HOST); + when(config.getAuthType()).thenReturn(AUTH_TYPE); + + // Setup empty cache with null token + // RefreshableTokenSource initialization will create empty token if null is returned + when(tokenCache.load()).thenReturn(null); + provider = + spy(new CachingExternalBrowserCredentialsProvider(config, connectionContext, tokenCache)); + + // Setup successful browser auth + Token newToken = + new Token("new-token", "Bearer", "new-refresh-token", LocalDateTime.now().plusHours(1)); + doReturn(newToken).when(provider).performBrowserAuth(); + + // Configure and get the token to trigger refresh + HeaderFactory headerFactory = provider.configure(config); + assertNotNull(headerFactory); + + Map headers = headerFactory.headers(); + assertEquals("Bearer new-token", headers.get(HttpHeaders.AUTHORIZATION)); + + // Verify only browser auth was attempted (refresh will not be attempted with null refresh + // token) + verify(provider).performBrowserAuth(); + verify(tokenCache).save(newToken); + } + + @Test + void testRefreshAccessToken() throws DatabricksException, IOException { + // Create a token with a refresh token + Token token = + new Token( + DatabricksJdbcConstants.EMPTY_STRING, + DatabricksJdbcConstants.EMPTY_STRING, + "test-refresh-token", + LocalDateTime.now().minusMinutes(1)); + when(tokenCache.load()).thenReturn(token); + + provider = + spy(new CachingExternalBrowserCredentialsProvider(config, connectionContext, tokenCache)); + + // Test that refreshAccessToken is called by refresh() + doThrow(new DatabricksException("Test exception")).when(provider).refreshAccessToken(); + + assertThrows(DatabricksException.class, () -> provider.refresh()); + } +} diff --git a/src/test/java/com/databricks/jdbc/auth/TokenCacheTest.java b/src/test/java/com/databricks/jdbc/auth/TokenCacheTest.java new file mode 100644 index 0000000000..61c5e8cd84 --- /dev/null +++ b/src/test/java/com/databricks/jdbc/auth/TokenCacheTest.java @@ -0,0 +1,76 @@ +package com.databricks.jdbc.auth; + +import static org.junit.jupiter.api.Assertions.*; + +import com.databricks.jdbc.common.util.StringUtil; +import com.databricks.sdk.core.oauth.Token; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.LocalDateTime; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class TokenCacheTest { + private static final String TEST_PASSPHRASE = "test-passphrase"; + private Path cacheFile; + private TokenCache tokenCache; + + @BeforeEach + void setUp() throws IOException { + String sanitizedUsername = StringUtil.sanitizeUsernameForFile(System.getProperty("user.name")); + cacheFile = + Paths.get( + System.getProperty("java.io.tmpdir"), + ".databricks", + sanitizedUsername + ".databricks_jdbc_token_cache_test"); + tokenCache = new TokenCache(cacheFile, TEST_PASSPHRASE); + Files.deleteIfExists(cacheFile); + } + + @AfterEach + void tearDown() throws IOException { + Files.deleteIfExists(cacheFile); + } + + @Test + void testEmptyCache() throws IOException { + assertNull(tokenCache.load()); + } + + @Test + void testSaveAndLoadToken() throws IOException { + LocalDateTime expiry = LocalDateTime.now().plusHours(1); + Token token = new Token("access-token", "Bearer", "refresh-token", expiry); + + tokenCache.save(token); + Token loadedToken = tokenCache.load(); + + assertNotNull(loadedToken); + assertEquals("access-token", loadedToken.getAccessToken()); + assertEquals("Bearer", loadedToken.getTokenType()); + assertEquals("refresh-token", loadedToken.getRefreshToken()); + assertFalse(loadedToken.isExpired()); + } + + @Test + void testInvalidPassphrase() { + assertThrows(IllegalArgumentException.class, () -> new TokenCache(null)); + assertThrows(IllegalArgumentException.class, () -> new TokenCache("")); + } + + @Test + void testOverwriteToken() throws IOException { + Token token1 = new Token("token1", "Bearer", "refresh1", LocalDateTime.now().plusHours(1)); + Token token2 = new Token("token2", "Bearer", "refresh2", LocalDateTime.now().plusHours(2)); + + tokenCache.save(token1); + tokenCache.save(token2); + + Token loadedToken = tokenCache.load(); + assertEquals("token2", loadedToken.getAccessToken()); + assertEquals("refresh2", loadedToken.getRefreshToken()); + } +} diff --git a/src/test/java/com/databricks/jdbc/common/util/StringUtilTest.java b/src/test/java/com/databricks/jdbc/common/util/StringUtilTest.java index b8592198fe..fb9b124835 100644 --- a/src/test/java/com/databricks/jdbc/common/util/StringUtilTest.java +++ b/src/test/java/com/databricks/jdbc/common/util/StringUtilTest.java @@ -73,4 +73,39 @@ public void testEscapeStringLiteral() { String expected = "''1'';select * from other-table"; assertEquals(expected, StringUtil.escapeStringLiteral(sqlValue)); } + + @Test + public void testNullUsername() { + assertEquals("unknown_user", StringUtil.sanitizeUsernameForFile(null)); + } + + @Test + public void testEmptyUsername() { + assertEquals("", StringUtil.sanitizeUsernameForFile("")); + } + + @Test + public void testAlreadySanitizedUsername() { + assertEquals("john_doe123", StringUtil.sanitizeUsernameForFile("john_doe123")); + } + + @Test + public void testUsernameWithSpaces() { + assertEquals("john_doe", StringUtil.sanitizeUsernameForFile("john doe")); + } + + @Test + public void testUsernameWithSpecialCharacters() { + assertEquals("john_doe_123", StringUtil.sanitizeUsernameForFile("john.doe@123")); + } + + @Test + public void testUsernameWithMixedCharacters() { + assertEquals("John_Doe_123_user", StringUtil.sanitizeUsernameForFile("John-Doe#123!user")); + } + + @Test + public void testUsernameWithAllInvalidCharacters() { + assertEquals("______", StringUtil.sanitizeUsernameForFile("!@#$%^")); + } } diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/common/ClientConfiguratorTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/common/ClientConfiguratorTest.java index 1f2343c3a1..dfe6d2033a 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/common/ClientConfiguratorTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/common/ClientConfiguratorTest.java @@ -8,10 +8,12 @@ import com.databricks.jdbc.api.IDatabricksConnectionContext; import com.databricks.jdbc.api.impl.DatabricksConnectionContextFactory; +import com.databricks.jdbc.auth.CachingExternalBrowserCredentialsProvider; import com.databricks.jdbc.auth.PrivateKeyClientCredentialProvider; import com.databricks.jdbc.common.AuthFlow; import com.databricks.jdbc.common.AuthMech; import com.databricks.jdbc.common.DatabricksJdbcConstants; +import com.databricks.jdbc.common.util.DatabricksAuthUtil; import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.sdk.WorkspaceClient; @@ -20,6 +22,7 @@ import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.ProxyConfig; import com.databricks.sdk.core.commons.CommonsHttpClient; +import com.databricks.sdk.core.oauth.ExternalBrowserCredentialsProvider; import com.databricks.sdk.core.utils.Cloud; import java.io.IOException; import java.util.List; @@ -28,6 +31,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; @ExtendWith(MockitoExtension.class) @@ -209,6 +213,63 @@ void getWorkspaceClient_OAuthWithBrowserBasedAuthentication_AuthenticatesCorrect assertEquals(DatabricksJdbcConstants.U2M_AUTH_TYPE, config.getAuthType()); } + @Test + void getWorkspaceClient_OAuthWithCachingExternalBrowser_NoPassphrase_ThrowsException() { + when(mockContext.getAuthMech()).thenReturn(AuthMech.OAUTH); + when(mockContext.getAuthFlow()).thenReturn(AuthFlow.BROWSER_BASED_AUTHENTICATION); + when(mockContext.getTokenCachePassPhrase()).thenReturn(null); + when(mockContext.getHttpConnectionPoolSize()).thenReturn(100); + when(mockContext.isOAuthDiscoveryModeEnabled()).thenReturn(true); + when(mockContext.isTokenCacheEnabled()).thenReturn(true); + + assertThrows(IllegalArgumentException.class, () -> new ClientConfigurator(mockContext)); + } + + @Test + void getWorkspaceClient_OAuthBrowserAuth_WithTokenCacheEnabled_UsesCachedProvider() + throws DatabricksParsingException { + try (MockedStatic mockedStatic = mockStatic(DatabricksAuthUtil.class)) { + when(mockContext.getAuthMech()).thenReturn(AuthMech.OAUTH); + when(mockContext.getAuthFlow()).thenReturn(AuthFlow.BROWSER_BASED_AUTHENTICATION); + when(mockContext.getHostForOAuth()).thenReturn("https://oauth-browser.databricks.com"); + when(mockContext.getClientId()).thenReturn("client-id"); + when(mockContext.getClientSecret()).thenReturn("client-secret"); + when(mockContext.getTokenCachePassPhrase()).thenReturn("test-passphrase"); + when(mockContext.getHttpConnectionPoolSize()).thenReturn(100); + when(mockContext.isTokenCacheEnabled()).thenReturn(true); + mockedStatic + .when(() -> DatabricksAuthUtil.getTokenEndpoint(any(), any())) + .thenReturn("https://oauth-browser.databricks.com"); + + configurator = new ClientConfigurator(mockContext); + WorkspaceClient client = configurator.getWorkspaceClient(); + DatabricksConfig config = client.config(); + + assertEquals("external-browser-with-cache", config.getAuthType()); + assertInstanceOf( + CachingExternalBrowserCredentialsProvider.class, config.getCredentialsProvider()); + } + } + + @Test + void getWorkspaceClient_OAuthBrowserAuth_WithTokenCacheDisabled_UsesStandardProvider() + throws DatabricksParsingException { + when(mockContext.getAuthMech()).thenReturn(AuthMech.OAUTH); + when(mockContext.getAuthFlow()).thenReturn(AuthFlow.BROWSER_BASED_AUTHENTICATION); + when(mockContext.getHostForOAuth()).thenReturn("https://oauth-browser.databricks.com"); + when(mockContext.getClientId()).thenReturn("client-id"); + when(mockContext.getClientSecret()).thenReturn("client-secret"); + when(mockContext.getHttpConnectionPoolSize()).thenReturn(100); + when(mockContext.isTokenCacheEnabled()).thenReturn(false); + + configurator = new ClientConfigurator(mockContext); + WorkspaceClient client = configurator.getWorkspaceClient(); + DatabricksConfig config = client.config(); + + assertEquals("external-browser", config.getAuthType()); + assertInstanceOf(ExternalBrowserCredentialsProvider.class, config.getCredentialsProvider()); + } + @Test void testNonOauth() { when(mockContext.getAuthMech()).thenReturn(AuthMech.OTHER);