diff --git a/.plans/JDBC_REMOVAL_PLAN.md b/.plans/JDBC_REMOVAL_PLAN.md index c6a4cb46b..c781b4ed6 100644 --- a/.plans/JDBC_REMOVAL_PLAN.md +++ b/.plans/JDBC_REMOVAL_PLAN.md @@ -12,6 +12,8 @@ eliminate the JDBC dependency entirely. **This project is a pure mechanical migration. There must be no functional changes.** +- **Always replicate** — no matter how many dependencies, replicate the JDBC original + verbatim. Never create "simplified replacements" or "minimal viable alternatives." - Copy JDBC classes verbatim into this repo. Compare against the actual source file in the JDBC repo (not decompiled output). Do not refactor, simplify, rename fields, change method signatures, strip comments/Javadoc, or improve logic — even where @@ -278,8 +280,12 @@ is replaced. | Step 9c — Swap telemetry imports | ✅ Open | #1131 | | Step 10a — Replicate SFException, ExecTimeTelemetryData, etc. | ✅ Open | #1132 | | Step 10b — Swap SFException imports | ✅ Open | #1134 | -| Step 10c — Remove SFSession/SFBaseSession | ⬜ TODO | — | -| Step 10d — Demote JDBC to test scope | ⬜ TODO | — | +| Step 10c — Remove SFSession from storage stack | ✅ Open | #1135 | +| Step 10c2 — Remove SFSession from exceptions + telemetry | ✅ Open | #1136 | +| Step 11a — Replace JDBC HTTP calls with HttpRequestHelper | ⬜ TODO | — | +| Step 11b — Remove FQN SnowflakeSQLException from throws | ⬜ TODO | — | +| Step 11c — Clean up remaining FQN JDBC references | ⬜ TODO | — | +| Step 11d — Demote JDBC to test scope | ⬜ TODO | — | **Closed PRs:** #1117 (reverted 7b approach), #1122 (reverted 8c approach) **Other PRs:** #1118 (error/exception tests on master), #1133 (Maven retry config) @@ -616,34 +622,74 @@ NOT swapped — they interact with JDBC's `RestRequest.executeWithRetries()`. --- -### Step 10c — Remove SFSession/SFBaseSession ⬜ TODO +### Step 10c — Remove SFSession from storage stack ✅ Open (PR #1135) -SFSession/SFBaseSession are always null from ingest callers. Not feasible -to replicate (1498+1404 lines, 156-class transitive closure = 40K lines). -Need to remove these parameter types. +**Done:** Remove SFSession/SFBaseSession parameters and dead session-based +code from storage clients, interface, strategies, factory, agent, config. +Session was always null from ingest callers. -336 lines. --- -### Step 10d — Demote JDBC to test scope ⬜ TODO +### Step 10c2 — Remove SFSession from exceptions + telemetry ✅ Open (PR #1136) -Remaining 27 JDBC imports after Step 10b (all unreplicable due to massive -dependency chains): -- `SFSession`/`SFBaseSession` (15) — parameter types, always null -- `HttpUtil` (2) — GCS client + TelemetryClient -- `RestRequest` (1) — GCS client -- `SnowflakeConnectionV1` (1) — TelemetryClient session path -- `SnowflakeSQLException` (JDBC's, 1) — TelemetryClient -- `ExecTimeTelemetryData`/`HttpResponseContextDto` (2) — GCS client - (replicated but interact with JDBC RestRequest) -- IB `Telemetry`/`TelemetryField`/`TelemetryUtil` (3) — - interact with session.getTelemetryClient() -- `SFSession` in `SnowflakeSQLLoggedException` (2) — parameter types +**Done:** Remove SFSession/SFBaseSession from SnowflakeSQLLoggedException +(all 15 constructors) and TelemetryClient (session-based code). Remove IB +telemetry dead code. Update all callers (~12 files). -339 lines. -Then: +After 10c2: 6 JDBC imports remain + ~70 FQN JDBC references in throws/params. + +--- + +### Step 11a — Replace JDBC HTTP calls with HttpRequestHelper ⬜ TODO + +Create `HttpRequestHelper` utility with retry logic (replaces JDBC's +`RestRequest.executeWithRetries` and `HttpUtil.executeGeneralRequest`). +Replace the 6 remaining JDBC imports: + +- `TelemetryClient`: replace `HttpUtil.executeGeneralRequest()` + + `SnowflakeSQLException` catch +- `SnowflakeGCSClient`: replace `HttpUtil.getHttpClientWithoutDecompression()`, + `HttpUtil.getHttpClient()`, `RestRequest.executeWithRetries()`, + `HttpUtil.getSocketTimeout()`. Remove `ExecTimeTelemetryData`, + `HttpResponseContextDto`, `RestRequest` imports. + +--- + +### Step 11b — Remove FQN SnowflakeSQLException from throws ⬜ TODO + +Mechanical removal of `, net.snowflake.client.jdbc.SnowflakeSQLException` +from ~47 throws clauses across all replicated storage clients, interface, +strategies, factory, and GCS client. + +--- + +### Step 11c — Clean up remaining FQN JDBC references ⬜ TODO + +Swap remaining FQN JDBC type references to ingest versions: +- `net.snowflake.client.core.HttpClientSettingsKey` → `HttpClientSettingsKey` + (same package, 8 occurrences in S3HttpUtil, SnowflakeFileTransferAgent, + SnowflakeGCSClient, SnowflakeStorageClient) +- `net.snowflake.client.core.HttpProtocol` → `HttpProtocol` (same package, + 1 occurrence in S3HttpUtil) +- `net.snowflake.client.core.OCSPMode` → `net.snowflake.ingest.utils.OCSPMode` + (2 occurrences in SnowflakeFileTransferAgent) +- `net.snowflake.client.jdbc.SnowflakeUtil.convertProxyPropertiesToHttpClientKey` + → `StorageClientUtil.convertProxyPropertiesToHttpClientKey` (2 occurrences + in SnowflakeFileTransferAgent) +- `static import net.snowflake.client.core.HttpUtil.setSessionlessProxyForAzure` + → replicate method in StorageClientUtil (4 occurrences in SnowflakeAzureClient) +- `net.snowflake.client.jdbc.cloud.storage.AwsSdkGCPSigner` — JDBC class + reference in GCSAccessStrategyAwsSdk (string constant + class reference, + 3 occurrences) + +--- + +### Step 11d — Demote JDBC to test scope ⬜ TODO + +After all FQN references are cleaned up: 1. Demote `snowflake-jdbc-thin` to `test` scope in `pom.xml` 2. Remove JDBC shade relocation rules from Maven Shade plugin -3. Remove `snowflake-jdbc-thin` from `public_pom.xml` -4. Run full test suite +3. Run full test suite --- diff --git a/pom.xml b/pom.xml index 5341d3187..1d39ef1bb 100644 --- a/pom.xml +++ b/pom.xml @@ -887,11 +887,6 @@ net.minidev json-smart - - - net.snowflake - snowflake-jdbc-thin - org.apache.commons commons-lang3 @@ -1017,6 +1012,12 @@ junit test + + + net.snowflake + snowflake-jdbc-thin + test + org.assertj assertj-core @@ -1514,14 +1515,7 @@ com.nimbusds ${shadeBase}.com.nimbusds - - net.snowflake.client - ${shadeBase}.net.snowflake.client - - - com.snowflake.client.jdbc - ${shadeBase}.com.snowflake.client.jdbc - + org.bouncycastle ${shadeBase}.org.bouncycastle diff --git a/src/main/java/net/snowflake/ingest/connection/telemetry/TelemetryClient.java b/src/main/java/net/snowflake/ingest/connection/telemetry/TelemetryClient.java index 1c88d210f..070669923 100644 --- a/src/main/java/net/snowflake/ingest/connection/telemetry/TelemetryClient.java +++ b/src/main/java/net/snowflake/ingest/connection/telemetry/TelemetryClient.java @@ -11,9 +11,9 @@ import java.util.LinkedList; import java.util.Objects; import java.util.concurrent.Future; -import net.snowflake.client.core.HttpUtil; -import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.JdbcHttpUtil; import net.snowflake.ingest.streaming.internal.fileTransferAgent.ObjectMapperFactory; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.SnowflakeSQLException; import net.snowflake.ingest.streaming.internal.fileTransferAgent.TelemetryThreadPool; import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLogger; import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLoggerFactory; @@ -275,11 +275,11 @@ private boolean sendBatch() throws IOException { try { response = - HttpUtil.executeGeneralRequest( + JdbcHttpUtil.executeGeneralRequest( post, TELEMETRY_HTTP_RETRY_TIMEOUT_IN_SEC, 0, - (int) HttpUtil.getSocketTimeout().toMillis(), + (int) JdbcHttpUtil.getSocketTimeout().toMillis(), 0, this.httpClient); stopwatch.stop(); diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/AttributeEnhancingHttpRequestRetryHandler.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/AttributeEnhancingHttpRequestRetryHandler.java new file mode 100644 index 000000000..a73b2d4c8 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/AttributeEnhancingHttpRequestRetryHandler.java @@ -0,0 +1,32 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/AttributeEnhancingHttpRequestRetryHandler.java + * + * Permitted differences: package declaration. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import java.io.IOException; +import org.apache.http.impl.client.DefaultHttpRequestRetryHandler; +import org.apache.http.protocol.HttpContext; + +/** + * Extends {@link DefaultHttpRequestRetryHandler} to store the current execution count (attempt + * number) in the {@link HttpContext}. This allows interceptors to identify retry attempts. + * + *

The execution count is stored using the key defined by {@link #EXECUTION_COUNT_ATTRIBUTE}. + */ +class AttributeEnhancingHttpRequestRetryHandler extends DefaultHttpRequestRetryHandler { + /** + * The key used to store the current execution count (attempt number) in the {@link HttpContext}. + * Interceptors can use this key to retrieve the count. The value stored will be an {@link + * Integer}. + */ + static final String EXECUTION_COUNT_ATTRIBUTE = "net.snowflake.client.core.execution-count"; + + @Override + public boolean retryRequest(IOException exception, int executionCount, HttpContext context) { + context.setAttribute(EXECUTION_COUNT_ATTRIBUTE, executionCount); + return super.retryRequest(exception, executionCount, context); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/AwsSdkGCPSigner.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/AwsSdkGCPSigner.java new file mode 100644 index 000000000..ddaef10d0 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/AwsSdkGCPSigner.java @@ -0,0 +1,56 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/jdbc/cloud/storage/AwsSdkGCPSigner.java + * + * Permitted differences: package. @SnowflakeJdbcInternalApi removed. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import com.amazonaws.SignableRequest; +import com.amazonaws.auth.AWS4Signer; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.http.HttpMethodName; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; + +public class AwsSdkGCPSigner extends AWS4Signer { + private static final Map headerMap = + new HashMap() { + { + put("x-amz-storage-class", "x-goog-storage-class"); + put("x-amz-acl", "x-goog-acl"); + put("x-amz-date", "x-goog-date"); + put("x-amz-copy-source", "x-goog-copy-source"); + put("x-amz-metadata-directive", "x-goog-metadata-directive"); + put("x-amz-copy-source-if-match", "x-goog-copy-source-if-match"); + put("x-amz-copy-source-if-none-match", "x-goog-copy-source-if-none-match"); + put("x-amz-copy-source-if-unmodified-since", "x-goog-copy-source-if-unmodified-since"); + put("x-amz-copy-source-if-modified-since", "x-goog-copy-source-if-modified-since"); + } + }; + + @Override + public void sign(SignableRequest request, AWSCredentials credentials) { + if (credentials.getAWSAccessKeyId() != null && !"".equals(credentials.getAWSAccessKeyId())) { + request.addHeader("Authorization", "Bearer " + credentials.getAWSAccessKeyId()); + } + + if (request.getHttpMethod() == HttpMethodName.GET) { + request.addHeader("Accept-Encoding", "gzip,deflate"); + } + + Map headerCopy = + request.getHeaders().entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + for (Map.Entry entry : headerCopy.entrySet()) { + String entryKey = entry.getKey().toLowerCase(); + if (headerMap.containsKey(entryKey)) { + request.addHeader(headerMap.get(entryKey), entry.getValue()); + } else if (entryKey.startsWith("x-amz-meta-")) { + request.addHeader(entryKey.replace("x-amz-meta-", "x-goog-meta-"), entry.getValue()); + } + } + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/Constants.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/Constants.java new file mode 100644 index 000000000..80152dfa9 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/Constants.java @@ -0,0 +1,66 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/Constants.java + * + * Permitted differences: package declaration, + * SnowflakeUtil.systemGetProperty -> StorageClientUtil.systemGetProperty. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import static net.snowflake.ingest.streaming.internal.fileTransferAgent.StorageClientUtil.systemGetProperty; + +/* + * Constants used in JDBC implementation + */ +public final class Constants { + // Session expired error code as returned from Snowflake + public static final int SESSION_EXPIRED_GS_CODE = 390112; + + // Cloud storage credentials expired error code + public static final int CLOUD_STORAGE_CREDENTIALS_EXPIRED = 240001; + + // Session gone error code as returned from Snowflake + public static final int SESSION_GONE = 390111; + + // Error code for all invalid id token cases during login request + public static final int ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE = 390195; + + public static final int OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE = 390318; + + public static final int OAUTH_ACCESS_TOKEN_INVALID_GS_CODE = 390303; + + // Error message for IOException when no space is left for GET + public static final String NO_SPACE_LEFT_ON_DEVICE_ERR = "No space left on device"; + + public enum OS { + WINDOWS, + LINUX, + MAC, + SOLARIS + } + + private static OS os = null; + + public static synchronized OS getOS() { + if (os == null) { + String operSys = systemGetProperty("os.name").toLowerCase(); + if (operSys.contains("win")) { + os = OS.WINDOWS; + } else if (operSys.contains("nix") || operSys.contains("nux") || operSys.contains("aix")) { + os = OS.LINUX; + } else if (operSys.contains("mac")) { + os = OS.MAC; + } else if (operSys.contains("sunos")) { + os = OS.SOLARIS; + } + } + return os; + } + + public static void clearOSForTesting() { + os = null; + } + + public static final int MB = 1024 * 1024; + public static final long GB = 1024 * 1024 * 1024; +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/DecorrelatedJitterBackoff.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/DecorrelatedJitterBackoff.java new file mode 100644 index 000000000..5aeaaa287 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/DecorrelatedJitterBackoff.java @@ -0,0 +1,39 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/util/DecorrelatedJitterBackoff.java + * + * Permitted differences: package. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import java.util.concurrent.ThreadLocalRandom; + +/** + * Decorrelated Jitter backoff + * + *

https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ + */ +public class DecorrelatedJitterBackoff { + private final long base; + private final long cap; + + public DecorrelatedJitterBackoff(long base, long cap) { + this.base = base; + this.cap = cap; + } + + public long nextSleepTime(long sleep) { + long correctedSleep = sleep <= base ? base + 1 : sleep; + return Math.min(cap, ThreadLocalRandom.current().nextLong(base, correctedSleep)); + } + + public long getJitterForLogin(long currentTime) { + double multiplicationFactor = chooseRandom(-1, 1); + long jitter = (long) (multiplicationFactor * currentTime * 0.5); + return jitter; + } + + public double chooseRandom(double min, double max) { + return min + (Math.random() * (max - min)); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/FileCacheManager.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/FileCacheManager.java new file mode 100644 index 000000000..92b9bdbb4 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/FileCacheManager.java @@ -0,0 +1,419 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/FileCacheManager.java + * + * Permitted differences: package declaration, + * net.snowflake.client.log.* -> ...fileTransferAgent.log.*, + * SnowflakeUtil.systemGetEnv/systemGetProperty -> StorageClientUtil.*, + * SnowflakeUtil.isWindows -> StorageClientUtil.isWindows, + * FileUtil/Constants/ObjectMapperFactory referenced from same package. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import static net.snowflake.ingest.streaming.internal.fileTransferAgent.FileUtil.isWritable; +import static net.snowflake.ingest.streaming.internal.fileTransferAgent.StorageClientUtil.isWindows; +import static net.snowflake.ingest.streaming.internal.fileTransferAgent.StorageClientUtil.systemGetEnv; +import static net.snowflake.ingest.streaming.internal.fileTransferAgent.StorageClientUtil.systemGetProperty; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.Reader; +import java.io.Writer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.attribute.BasicFileAttributes; +import java.nio.file.attribute.PosixFilePermission; +import java.nio.file.attribute.PosixFilePermissions; +import java.util.Date; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLogger; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLoggerFactory; + +class FileCacheManager { + private static final SFLogger logger = SFLoggerFactory.getLogger(FileCacheManager.class); + + /** Object mapper for JSON encoding and decoding */ + private static final ObjectMapper OBJECT_MAPPER = ObjectMapperFactory.getObjectMapper(); + + private static final Charset DEFAULT_FILE_ENCODING = StandardCharsets.UTF_8; + + private String cacheDirectorySystemProperty; + private String cacheDirectoryEnvironmentVariable; + private String baseCacheFileName; + private long cacheFileLockExpirationInMilliseconds; + + private File cacheFile; + private File cacheLockFile; + + private File cacheDir; + + private boolean onlyOwnerPermissions = true; + + private FileCacheManager() {} + + static FileCacheManager builder() { + return new FileCacheManager(); + } + + FileCacheManager setCacheDirectorySystemProperty(String cacheDirectorySystemProperty) { + this.cacheDirectorySystemProperty = cacheDirectorySystemProperty; + return this; + } + + FileCacheManager setCacheDirectoryEnvironmentVariable(String cacheDirectoryEnvironmentVariable) { + this.cacheDirectoryEnvironmentVariable = cacheDirectoryEnvironmentVariable; + return this; + } + + FileCacheManager setBaseCacheFileName(String baseCacheFileName) { + this.baseCacheFileName = baseCacheFileName; + return this; + } + + FileCacheManager setCacheFileLockExpirationInSeconds(long cacheFileLockExpirationInSeconds) { + this.cacheFileLockExpirationInMilliseconds = cacheFileLockExpirationInSeconds * 1000; + return this; + } + + FileCacheManager setOnlyOwnerPermissions(boolean onlyOwnerPermissions) { + this.onlyOwnerPermissions = onlyOwnerPermissions; + return this; + } + + synchronized String getCacheFilePath() { + return cacheFile.getAbsolutePath(); + } + + /** + * Override the cache file. + * + * @param newCacheFile a file object to override the default one. + */ + synchronized void overrideCacheFile(File newCacheFile) { + if (!FileUtil.exists(newCacheFile)) { + logger.debug("Cache file doesn't exist. File: {}", newCacheFile); + } + if (onlyOwnerPermissions) { + FileUtil.handleWhenFilePermissionsWiderThanUserOnly(newCacheFile, "Override cache file"); + FileUtil.handleWhenParentDirectoryPermissionsWiderThanUserOnly( + newCacheFile, "Override cache file"); + } else { + FileUtil.logFileUsage(cacheFile, "Override cache file", false); + } + this.cacheFile = newCacheFile; + this.cacheDir = newCacheFile.getParentFile(); + this.baseCacheFileName = newCacheFile.getName(); + } + + synchronized FileCacheManager build() { + // try to get cacheDir from system property or environment variable + String cacheDirPath = + this.cacheDirectorySystemProperty != null + ? systemGetProperty(this.cacheDirectorySystemProperty) + : null; + if (cacheDirPath == null) { + try { + cacheDirPath = + this.cacheDirectoryEnvironmentVariable != null + ? systemGetEnv(this.cacheDirectoryEnvironmentVariable) + : null; + } catch (Throwable ex) { + logger.debug( + "Cannot get environment variable for cache directory, skip using cache", false); + // In Boomi cloud, System.getenv is not allowed due to policy, + // so we catch the exception and skip cache completely + return this; + } + } + + if (cacheDirPath != null) { + this.cacheDir = new File(cacheDirPath); + } else { + this.cacheDir = getDefaultCacheDir(); + } + if (cacheDir == null) { + return this; + } + if (!cacheDir.exists()) { + try { + if (!isWindows() && onlyOwnerPermissions) { + Files.createDirectories( + cacheDir.toPath(), + PosixFilePermissions.asFileAttribute( + Stream.of( + PosixFilePermission.OWNER_READ, + PosixFilePermission.OWNER_WRITE, + PosixFilePermission.OWNER_EXECUTE) + .collect(Collectors.toSet()))); + } else { + Files.createDirectories(cacheDir.toPath()); + } + } catch (IOException e) { + logger.info( + "Failed to create the cache directory: {}. Ignored. {}", + e.getMessage(), + cacheDir.getAbsoluteFile()); + return this; + } + } + if (!this.cacheDir.exists()) { + logger.debug( + "Cannot create the cache directory {}. Giving up.", this.cacheDir.getAbsolutePath()); + return this; + } + logger.debug("Verified Directory {}", this.cacheDir.getAbsolutePath()); + + File cacheFileTmp = new File(this.cacheDir, this.baseCacheFileName).getAbsoluteFile(); + try { + // create an empty file if not exists and return true. + // If exists. the method returns false. + // In this particular case, it doesn't matter as long as the file is + // writable. + if (!cacheFileTmp.exists()) { + if (!isWindows() && onlyOwnerPermissions) { + Files.createFile( + cacheFileTmp.toPath(), + PosixFilePermissions.asFileAttribute( + Stream.of(PosixFilePermission.OWNER_READ, PosixFilePermission.OWNER_WRITE) + .collect(Collectors.toSet()))); + } else { + Files.createFile(cacheFileTmp.toPath()); + } + logger.debug("Successfully created a cache file {}", cacheFileTmp); + } else { + logger.debug("Cache file already exists {}", cacheFileTmp); + } + FileUtil.logFileUsage(cacheFileTmp, "Cache file creation", false); + this.cacheFile = cacheFileTmp.getCanonicalFile(); + this.cacheLockFile = + new File(this.cacheFile.getParentFile(), this.baseCacheFileName + ".lck"); + } catch (IOException | SecurityException ex) { + logger.info( + "Failed to touch the cache file: {}. Ignored. {}", + ex.getMessage(), + cacheFileTmp.getAbsoluteFile()); + } + return this; + } + + static File getDefaultCacheDir() { + if (Constants.getOS() == Constants.OS.LINUX) { + String xdgCacheHome = getXdgCacheHome(); + if (xdgCacheHome != null) { + return new File(xdgCacheHome, "snowflake"); + } + } + + String homeDir = getHomeDirProperty(); + if (homeDir == null) { + // if still home directory is null, no cache dir is set. + return null; + } + if (Constants.getOS() == Constants.OS.WINDOWS) { + return new File( + new File(new File(new File(homeDir, "AppData"), "Local"), "Snowflake"), "Caches"); + } else if (Constants.getOS() == Constants.OS.MAC) { + return new File(new File(new File(homeDir, "Library"), "Caches"), "Snowflake"); + } else { + return new File(new File(homeDir, ".cache"), "snowflake"); + } + } + + private static String getXdgCacheHome() { + String xdgCacheHome = systemGetEnv("XDG_CACHE_HOME"); + if (xdgCacheHome != null && isWritable(xdgCacheHome)) { + return xdgCacheHome; + } + return null; + } + + private static String getHomeDirProperty() { + String homeDir = systemGetProperty("user.home"); + if (homeDir != null && isWritable(homeDir)) { + return homeDir; + } + return null; + } + + synchronized T withLock(Supplier supplier) { + if (cacheFile == null) { + logger.error("No cache file assigned", false); + return null; + } + if (cacheLockFile == null) { + logger.error("No cache lock file assigned", false); + return null; + } else if (cacheLockFile.exists()) { + deleteCacheLockIfExpired(); + } + + if (!tryToLockCacheFile()) { + logger.debug("Failed to lock the file. Skipping cache operation", false); + return null; + } + try { + return supplier.get(); + } finally { + if (!unlockCacheFile()) { + logger.debug("Failed to unlock cache file", false); + } + } + } + + /** Reads the cache file. */ + synchronized JsonNode readCacheFile() { + try { + if (!FileUtil.exists(cacheFile)) { + logger.debug("Cache file doesn't exist. Ignoring read. File: {}", cacheFile); + return null; + } + + try (Reader reader = + new InputStreamReader(new FileInputStream(cacheFile), DEFAULT_FILE_ENCODING)) { + + if (onlyOwnerPermissions) { + FileUtil.handleWhenFilePermissionsWiderThanUserOnly(cacheFile, "Read cache"); + FileUtil.handleWhenParentDirectoryPermissionsWiderThanUserOnly(cacheFile, "Read cache"); + FileUtil.throwWhenOwnerDifferentThanCurrentUser(cacheFile, "Read cache"); + } else { + FileUtil.logFileUsage(cacheFile, "Read cache", false); + } + return OBJECT_MAPPER.readTree(reader); + } + } catch (IOException ex) { + logger.debug("Failed to read the cache file. No worry. File: {}, Err: {}", cacheFile, ex); + } + return null; + } + + synchronized void writeCacheFile(JsonNode input) { + logger.debug("Writing cache file. File: {}", cacheFile); + try { + if (input == null || !FileUtil.exists(cacheFile)) { + logger.debug( + "Cache file doesn't exist or input is null. Ignoring write. File: {}", cacheFile); + return; + } + try (Writer writer = + new OutputStreamWriter(new FileOutputStream(cacheFile), DEFAULT_FILE_ENCODING)) { + if (onlyOwnerPermissions) { + FileUtil.handleWhenFilePermissionsWiderThanUserOnly(cacheFile, "Write to cache"); + FileUtil.handleWhenParentDirectoryPermissionsWiderThanUserOnly( + cacheFile, "Write to cache"); + } else { + FileUtil.logFileUsage(cacheFile, "Write to cache", false); + } + writer.write(input.toString()); + } + } catch (IOException ex) { + logger.debug("Failed to write the cache file. File: {}", cacheFile); + } + } + + synchronized void deleteCacheFile() { + logger.debug("Deleting cache file. File: {}, lock file: {}", cacheFile, cacheLockFile); + + if (cacheFile == null) { + return; + } + + unlockCacheFile(); + if (!cacheFile.delete()) { + logger.debug("Failed to delete the file: {}", cacheFile); + } + } + + /** + * Tries to lock the cache file + * + * @return true if success or false + */ + private synchronized boolean tryToLockCacheFile() { + int cnt = 0; + boolean locked = false; + while (cnt < 5 && !(locked = lockCacheFile())) { + try { + Thread.sleep(10); + } catch (InterruptedException ex) { + // doesn't matter + } + ++cnt; + } + if (!locked) { + deleteCacheLockIfExpired(); + if (!lockCacheFile()) { + logger.debug("Failed to lock the cache file.", false); + } + } + return locked; + } + + private synchronized void deleteCacheLockIfExpired() { + long currentTime = new Date().getTime(); + long lockFileTs = fileCreationTime(cacheLockFile); + if (lockFileTs < 0) { + logger.debug("Failed to get the timestamp of lock directory"); + } else if (lockFileTs < currentTime - this.cacheFileLockExpirationInMilliseconds) { + // old lock file + try { + if (!cacheLockFile.delete()) { + logger.debug("Failed to delete the directory. Dir: {}", cacheLockFile); + } else { + logger.debug("Deleted expired cache lock directory.", false); + } + } catch (Exception e) { + logger.debug( + "Failed to delete the directory. Dir: {}, Error: {}", cacheLockFile, e.getMessage()); + } + } + } + + /** + * Gets file/dir creation time in epoch (ms) + * + * @return epoch time in ms + */ + private static synchronized long fileCreationTime(File targetFile) { + if (!FileUtil.exists(targetFile)) { + logger.debug("File does not exist. File: {}", targetFile); + return -1; + } + try { + Path cacheFileLockPath = Paths.get(targetFile.getAbsolutePath()); + BasicFileAttributes attr = Files.readAttributes(cacheFileLockPath, BasicFileAttributes.class); + return attr.creationTime().toMillis(); + } catch (IOException ex) { + logger.debug("Failed to get creation time. File/Dir: {}, Err: {}", targetFile, ex); + } + return -1; + } + + /** + * Lock cache file by creating a lock directory + * + * @return true if success or false + */ + private synchronized boolean lockCacheFile() { + return cacheLockFile.mkdirs(); + } + + /** + * Unlock cache file by deleting a lock directory + * + * @return true if success or false + */ + private synchronized boolean unlockCacheFile() { + return cacheLockFile.delete(); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/FileUtil.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/FileUtil.java index b0f17d1f9..f441db0c0 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/FileUtil.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/FileUtil.java @@ -14,6 +14,8 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.attribute.FileOwnerAttributeView; import java.nio.file.attribute.PosixFilePermission; import java.util.Arrays; import java.util.Collection; @@ -84,6 +86,127 @@ private static String getContextStr(String context) { return StorageClientUtil.isNullOrEmpty(context) ? "" : context + ": "; } + public static void logFileUsage(String stringPath, String context, boolean logReadAccess) { + Path path = Paths.get(stringPath); + logFileUsage(path, context, logReadAccess); + } + + public static boolean isWritable(String path) { + File file = new File(path); + if (!file.canWrite()) { + logger.debug("File/directory not writeable: {}", path); + return false; + } + return true; + } + + public static void handleWhenParentDirectoryPermissionsWiderThanUserOnly( + File file, String context) { + handleWhenDirectoryPermissionsWiderThanUserOnly(file.getParentFile(), context); + } + + public static void handleWhenFilePermissionsWiderThanUserOnly(File file, String context) { + if (Files.isSymbolicLink(file.toPath())) { + throw new SecurityException("Symbolic link is not allowed for file cache: " + file); + } + handleWhenPermissionsWiderThanUserOnly(file.toPath(), context, false); + } + + public static void handleWhenDirectoryPermissionsWiderThanUserOnly(File file, String context) { + handleWhenPermissionsWiderThanUserOnly(file.toPath(), context, true); + } + + public static void handleWhenPermissionsWiderThanUserOnly( + Path filePath, String context, boolean isDirectory) { + // we do not check the permissions for Windows + if (isWindows()) { + return; + } + + try { + Collection filePermissions = Files.getPosixFilePermissions(filePath); + boolean isWritableByOthers = isPermPresent(filePermissions, WRITE_BY_OTHERS); + boolean isReadableByOthers = isPermPresent(filePermissions, READ_BY_OTHERS); + boolean isExecutable = isPermPresent(filePermissions, EXECUTABLE); + + boolean permissionsTooOpen; + if (isDirectory) { + permissionsTooOpen = isWritableByOthers || isReadableByOthers; + } else { + permissionsTooOpen = isWritableByOthers || isReadableByOthers || isExecutable; + } + if (permissionsTooOpen) { + logger.debug( + "{}File/directory {} access rights: {}", + getContextStr(context), + filePath, + filePermissions); + String message = + String.format( + "Access to file or directory %s is wider than allowed. Remove cache file/directory" + + " and re-run the driver.", + filePath); + if (isDirectory) { + logger.warn(message); + } else { + throw new SecurityException(message); + } + } else { + if (!isDirectory && Files.isSymbolicLink(filePath)) { + throw new SecurityException("Symbolic link is not allowed for file cache: " + filePath); + } + } + } catch (IOException e) { + String message = + String.format( + "%s Unable to access the file/directory to check the permissions. Error: %s", + filePath, e); + if (isDirectory) { + logger.warn(message); + } else { + throw new SecurityException(message); + } + } + } + + public static void throwWhenOwnerDifferentThanCurrentUser(File file, String context) { + // we do not check the permissions for Windows + if (isWindows()) { + return; + } + + Path filePath = Paths.get(file.getPath()); + + try { + String fileOwnerName = getFileOwnerName(filePath); + String currentUser = System.getProperty("user.name"); + if (!currentUser.equalsIgnoreCase(fileOwnerName)) { + logger.debug( + "The file owner: {} is different than current user: {}", fileOwnerName, currentUser); + throw new SecurityException("The file owner is different than current user"); + } + } catch (IOException e) { + logger.warn( + "{}Unable to access the file to check the owner: {}. Error: {}", + getContextStr(context), + filePath, + e); + } + } + + static String getFileOwnerName(Path filePath) throws IOException { + FileOwnerAttributeView ownerAttributeView = + Files.getFileAttributeView(filePath, FileOwnerAttributeView.class); + return ownerAttributeView.getOwner().getName(); + } + + public static boolean exists(File file) { + if (file == null) { + return false; + } + return file.exists(); + } + private static boolean isWindows() { return StorageClientUtil.isWindows(); } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/GCSAccessStrategy.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/GCSAccessStrategy.java index 4ce9b261f..d3de304fb 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/GCSAccessStrategy.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/GCSAccessStrategy.java @@ -42,7 +42,7 @@ boolean handleStorageException( String command, String queryId, SnowflakeGCSClient gcsClient) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException; + throws SnowflakeSQLException; void shutdown(); } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/GCSAccessStrategyAwsSdk.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/GCSAccessStrategyAwsSdk.java index 15c4b7c3b..2357a6591 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/GCSAccessStrategyAwsSdk.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/GCSAccessStrategyAwsSdk.java @@ -39,8 +39,7 @@ class GCSAccessStrategyAwsSdk implements GCSAccessStrategy { private static final SFLogger logger = SFLoggerFactory.getLogger(GCSAccessStrategyAwsSdk.class); private final AmazonS3 amazonClient; - GCSAccessStrategyAwsSdk(StageInfo stage) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + GCSAccessStrategyAwsSdk(StageInfo stage) throws SnowflakeSQLException { String accessToken = (String) stage.getCredentials().get("GCS_ACCESS_TOKEN"); Optional oEndpoint = stage.gcsCustomEndpoint(); @@ -64,9 +63,10 @@ class GCSAccessStrategyAwsSdk implements GCSAccessStrategy { ClientConfiguration clientConfig = new ClientConfiguration(); SignerFactory.registerSigner( - "net.snowflake.client.jdbc.cloud.storage.AwsSdkGCPSigner", - net.snowflake.client.jdbc.cloud.storage.AwsSdkGCPSigner.class); - clientConfig.setSignerOverride("net.snowflake.client.jdbc.cloud.storage.AwsSdkGCPSigner"); + "net.snowflake.ingest.streaming.internal.fileTransferAgent.AwsSdkGCPSigner", + net.snowflake.ingest.streaming.internal.fileTransferAgent.AwsSdkGCPSigner.class); + clientConfig.setSignerOverride( + "net.snowflake.ingest.streaming.internal.fileTransferAgent.AwsSdkGCPSigner"); clientConfig .getApacheHttpClientConfig() @@ -223,7 +223,7 @@ public boolean handleStorageException( String command, String queryId, SnowflakeGCSClient gcsClient) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { if (ex instanceof AmazonClientException) { logger.debug("GCSAccessStrategyAwsSdk: " + ex.getMessage()); diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/GCSDefaultAccessStrategy.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/GCSDefaultAccessStrategy.java index c710a1ead..2e08910b6 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/GCSDefaultAccessStrategy.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/GCSDefaultAccessStrategy.java @@ -178,7 +178,7 @@ public boolean handleStorageException( String command, String queryId, SnowflakeGCSClient gcsClient) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { if (ex instanceof StorageException) { // NOTE: this code path only handle Access token based operation, // presigned URL is not covered. Presigned Url do not raise diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/HexUtil.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/HexUtil.java new file mode 100644 index 000000000..301541e49 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/HexUtil.java @@ -0,0 +1,27 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/HexUtil.java + * + * Permitted differences: package declaration. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +class HexUtil { + + /** + * Converts Byte array to hex string + * + * @param bytes a byte array + * @return a string in hexadecimal code + */ + static String byteToHexString(byte[] bytes) { + final char[] hexArray = "0123456789ABCDEF".toCharArray(); + char[] hexChars = new char[bytes.length * 2]; + for (int j = 0; j < bytes.length; j++) { + int v = bytes[j] & 0xFF; + hexChars[j * 2] = hexArray[v >>> 4]; + hexChars[j * 2 + 1] = hexArray[v & 0x0F]; + } + return new String(hexChars); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/HttpClientSettingsKey.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/HttpClientSettingsKey.java index 58dd883bc..605c6474d 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/HttpClientSettingsKey.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/HttpClientSettingsKey.java @@ -1,5 +1,5 @@ /* - * Replicated from snowflake-jdbc: net.snowflake.client.core.HttpClientSettingsKey + * Replicated from snowflake-jdbc: HttpClientSettingsKey * Tag: v3.25.1 * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/HttpClientSettingsKey.java * diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/HttpExecutingContext.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/HttpExecutingContext.java new file mode 100644 index 000000000..8a9c07cb6 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/HttpExecutingContext.java @@ -0,0 +1,311 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/HttpExecutingContext.java + * + * Permitted differences: package declaration, import swaps for already-replicated classes, + * @SnowflakeJdbcInternalApi annotation removed. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import java.util.concurrent.atomic.AtomicBoolean; + +public class HttpExecutingContext { + + // min backoff in milli before we retry due to transient issues + private static final long minBackoffMillis = 1000; + + // max backoff in milli before we retry due to transient issues + // we double the backoff after each retry till we reach the max backoff + private static final long maxBackoffMillis = 16000; + + // retry at least once even if timeout limit has been reached + private static final int MIN_RETRY_COUNT = 1; + + // retry at least once even if timeout limit has been reached + private static final int DEFAULT_RETRY_TIMEOUT = 300; + + private final String requestId; + private final String requestInfoScrubbed; + private final long startTime; + // start time for each request, + // used for keeping track how much time we have spent + // due to network issues so that we can compare against the user + // specified network timeout to make sure we do not retry infinitely + // when there are transient network/GS issues + private long startTimePerRequest; + // Used to indicate that this is a login/auth request and will be using the new retry strategy. + private boolean isLoginRequest; + // Tracks the total time spent handling transient network issues and retries during HTTP requests + private long elapsedMilliForTransientIssues; + private long retryTimeout; + private long authTimeout; + private DecorrelatedJitterBackoff backoff; + private long backoffInMillis; + private int origSocketTimeout; + private String breakRetryReason; + private String breakRetryEventName; + private String lastStatusCodeForRetry; + private int retryCount; + private int maxRetries; + private boolean noRetry; + private int injectSocketTimeout; + private boolean retryHTTP403; + private boolean shouldRetry; + private boolean skipRetriesBecauseOf200; // todo create skip retry reason enum + private boolean withoutCookies; + private boolean includeRetryParameters; + private boolean includeRequestGuid; + private boolean unpackResponse; + private AtomicBoolean canceling; + + public HttpExecutingContext(String requestIdStr, String requestInfoScrubbed) { + this.requestId = requestIdStr; + this.requestInfoScrubbed = requestInfoScrubbed; + this.startTime = System.currentTimeMillis(); + this.startTimePerRequest = startTime; + this.backoff = new DecorrelatedJitterBackoff(getMinBackoffInMillis(), getMaxBackoffInMilli()); + this.backoffInMillis = minBackoffMillis; + } + + public String getRequestId() { + return requestId; + } + + public long getStartTime() { + return startTime; + } + + public long getStartTimePerRequest() { + return startTimePerRequest; + } + + public void setStartTimePerRequest(long startTimePerRequest) { + this.startTimePerRequest = startTimePerRequest; + } + + public boolean isLoginRequest() { + return isLoginRequest; + } + + public void setLoginRequest(boolean loginRequest) { + isLoginRequest = loginRequest; + } + + public long getElapsedMilliForTransientIssues() { + return elapsedMilliForTransientIssues; + } + + public long getRetryTimeoutInMilliseconds() { + return retryTimeout * 1000; + } + + public long getRetryTimeout() { + return retryTimeout; + } + + public void setRetryTimeout(long retryTimeout) { + this.retryTimeout = retryTimeout; + } + + public long getMinBackoffInMillis() { + return minBackoffMillis; + } + + public long getBackoffInMillis() { + return backoffInMillis; + } + + public void setBackoffInMillis(long backoffInMillis) { + this.backoffInMillis = backoffInMillis; + } + + public long getMaxBackoffInMilli() { + return maxBackoffMillis; + } + + public long getAuthTimeout() { + return authTimeout; + } + + public long getAuthTimeoutInMilliseconds() { + return authTimeout * 1000; + } + + public void setAuthTimeout(long authTimeout) { + this.authTimeout = authTimeout; + } + + public DecorrelatedJitterBackoff getBackoff() { + return backoff; + } + + public void setBackoff(DecorrelatedJitterBackoff backoff) { + this.backoff = backoff; + } + + public int getOrigSocketTimeout() { + return origSocketTimeout; + } + + public void setOrigSocketTimeout(int origSocketTimeout) { + this.origSocketTimeout = origSocketTimeout; + } + + public String getBreakRetryReason() { + return breakRetryReason; + } + + public void setBreakRetryReason(String breakRetryReason) { + this.breakRetryReason = breakRetryReason; + } + + public String getBreakRetryEventName() { + return breakRetryEventName; + } + + public void setBreakRetryEventName(String breakRetryEventName) { + this.breakRetryEventName = breakRetryEventName; + } + + public String getLastStatusCodeForRetry() { + return lastStatusCodeForRetry; + } + + public void setLastStatusCodeForRetry(String lastStatusCodeForRetry) { + this.lastStatusCodeForRetry = lastStatusCodeForRetry; + } + + public int getRetryCount() { + return retryCount; + } + + public void setRetryCount(int retryCount) { + this.retryCount = retryCount; + } + + public void resetRetryCount() { + this.retryCount = 0; + } + + public void incrementRetryCount() { + this.retryCount++; + } + + public int getMaxRetries() { + return maxRetries; + } + + public void setMaxRetries(int maxRetries) { + this.maxRetries = maxRetries; + } + + public String getRequestInfoScrubbed() { + return requestInfoScrubbed; + } + + public boolean isNoRetry() { + return noRetry; + } + + public void setNoRetry(boolean noRetry) { + this.noRetry = noRetry; + } + + public boolean isRetryHTTP403() { + return retryHTTP403; + } + + public void setRetryHTTP403(boolean retryHTTP403) { + this.retryHTTP403 = retryHTTP403; + } + + public boolean isShouldRetry() { + return shouldRetry; + } + + public void setShouldRetry(boolean shouldRetry) { + this.shouldRetry = shouldRetry; + } + + public void increaseElapsedMilliForTransientIssues(long elapsedMilliForLastCall) { + this.elapsedMilliForTransientIssues += elapsedMilliForLastCall; + } + + public boolean elapsedTimeExceeded() { + return elapsedMilliForTransientIssues > getRetryTimeoutInMilliseconds(); + } + + public boolean moreThanMinRetries() { + return retryCount >= MIN_RETRY_COUNT; + } + + public boolean maxRetriesExceeded() { + return maxRetries > 0 && retryCount >= maxRetries; + } + + public boolean socketOrConnectTimeoutReached() { + return authTimeout > 0 + && elapsedMilliForTransientIssues > getAuthTimeoutInMilliseconds() + && (origSocketTimeout == 0 || elapsedMilliForTransientIssues < origSocketTimeout); + } + + public AtomicBoolean getCanceling() { + return canceling; + } + + public void setCanceling(AtomicBoolean canceling) { + this.canceling = canceling; + } + + public boolean isIncludeRequestGuid() { + return includeRequestGuid; + } + + public void setIncludeRequestGuid(boolean includeRequestGuid) { + this.includeRequestGuid = includeRequestGuid; + } + + public boolean isWithoutCookies() { + return withoutCookies; + } + + public void setWithoutCookies(boolean withoutCookies) { + this.withoutCookies = withoutCookies; + } + + public int isInjectSocketTimeout() { + return injectSocketTimeout; + } + + public void setInjectSocketTimeout(int injectSocketTimeout) { + this.injectSocketTimeout = injectSocketTimeout; + } + + public int getInjectSocketTimeout() { + return injectSocketTimeout; + } + + public boolean isIncludeRetryParameters() { + return includeRetryParameters; + } + + public boolean isUnpackResponse() { + return unpackResponse; + } + + public void setUnpackResponse(boolean unpackResponse) { + this.unpackResponse = unpackResponse; + } + + public void setIncludeRetryParameters(boolean includeRetryParameters) { + this.includeRetryParameters = includeRetryParameters; + } + + public boolean isSkipRetriesBecauseOf200() { + return skipRetriesBecauseOf200; + } + + public void setSkipRetriesBecauseOf200(boolean skipRetriesBecauseOf200) { + this.skipRetriesBecauseOf200 = skipRetriesBecauseOf200; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/HttpExecutingContextBuilder.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/HttpExecutingContextBuilder.java new file mode 100644 index 000000000..d990a9e2a --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/HttpExecutingContextBuilder.java @@ -0,0 +1,285 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/HttpExecutingContextBuilder.java + * + * Permitted differences: package declaration, @SnowflakeJdbcInternalApi annotation removed. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Builder class for {@link HttpExecutingContext}. Provides a fluent interface for constructing + * HttpExecutingContext instances with many optional parameters. + */ +public class HttpExecutingContextBuilder { + private final String requestId; + private final String requestInfoScrubbed; + private long retryTimeout; + private long authTimeout; + private int origSocketTimeout; + private int maxRetries; + private int injectSocketTimeout; + private AtomicBoolean canceling; + private boolean withoutCookies; + private boolean includeRetryParameters; + private boolean includeRequestGuid; + private boolean retryHTTP403; + private boolean noRetry; + private boolean unpackResponse; + private boolean isLoginRequest; + + /** + * Creates a new builder instance with required parameters. + * + * @param requestId Request ID for logging and tracking + * @param requestInfoScrubbed Scrubbed request info for logging + */ + public HttpExecutingContextBuilder(String requestId, String requestInfoScrubbed) { + this.requestId = requestId; + this.requestInfoScrubbed = requestInfoScrubbed; + } + + /** + * Copy constructor to create a new builder from an existing HttpExecutingContext. + * + * @param context The context to copy settings from + */ + public HttpExecutingContextBuilder(HttpExecutingContext context) { + this.requestId = context.getRequestId(); + this.requestInfoScrubbed = context.getRequestInfoScrubbed(); + this.retryTimeout = context.getRetryTimeout(); + this.authTimeout = context.getAuthTimeout(); + this.origSocketTimeout = context.getOrigSocketTimeout(); + this.maxRetries = context.getMaxRetries(); + this.injectSocketTimeout = context.getInjectSocketTimeout(); + this.canceling = context.getCanceling(); + this.withoutCookies = context.isWithoutCookies(); + this.includeRetryParameters = context.isIncludeRetryParameters(); + this.includeRequestGuid = context.isIncludeRequestGuid(); + this.retryHTTP403 = context.isRetryHTTP403(); + this.noRetry = context.isNoRetry(); + this.unpackResponse = context.isUnpackResponse(); + this.isLoginRequest = context.isLoginRequest(); + } + + /** + * Creates a new builder for a login request with common defaults. + * + * @param requestId Request ID for logging and tracking + * @param requestInfoScrubbed Scrubbed request info for logging + * @return A new builder instance configured for login requests + */ + public static HttpExecutingContextBuilder forLogin(String requestId, String requestInfoScrubbed) { + return new HttpExecutingContextBuilder(requestId, requestInfoScrubbed) + .loginRequest(true) + .includeRequestGuid(true) + .retryHTTP403(true); + } + + /** + * Creates a new builder for a query request with common defaults. + * + * @param requestId Request ID for logging and tracking + * @param requestInfoScrubbed Scrubbed request info for logging + * @return A new builder instance configured for query requests + */ + public static HttpExecutingContextBuilder forQuery(String requestId, String requestInfoScrubbed) { + return new HttpExecutingContextBuilder(requestId, requestInfoScrubbed) + .includeRetryParameters(true) + .includeRequestGuid(true) + .unpackResponse(true); + } + + /** + * Creates a new builder for a simple HTTP request with minimal retry settings. + * + * @param requestId Request ID for logging and tracking + * @param requestInfoScrubbed Scrubbed request info for logging + * @return A new builder instance configured for simple requests + */ + public static HttpExecutingContextBuilder forSimpleRequest( + String requestId, String requestInfoScrubbed) { + return new HttpExecutingContextBuilder(requestId, requestInfoScrubbed) + .noRetry(true) + .includeRequestGuid(true); + } + + /** + * Creates a new builder with default settings for retryable requests. + * + * @param requestId Request ID for logging and tracking + * @param requestInfoScrubbed Scrubbed request info for logging + * @return A new builder instance with default retry settings + */ + public static HttpExecutingContextBuilder withRequest( + String requestId, String requestInfoScrubbed) { + return new HttpExecutingContextBuilder(requestId, requestInfoScrubbed); + } + + /** + * Sets the retry timeout in seconds. + * + * @param retryTimeout Retry timeout in seconds + * @return this builder instance + */ + public HttpExecutingContextBuilder retryTimeout(long retryTimeout) { + this.retryTimeout = retryTimeout; + return this; + } + + /** + * Sets the authentication timeout in seconds. + * + * @param authTimeout Authentication timeout in seconds + * @return this builder instance + */ + public HttpExecutingContextBuilder authTimeout(long authTimeout) { + this.authTimeout = authTimeout; + return this; + } + + /** + * Sets the original socket timeout in milliseconds. + * + * @param origSocketTimeout Socket timeout in milliseconds + * @return this builder instance + */ + public HttpExecutingContextBuilder origSocketTimeout(int origSocketTimeout) { + this.origSocketTimeout = origSocketTimeout; + return this; + } + + /** + * Sets the maximum number of retries. + * + * @param maxRetries Maximum number of retries + * @return this builder instance + */ + public HttpExecutingContextBuilder maxRetries(int maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + /** + * Sets the injected socket timeout for testing. + * + * @param injectSocketTimeout Socket timeout to inject + * @return this builder instance + */ + public HttpExecutingContextBuilder injectSocketTimeout(int injectSocketTimeout) { + this.injectSocketTimeout = injectSocketTimeout; + return this; + } + + /** + * Sets the canceling flag. + * + * @param canceling AtomicBoolean for cancellation + * @return this builder instance + */ + public HttpExecutingContextBuilder canceling(AtomicBoolean canceling) { + this.canceling = canceling; + return this; + } + + /** + * Sets whether to disable cookies. + * + * @param withoutCookies true to disable cookies + * @return this builder instance + */ + public HttpExecutingContextBuilder withoutCookies(boolean withoutCookies) { + this.withoutCookies = withoutCookies; + return this; + } + + /** + * Sets whether to include retry parameters in requests. + * + * @param includeRetryParameters true to include retry parameters + * @return this builder instance + */ + public HttpExecutingContextBuilder includeRetryParameters(boolean includeRetryParameters) { + this.includeRetryParameters = includeRetryParameters; + return this; + } + + /** + * Sets whether to include request GUID. + * + * @param includeRequestGuid true to include request GUID + * @return this builder instance + */ + public HttpExecutingContextBuilder includeRequestGuid(boolean includeRequestGuid) { + this.includeRequestGuid = includeRequestGuid; + return this; + } + + /** + * Sets whether to retry on HTTP 403 errors. + * + * @param retryHTTP403 true to retry on HTTP 403 + * @return this builder instance + */ + public HttpExecutingContextBuilder retryHTTP403(boolean retryHTTP403) { + this.retryHTTP403 = retryHTTP403; + return this; + } + + /** + * Sets whether to disable retries. + * + * @param noRetry true to disable retries + * @return this builder instance + */ + public HttpExecutingContextBuilder noRetry(boolean noRetry) { + this.noRetry = noRetry; + return this; + } + + /** + * Sets whether to unpack the response. + * + * @param unpackResponse true to unpack response + * @return this builder instance + */ + public HttpExecutingContextBuilder unpackResponse(boolean unpackResponse) { + this.unpackResponse = unpackResponse; + return this; + } + + /** + * Sets whether this is a login request. + * + * @param isLoginRequest true if this is a login request + * @return this builder instance + */ + public HttpExecutingContextBuilder loginRequest(boolean isLoginRequest) { + this.isLoginRequest = isLoginRequest; + return this; + } + + /** + * Builds and returns a new HttpExecutingContext instance with the configured parameters. + * + * @return A new HttpExecutingContext instance + */ + public HttpExecutingContext build() { + HttpExecutingContext context = new HttpExecutingContext(requestId, requestInfoScrubbed); + context.setRetryTimeout(retryTimeout); + context.setAuthTimeout(authTimeout); + context.setOrigSocketTimeout(origSocketTimeout); + context.setMaxRetries(maxRetries); + context.setInjectSocketTimeout(injectSocketTimeout); + context.setCanceling(canceling); + context.setWithoutCookies(withoutCookies); + context.setIncludeRetryParameters(includeRetryParameters); + context.setIncludeRequestGuid(includeRequestGuid); + context.setRetryHTTP403(retryHTTP403); + context.setNoRetry(noRetry); + context.setUnpackResponse(unpackResponse); + context.setLoginRequest(isLoginRequest); + return context; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/IngestSSLConnectionSocketFactory.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/IngestSSLConnectionSocketFactory.java index a484f14f8..55a946d9d 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/IngestSSLConnectionSocketFactory.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/IngestSSLConnectionSocketFactory.java @@ -75,7 +75,10 @@ private static String[] decideCipherSuites() { // cipher suites need to be picked up in code explicitly for jdk 1.7 // https://stackoverflow.com/questions/44378970/ - logger.trace("Cipher suites used: {}", Arrays.toString(cipherSuites)); + logger.trace( + "Cipher suites used: {}", + (net.snowflake.ingest.streaming.internal.fileTransferAgent.log.ArgSupplier) + () -> Arrays.toString(cipherSuites)); return cipherSuites; } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/JdbcHttpUtil.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/JdbcHttpUtil.java new file mode 100644 index 000000000..c4a5c1ea3 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/JdbcHttpUtil.java @@ -0,0 +1,1187 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/HttpUtil.java + * + * Permitted differences: package declaration, class renamed from HttpUtil to JdbcHttpUtil + * (to avoid collision with ingest's net.snowflake.ingest.utils.HttpUtil), + * import swaps for already-replicated classes, + * @SnowflakeJdbcInternalApi annotations removed, + * SFSSLConnectionSocketFactory -> IngestSSLConnectionSocketFactory, + * SystemUtil.convertSystemPropertyToIntValue inlined, + * SnowflakeDriver.implementVersion -> SnowflakeDriverConstants.implementVersion, + * SnowflakeMutableProxyRoutePlanner uses ingest version (same package), + * SessionUtil.isNewRetryStrategyRequest inlined from JDBC SessionUtil. + * SnowflakeUtil.isNullOrEmpty -> StorageClientUtil.isNullOrEmpty. + * SnowflakeUtil.systemGetProperty -> StorageClientUtil.systemGetProperty. + * net.snowflake.client.log.* -> ...fileTransferAgent.log.* + * net.snowflake.client.util.SecretDetector -> ...log.SecretDetector + * net.snowflake.client.util.Stopwatch -> net.snowflake.ingest.utils.Stopwatch + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import static net.snowflake.ingest.streaming.internal.fileTransferAgent.StorageClientUtil.isNullOrEmpty; +import static net.snowflake.ingest.streaming.internal.fileTransferAgent.StorageClientUtil.systemGetProperty; +import static org.apache.http.client.config.CookieSpecs.DEFAULT; +import static org.apache.http.client.config.CookieSpecs.IGNORE_COOKIES; + +import com.amazonaws.ClientConfiguration; +import com.google.common.annotations.VisibleForTesting; +import com.microsoft.azure.storage.OperationContext; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.net.InetSocketAddress; +import java.net.Proxy; +import java.net.Socket; +import java.net.URI; +import java.security.KeyManagementException; +import java.security.NoSuchAlgorithmException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import javax.annotation.Nullable; +import javax.net.ssl.TrustManager; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.ArgSupplier; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLogger; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLoggerFactory; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLoggerUtil; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SecretDetector; +import net.snowflake.ingest.utils.OCSPMode; +import net.snowflake.ingest.utils.SFSessionProperty; +import net.snowflake.ingest.utils.Stopwatch; +import org.apache.http.HttpHost; +import org.apache.http.auth.AuthScope; +import org.apache.http.auth.Credentials; +import org.apache.http.auth.UsernamePasswordCredentials; +import org.apache.http.client.CredentialsProvider; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.config.Registry; +import org.apache.http.config.RegistryBuilder; +import org.apache.http.conn.socket.ConnectionSocketFactory; +import org.apache.http.conn.socket.PlainConnectionSocketFactory; +import org.apache.http.impl.client.BasicCredentialsProvider; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.DefaultRedirectStrategy; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.impl.conn.PoolingHttpClientConnectionManager; +import org.apache.http.protocol.HttpContext; +import org.apache.http.ssl.SSLInitializationException; + +/** HttpUtil class */ +public class JdbcHttpUtil { + private static final SFLogger logger = SFLoggerFactory.getLogger(JdbcHttpUtil.class); + + static final int DEFAULT_MAX_CONNECTIONS = 300; + static final int DEFAULT_MAX_CONNECTIONS_PER_ROUTE = 300; + private static final int DEFAULT_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MS = 60000; + static final int DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT_IN_MS = 300000; // ms + static final int DEFAULT_TTL = 60; // secs + static final int DEFAULT_IDLE_CONNECTION_TIMEOUT = 5; // secs + static final int DEFAULT_DOWNLOADED_CONDITION_TIMEOUT = 3600; // secs + + public static final String JDBC_TTL = "net.snowflake.jdbc.ttl"; + public static final String JDBC_MAX_CONNECTIONS_PROPERTY = "net.snowflake.jdbc.max_connections"; + public static final String JDBC_MAX_CONNECTIONS_PER_ROUTE_PROPERTY = + "net.snowflake.jdbc.max_connections_per_route"; + private static Duration connectionTimeout; + private static Duration socketTimeout; + + /** + * The unique httpClient shared by all connections. This will benefit long-lived clients. Key = + * proxy host + proxy port + nonProxyHosts, Value = Map of [OCSPMode, HttpClient] + */ + public static Map httpClient = + new ConcurrentHashMap<>(); + + /** + * The unique httpClient map shared by all connections that don't want decompression. This will + * benefit long-lived clients. Key = proxy host + proxy port + nonProxyHosts, Value = Map + * [OCSPMode, HttpClient] + */ + private static Map httpClientWithoutDecompression = + new ConcurrentHashMap<>(); + + /** The map of snowflake route planners */ + static Map httpClientRoutePlanner = + new ConcurrentHashMap<>(); + + /** Handle on the static connection manager, to gather statistics mainly */ + private static PoolingHttpClientConnectionManager connectionManager = null; + + /** default request configuration, to be copied on individual requests. */ + private static RequestConfig DefaultRequestConfig = null; + + private static boolean socksProxyDisabled = false; + + public static void reset() { + httpClient.clear(); + httpClientWithoutDecompression.clear(); + httpClientRoutePlanner.clear(); + } + + public static Duration getConnectionTimeout() { + return connectionTimeout != null + ? connectionTimeout + : Duration.ofMillis(DEFAULT_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MS); + } + + public static Duration getSocketTimeout() { + return socketTimeout != null + ? socketTimeout + : Duration.ofMillis(DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT_IN_MS); + } + + public static void setConnectionTimeout(int timeout) { + connectionTimeout = Duration.ofMillis(timeout); + initDefaultRequestConfig(connectionTimeout.toMillis(), getSocketTimeout().toMillis()); + } + + public static void setSocketTimeout(int timeout) { + socketTimeout = Duration.ofMillis(timeout); + initDefaultRequestConfig(getConnectionTimeout().toMillis(), socketTimeout.toMillis()); + } + + public static long getDownloadedConditionTimeoutInSeconds() { + return DEFAULT_DOWNLOADED_CONDITION_TIMEOUT; + } + + public static void closeExpiredAndIdleConnections() { + if (connectionManager != null) { + synchronized (connectionManager) { + logger.debug("Connection pool stats: {}", connectionManager.getTotalStats()); + connectionManager.closeExpiredConnections(); + connectionManager.closeIdleConnections(DEFAULT_IDLE_CONNECTION_TIMEOUT, TimeUnit.SECONDS); + } + } + } + + /** + * A static function to set S3 proxy params when there is a valid session + * + * @param key key to HttpClient map containing OCSP and proxy info + * @param clientConfig the configuration needed by S3 to set the proxy + * @deprecated Use {@link S3HttpUtil#setProxyForS3(HttpClientSettingsKey, ClientConfiguration)} + * instead + */ + @Deprecated + public static void setProxyForS3(HttpClientSettingsKey key, ClientConfiguration clientConfig) { + S3HttpUtil.setProxyForS3(key, clientConfig); + } + + /** + * A static function to set S3 proxy params for sessionless connections using the proxy params + * from the StageInfo + * + * @param proxyProperties proxy properties + * @param clientConfig the configuration needed by S3 to set the proxy + * @throws SnowflakeSQLException when exception encountered + * @deprecated Use {@link S3HttpUtil#setSessionlessProxyForS3(Properties, ClientConfiguration)} + * instead + */ + @Deprecated + public static void setSessionlessProxyForS3( + Properties proxyProperties, ClientConfiguration clientConfig) throws SnowflakeSQLException { + S3HttpUtil.setSessionlessProxyForS3(proxyProperties, clientConfig); + } + + /** + * A static function to set Azure proxy params for sessionless connections using the proxy params + * from the StageInfo + * + * @param proxyProperties proxy properties + * @param opContext the configuration needed by Azure to set the proxy + * @throws SnowflakeSQLException when invalid proxy properties encountered + */ + public static void setSessionlessProxyForAzure( + Properties proxyProperties, OperationContext opContext) throws SnowflakeSQLException { + if (proxyProperties != null + && proxyProperties.size() > 0 + && proxyProperties.getProperty(SFSessionProperty.USE_PROXY.getPropertyKey()) != null) { + Boolean useProxy = + Boolean.valueOf( + proxyProperties.getProperty(SFSessionProperty.USE_PROXY.getPropertyKey())); + if (useProxy) { + String proxyHost = + proxyProperties.getProperty(SFSessionProperty.PROXY_HOST.getPropertyKey()); + int proxyPort; + try { + proxyPort = + Integer.parseInt( + proxyProperties.getProperty(SFSessionProperty.PROXY_PORT.getPropertyKey())); + } catch (NumberFormatException | NullPointerException e) { + throw new SnowflakeSQLException( + ErrorCode.INVALID_PROXY_PROPERTIES, "Could not parse port number"); + } + Proxy azProxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyPort)); + logger.debug("Setting sessionless Azure proxy. Host: {}, port: {}", proxyHost, proxyPort); + opContext.setProxy(azProxy); + } else { + logger.debug("Omitting sessionless Azure proxy setup as proxy is disabled"); + } + } else { + logger.debug("Omitting sessionless Azure proxy setup"); + } + } + + /** + * A static function to set Azure proxy params when there is a valid session + * + * @param key key to HttpClient map containing OCSP and proxy info + * @param opContext the configuration needed by Azure to set the proxy + */ + public static void setProxyForAzure(HttpClientSettingsKey key, OperationContext opContext) { + if (key != null && key.usesProxy()) { + Proxy azProxy = + new Proxy(Proxy.Type.HTTP, new InetSocketAddress(key.getProxyHost(), key.getProxyPort())); + logger.debug( + "Setting Azure proxy. Host: {}, port: {}", key.getProxyHost(), key.getProxyPort()); + opContext.setProxy(azProxy); + } else { + logger.debug("Omitting Azure proxy setup"); + } + } + + /** + * Constructs a user-agent header with the following pattern: connector_name/connector_version + * (os-platform_info) language_implementation/language_version + * + * @param customSuffix custom suffix that would be appended to user agent to identify the jdbc + * usage. + * @return string for user-agent header + */ + @VisibleForTesting + static String buildUserAgent(String customSuffix) { + // Start with connector name + StringBuilder builder = new StringBuilder("JDBC/"); + // Append connector version and parenthesis start + builder.append(SnowflakeDriverConstants.implementVersion); + builder.append(" ("); + // Generate OS platform and version from system properties + String osPlatform = (systemGetProperty("os.name") != null) ? systemGetProperty("os.name") : ""; + String osVersion = + (systemGetProperty("os.version") != null) ? systemGetProperty("os.version") : ""; + // Append OS platform and version separated by a space + builder.append(osPlatform); + builder.append(" "); + builder.append(osVersion); + // Append language name + builder.append(") JAVA/"); + // Generate string for language version from system properties and append it + String languageVersion = + (systemGetProperty("java.version") != null) ? systemGetProperty("java.version") : ""; + builder.append(languageVersion); + if (!customSuffix.isEmpty()) { + builder.append(" " + customSuffix); + } + String userAgent = builder.toString(); + return userAgent; + } + + /** + * Build an Http client using our set of default. + * + * @param key Key to HttpClient hashmap containing OCSP mode and proxy information, could be null + * @param ocspCacheFile OCSP response cache file. If null, the default OCSP response file will be + * used. + * @param downloadUnCompressed Whether the HTTP client should be built requesting no decompression + * @return HttpClient object + */ + public static CloseableHttpClient buildHttpClient( + @Nullable HttpClientSettingsKey key, File ocspCacheFile, boolean downloadUnCompressed) { + return buildHttpClient(key, ocspCacheFile, downloadUnCompressed, null); + } + + /** + * Build an Http client using our set of default. + * + * @param key Key to HttpClient hashmap containing OCSP mode and proxy information, could be null + * @param ocspCacheFile OCSP response cache file. If null, the default OCSP response file will be + * used. + * @param downloadUnCompressed Whether the HTTP client should be built requesting no decompression + * @param httpHeadersCustomizers List of HTTP headers customizers + * @return HttpClient object + */ + public static CloseableHttpClient buildHttpClient( + @Nullable HttpClientSettingsKey key, + File ocspCacheFile, + boolean downloadUnCompressed, + List httpHeadersCustomizers) { + logger.debug( + "Building http client with client settings key: {}, ocsp cache file: {}, download" + + " uncompressed: {}", + key != null ? key.toString() : null, + ocspCacheFile, + downloadUnCompressed); + // set timeout so that we don't wait forever. + // Setup the default configuration for all requests on this client + int timeToLive = convertSystemPropertyToIntValue(JDBC_TTL, DEFAULT_TTL); + long connectTimeout = getConnectionTimeout().toMillis(); + long socketTimeout = getSocketTimeout().toMillis(); + logger.debug( + "Connection pooling manager connect timeout: {} ms, socket timeout: {} ms, ttl: {} s", + connectTimeout, + socketTimeout, + timeToLive); + + // Create default request config without proxy since different connections could use different + // proxies in multi tenant environments + // Proxy is set later with route planner + if (DefaultRequestConfig == null) { + initDefaultRequestConfig(connectTimeout, socketTimeout); + } + + TrustManager[] trustManagers = null; + if (key != null && key.getOcspMode() != OCSPMode.DISABLE_OCSP_CHECKS) { + // A custom TrustManager is required only if disableOCSPChecks is disabled, + // which is by default in the production. disableOCSPChecks can be enabled + // 1) OCSP service is down for reasons, 2) PowerMock test that doesn't + // care OCSP checks. + // OCSP FailOpen is ON by default + try { + if (ocspCacheFile == null) { + logger.debug("Instantiating trust manager with default ocsp cache file"); + } else { + logger.debug("Instantiating trust manager with ocsp cache file: {}", ocspCacheFile); + } + TrustManager[] tm = {new SFTrustManager(key, ocspCacheFile)}; + trustManagers = tm; + } catch (Exception | Error err) { + // dump error stack + StringWriter errors = new StringWriter(); + err.printStackTrace(new PrintWriter(errors)); + logger.error(errors.toString(), true); + throw new RuntimeException(err); // rethrow the exception + } + } else if (key != null) { + logger.debug( + "Omitting trust manager instantiation as OCSP mode is set to {}", key.getOcspMode()); + } else { + logger.debug("Omitting trust manager instantiation as configuration is not provided"); + } + try { + logger.debug( + "Registering https connection socket factory with socks proxy disabled: {} and http " + + "connection socket factory", + socksProxyDisabled); + + Registry registry = + RegistryBuilder.create() + .register( + "https", new IngestSSLConnectionSocketFactory(trustManagers, socksProxyDisabled)) + .register("http", new SFConnectionSocketFactory()) + .build(); + + // Build a connection manager with enough connections + connectionManager = + new PoolingHttpClientConnectionManager( + registry, null, null, null, timeToLive, TimeUnit.SECONDS); + int maxConnections = + convertSystemPropertyToIntValue(JDBC_MAX_CONNECTIONS_PROPERTY, DEFAULT_MAX_CONNECTIONS); + int maxConnectionsPerRoute = + convertSystemPropertyToIntValue( + JDBC_MAX_CONNECTIONS_PER_ROUTE_PROPERTY, DEFAULT_MAX_CONNECTIONS_PER_ROUTE); + logger.debug( + "Max connections total in connection pooling manager: {}; max connections per route: {}", + maxConnections, + maxConnectionsPerRoute); + connectionManager.setMaxTotal(maxConnections); + connectionManager.setDefaultMaxPerRoute(maxConnectionsPerRoute); + + logger.debug("Disabling cookie management for http client"); + String userAgentSuffix = key != null ? key.getUserAgentSuffix() : ""; + HttpClientBuilder httpClientBuilder = + HttpClientBuilder.create() + .setConnectionManager(connectionManager) + // Support JVM proxy settings + .useSystemProperties() + .setRedirectStrategy(new DefaultRedirectStrategy()) + .setUserAgent(buildUserAgent(userAgentSuffix)) // needed for Okta + .disableCookieManagement() // SNOW-39748 + .setDefaultRequestConfig(DefaultRequestConfig); + if (key != null && key.usesProxy()) { + HttpHost proxy = + new HttpHost( + key.getProxyHost(), key.getProxyPort(), key.getProxyHttpProtocol().getScheme()); + logger.debug( + "Configuring proxy and route planner - host: {}, port: {}, scheme: {}, nonProxyHosts:" + + " {}", + key.getProxyHost(), + key.getProxyPort(), + key.getProxyHttpProtocol().getScheme(), + key.getNonProxyHosts()); + // use the custom proxy properties + SnowflakeMutableProxyRoutePlanner sdkProxyRoutePlanner = + httpClientRoutePlanner.computeIfAbsent( + key, + k -> + new SnowflakeMutableProxyRoutePlanner( + key.getProxyHost(), + key.getProxyPort(), + key.getProxyHttpProtocol(), + key.getNonProxyHosts())); + httpClientBuilder.setProxy(proxy).setRoutePlanner(sdkProxyRoutePlanner); + if (!isNullOrEmpty(key.getProxyUser()) && !isNullOrEmpty(key.getProxyPassword())) { + Credentials credentials = + new UsernamePasswordCredentials(key.getProxyUser(), key.getProxyPassword()); + AuthScope authScope = new AuthScope(key.getProxyHost(), key.getProxyPort()); + CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + logger.debug( + "Using user: {}, password is {} for proxy host: {}, port: {}", + key.getProxyUser(), + SFLoggerUtil.isVariableProvided(key.getProxyPassword()), + key.getProxyHost(), + key.getProxyPort()); + credentialsProvider.setCredentials(authScope, credentials); + httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider); + } + } + if (downloadUnCompressed) { + logger.debug("Disabling content compression for http client"); + httpClientBuilder.disableContentCompression(); + } + if (httpHeadersCustomizers != null && !httpHeadersCustomizers.isEmpty()) { + logger.debug("Setting up http headers customizers"); + httpClientBuilder.setRetryHandler(new AttributeEnhancingHttpRequestRetryHandler()); + httpClientBuilder.addInterceptorLast( + new HeaderCustomizerHttpRequestInterceptor(httpHeadersCustomizers)); + } + return httpClientBuilder.build(); + } catch (NoSuchAlgorithmException | KeyManagementException ex) { + throw new SSLInitializationException(ex.getMessage(), ex); + } + } + + private static void initDefaultRequestConfig(long connectTimeout, long socketTimeout) { + RequestConfig.Builder builder = + RequestConfig.custom() + .setConnectTimeout((int) connectTimeout) + .setConnectionRequestTimeout((int) connectTimeout) + .setSocketTimeout((int) socketTimeout); + logger.debug( + "Rebuilding request config. Connect timeout: {} ms, connection request timeout: {} ms," + + " socket timeout: {} ms", + connectTimeout, + connectTimeout, + socketTimeout); + DefaultRequestConfig = builder.build(); + } + + public static void updateRoutePlanner(HttpClientSettingsKey key) { + if (httpClientRoutePlanner.containsKey(key) + && !httpClientRoutePlanner + .get(key) + .getNonProxyHosts() + .equalsIgnoreCase(key.getNonProxyHosts())) { + logger.debug( + "Updating route planner non-proxy hosts for proxy: {}:{} to: {}", + key.getProxyHost(), + key.getProxyPort(), + key.getNonProxyHosts()); + httpClientRoutePlanner.get(key).setNonProxyHosts(key.getNonProxyHosts()); + } + } + + /** + * Gets HttpClient with insecureMode false + * + * @param ocspAndProxyKey OCSP mode and proxy settings for httpclient + * @return HttpClient object shared across all connections + */ + public static CloseableHttpClient getHttpClient(HttpClientSettingsKey ocspAndProxyKey) { + return initHttpClient(ocspAndProxyKey, null, null); + } + + /** + * Gets HttpClient with insecureMode false and disabling decompression + * + * @param ocspAndProxyKey OCSP mode and proxy settings for httpclient + * @return HttpClient object shared across all connections + */ + public static CloseableHttpClient getHttpClientWithoutDecompression( + HttpClientSettingsKey ocspAndProxyKey) { + return initHttpClientWithoutDecompression(ocspAndProxyKey, null, null); + } + + /** + * Gets HttpClient with insecureMode false + * + * @param ocspAndProxyKey OCSP mode and proxy settings for httpclient + * @param httpHeadersCustomizers List of HTTP headers customizers + * @return HttpClient object shared across all connections + */ + public static CloseableHttpClient getHttpClient( + HttpClientSettingsKey ocspAndProxyKey, List httpHeadersCustomizers) { + return initHttpClient(ocspAndProxyKey, null, httpHeadersCustomizers); + } + + /** + * Gets HttpClient with insecureMode false and disabling decompression + * + * @param ocspAndProxyKey OCSP mode and proxy settings for httpclient + * @param httpHeadersCustomizers List of HTTP headers customizers + * @return HttpClient object shared across all connections + */ + public static CloseableHttpClient getHttpClientWithoutDecompression( + HttpClientSettingsKey ocspAndProxyKey, List httpHeadersCustomizers) { + return initHttpClientWithoutDecompression(ocspAndProxyKey, null, httpHeadersCustomizers); + } + + /** + * Accessor for the HTTP client singleton. + * + * @param key contains information needed to build specific HttpClient + * @param ocspCacheFile OCSP response cache file name. if null, the default file will be used. + * @return HttpClient object shared across all connections + */ + public static CloseableHttpClient initHttpClientWithoutDecompression( + HttpClientSettingsKey key, File ocspCacheFile) { + updateRoutePlanner(key); + return httpClientWithoutDecompression.computeIfAbsent( + key, k -> buildHttpClient(key, ocspCacheFile, true, null)); + } + + /** + * Accessor for the HTTP client singleton. + * + * @param key contains information needed to build specific HttpClient + * @param ocspCacheFile OCSP response cache file name. if null, the default file will be used. + * @return HttpClient object shared across all connections + */ + public static CloseableHttpClient initHttpClient(HttpClientSettingsKey key, File ocspCacheFile) { + updateRoutePlanner(key); + return httpClient.computeIfAbsent( + key, k -> buildHttpClient(key, ocspCacheFile, key.getGzipDisabled(), null)); + } + + /** + * Accessor for the HTTP client singleton. + * + * @param key contains information needed to build specific HttpClient + * @param ocspCacheFile OCSP response cache file name. if null, the default file will be used. + * @param httpHeadersCustomizers List of HTTP headers customizers + * @return HttpClient object shared across all connections + */ + public static CloseableHttpClient initHttpClientWithoutDecompression( + HttpClientSettingsKey key, + File ocspCacheFile, + List httpHeadersCustomizers) { + updateRoutePlanner(key); + return httpClientWithoutDecompression.computeIfAbsent( + key, k -> buildHttpClient(key, ocspCacheFile, true, httpHeadersCustomizers)); + } + + /** + * Accessor for the HTTP client singleton. + * + * @param key contains information needed to build specific HttpClient + * @param ocspCacheFile OCSP response cache file name. if null, the default file will be used. + * @param httpHeadersCustomizers List of HTTP headers customizers + * @return HttpClient object shared across all connections + */ + public static CloseableHttpClient initHttpClient( + HttpClientSettingsKey key, + File ocspCacheFile, + List httpHeadersCustomizers) { + updateRoutePlanner(key); + return httpClient.computeIfAbsent( + key, + k -> buildHttpClient(key, ocspCacheFile, key.getGzipDisabled(), httpHeadersCustomizers)); + } + + /** + * Return a request configuration inheriting from the default request configuration of the shared + * HttpClient with a different socket timeout. + * + * @param soTimeoutMs - custom socket timeout in milli-seconds + * @param withoutCookies - whether this request should ignore cookies or not + * @return RequestConfig object + */ + public static RequestConfig getDefaultRequestConfigWithSocketTimeout( + int soTimeoutMs, boolean withoutCookies) { + final String cookieSpec = withoutCookies ? IGNORE_COOKIES : DEFAULT; + return RequestConfig.copy(DefaultRequestConfig) + .setSocketTimeout(soTimeoutMs) + .setCookieSpec(cookieSpec) + .build(); + } + + /** + * Return a request configuration inheriting from the default request configuration of the shared + * HttpClient with a different socket and connect timeout. + * + * @param requestSocketAndConnectTimeout - custom socket and connect timeout in milli-seconds + * @param withoutCookies - whether this request should ignore cookies or not + * @return RequestConfig object + */ + public static RequestConfig getDefaultRequestConfigWithSocketAndConnectTimeout( + int requestSocketAndConnectTimeout, boolean withoutCookies) { + final String cookieSpec = withoutCookies ? IGNORE_COOKIES : DEFAULT; + return RequestConfig.copy(DefaultRequestConfig) + .setSocketTimeout(requestSocketAndConnectTimeout) + .setConnectTimeout(requestSocketAndConnectTimeout) + .setCookieSpec(cookieSpec) + .build(); + } + + /** + * Return a request configuration inheriting from the default request configuration of the shared + * HttpClient with the cookie spec set to ignore. + * + * @return RequestConfig object + */ + public static RequestConfig getRequestConfigWithoutCookies() { + return RequestConfig.copy(DefaultRequestConfig).setCookieSpec(IGNORE_COOKIES).build(); + } + + public static void setRequestConfig(RequestConfig requestConfig) { + logger.debug("Setting default request config to: {}", requestConfig); + DefaultRequestConfig = requestConfig; + } + + /** + * Accessor for the HTTP client singleton. + * + * @return HTTP Client stats in string representation + */ + private static String getHttpClientStats() { + return connectionManager == null ? "" : connectionManager.getTotalStats().toString(); + } + + /** + * Enables/disables use of the SOCKS proxy when creating sockets + * + * @param socksProxyDisabled new value + */ + public static void setSocksProxyDisabled(boolean socksProxyDisabled) { + logger.debug("Setting socks proxy disabled to {}", socksProxyDisabled); + JdbcHttpUtil.socksProxyDisabled = socksProxyDisabled; + } + + /** + * Returns whether the SOCKS proxy is disabled for this JVM + * + * @return whether the SOCKS proxy is disabled + */ + public static boolean isSocksProxyDisabled() { + return JdbcHttpUtil.socksProxyDisabled; + } + + /** + * Executes an HTTP request with the cookie spec set to IGNORE_COOKIES + * + * @param httpRequest HttpRequestBase + * @param retryTimeout retry timeout + * @param authTimeout authenticator specific timeout + * @param socketTimeout socket timeout (in ms) + * @param retryCount max retry count for the request - if it is set to 0, it will be ignored and + * only retryTimeout will determine when to end the retries + * @param injectSocketTimeout injecting socket timeout + * @param canceling canceling? + * @param ocspAndProxyKey OCSP mode and proxy settings for httpclient + * @return response + * @throws SnowflakeSQLException if Snowflake error occurs + * @throws IOException raises if a general IO error occurs + */ + static String executeRequestWithoutCookies( + HttpRequestBase httpRequest, + int retryTimeout, + int authTimeout, + int socketTimeout, + int retryCount, + int injectSocketTimeout, + AtomicBoolean canceling, + HttpClientSettingsKey ocspAndProxyKey) + throws SnowflakeSQLException, IOException { + logger.debug("Executing request without cookies"); + return executeRequestInternal( + httpRequest, + retryTimeout, + authTimeout, + socketTimeout, + retryCount, + injectSocketTimeout, + canceling, + true, // no cookie + false, // no retry parameter + true, // guid? (do we need this?) + false, // no retry on HTTP 403 + getHttpClient(ocspAndProxyKey, null), + new ExecTimeTelemetryData(), + null); + } + + /** + * Executes an HTTP request for Snowflake. + * + * @param httpRequest HttpRequestBase + * @param retryTimeout retry timeout + * @param authTimeout authenticator specific timeout + * @param socketTimeout socket timeout (in ms) + * @param retryCount max retry count for the request - if it is set to 0, it will be ignored and + * only retryTimeout will determine when to end the retries + * @param ocspAndProxyAndGzipKey OCSP mode and proxy settings for httpclient + * @return response + * @throws SnowflakeSQLException if Snowflake error occurs + * @throws IOException raises if a general IO error occurs + */ + public static String executeGeneralRequest( + HttpRequestBase httpRequest, + int retryTimeout, + int authTimeout, + int socketTimeout, + int retryCount, + HttpClientSettingsKey ocspAndProxyAndGzipKey) + throws SnowflakeSQLException, IOException { + return executeGeneralRequest( + httpRequest, + retryTimeout, + authTimeout, + socketTimeout, + retryCount, + ocspAndProxyAndGzipKey, + null); + } + + public static String executeGeneralRequestOmitRequestGuid( + HttpRequestBase httpRequest, + int retryTimeout, + int authTimeout, + int socketTimeout, + int retryCount, + HttpClientSettingsKey ocspAndProxyAndGzipKey) + throws SnowflakeSQLException, IOException { + return executeRequestInternal( + httpRequest, + retryTimeout, + authTimeout, + socketTimeout, + retryCount, + 0, + null, + false, + false, + false, + false, + getHttpClient(ocspAndProxyAndGzipKey, null), + new ExecTimeTelemetryData(), + null); + } + + /** + * Executes an HTTP request for Snowflake. + * + * @param httpRequest HttpRequestBase + * @param retryTimeout retry timeout + * @param authTimeout authenticator specific timeout + * @param socketTimeout socket timeout (in ms) + * @param retryCount max retry count for the request - if it is set to 0, it will be ignored and + * only retryTimeout will determine when to end the retries + * @param ocspAndProxyAndGzipKey OCSP mode and proxy settings for httpclient + * @param retryContextManager RetryContext used to customize retry handling functionality + * @return response + * @throws SnowflakeSQLException if Snowflake error occurs + * @throws IOException raises if a general IO error occurs + */ + public static String executeGeneralRequest( + HttpRequestBase httpRequest, + int retryTimeout, + int authTimeout, + int socketTimeout, + int retryCount, + HttpClientSettingsKey ocspAndProxyAndGzipKey, + RetryContextManager retryContextManager) + throws SnowflakeSQLException, IOException { + logger.debug("Executing general request"); + return executeRequest( + httpRequest, + retryTimeout, + authTimeout, + socketTimeout, + retryCount, + 0, // no inject socket timeout + null, // no canceling + false, // no retry parameter + false, // no retry on HTTP 403 + ocspAndProxyAndGzipKey, + new ExecTimeTelemetryData(), + retryContextManager); + } + + /** + * Executes an HTTP request for Snowflake + * + * @param httpRequest HttpRequestBase + * @param retryTimeout retry timeout + * @param authTimeout authenticator specific timeout + * @param socketTimeout socket timeout (in ms) + * @param retryCount max retry count for the request - if it is set to 0, it will be ignored and + * only retryTimeout will determine when to end the retries + * @param httpClient client object used to communicate with other machine + * @return response + * @throws SnowflakeSQLException if Snowflake error occurs + * @throws IOException raises if a general IO error occurs + */ + public static String executeGeneralRequest( + HttpRequestBase httpRequest, + int retryTimeout, + int authTimeout, + int socketTimeout, + int retryCount, + CloseableHttpClient httpClient) + throws SnowflakeSQLException, IOException { + logger.debug("Executing general request"); + return executeRequestInternal( + httpRequest, + retryTimeout, + authTimeout, + socketTimeout, + retryCount, + 0, // no inject socket timeout + null, // no canceling + false, // with cookie + false, // no retry parameter + true, // include request GUID + false, // no retry on HTTP 403 + httpClient, + new ExecTimeTelemetryData(), + null); + } + + /** + * Executes an HTTP request for Snowflake. + * + * @param httpRequest HttpRequestBase + * @param retryTimeout retry timeout + * @param authTimeout authenticator timeout + * @param socketTimeout socket timeout (in ms) + * @param maxRetries retry count for the request + * @param injectSocketTimeout injecting socket timeout + * @param canceling canceling? + * @param includeRetryParameters whether to include retry parameters in retried requests + * @param retryOnHTTP403 whether to retry on HTTP 403 or not + * @param ocspAndProxyKey OCSP mode and proxy settings for httpclient + * @param execTimeData query execution time telemetry data object + * @return response + * @throws SnowflakeSQLException if Snowflake error occurs + * @throws IOException raises if a general IO error occurs + */ + public static String executeRequest( + HttpRequestBase httpRequest, + int retryTimeout, + int authTimeout, + int socketTimeout, + int maxRetries, + int injectSocketTimeout, + AtomicBoolean canceling, + boolean includeRetryParameters, + boolean retryOnHTTP403, + HttpClientSettingsKey ocspAndProxyKey, + ExecTimeTelemetryData execTimeData) + throws SnowflakeSQLException, IOException { + return executeRequest( + httpRequest, + retryTimeout, + authTimeout, + socketTimeout, + maxRetries, + injectSocketTimeout, + canceling, + includeRetryParameters, + retryOnHTTP403, + ocspAndProxyKey, + execTimeData, + null); + } + + /** + * Executes an HTTP request for Snowflake. + * + * @param httpRequest HttpRequestBase + * @param retryTimeout retry timeout + * @param authTimeout authenticator timeout + * @param socketTimeout socket timeout (in ms) + * @param maxRetries retry count for the request + * @param injectSocketTimeout injecting socket timeout + * @param canceling canceling? + * @param includeRetryParameters whether to include retry parameters in retried requests + * @param retryOnHTTP403 whether to retry on HTTP 403 or not + * @param ocspAndProxyKey OCSP mode and proxy settings for httpclient + * @param execTimeData query execution time telemetry data object + * @param retryContextManager RetryContext used to customize retry handling functionality + * @return response + * @throws SnowflakeSQLException if Snowflake error occurs + * @throws IOException raises if a general IO error occurs + */ + public static String executeRequest( + HttpRequestBase httpRequest, + int retryTimeout, + int authTimeout, + int socketTimeout, + int maxRetries, + int injectSocketTimeout, + AtomicBoolean canceling, + boolean includeRetryParameters, + boolean retryOnHTTP403, + HttpClientSettingsKey ocspAndProxyKey, + ExecTimeTelemetryData execTimeData, + RetryContextManager retryContextManager) + throws SnowflakeSQLException, IOException { + boolean ocspEnabled = !(ocspAndProxyKey.getOcspMode().equals(OCSPMode.DISABLE_OCSP_CHECKS)); + logger.debug("Executing request with OCSP enabled: {}", ocspEnabled); + execTimeData.setOCSPStatus(ocspEnabled); + return executeRequestInternal( + httpRequest, + retryTimeout, + authTimeout, + socketTimeout, + maxRetries, + injectSocketTimeout, + canceling, + false, // with cookie (do we need cookie?) + includeRetryParameters, + true, // include request GUID + retryOnHTTP403, + getHttpClient(ocspAndProxyKey, null), + execTimeData, + retryContextManager); + } + + /** + * Helper to execute a request with retry and check and throw exception if response is not + * success. This should be used only for small request has it execute the REST request and get + * back the result as a string. + * + *

Connection under the httpRequest is released. + * + * @param httpRequest request object contains all the information + * @param retryTimeout retry timeout (in seconds) + * @param authTimeout authenticator specific timeout (in seconds) + * @param socketTimeout socket timeout (in ms) + * @param maxRetries retry count for the request + * @param injectSocketTimeout simulate socket timeout + * @param canceling canceling flag + * @param withoutCookies whether this request should ignore cookies + * @param includeRetryParameters whether to include retry parameters in retried requests + * @param includeRequestGuid whether to include request_guid + * @param retryOnHTTP403 whether to retry on HTTP 403 + * @param httpClient client object used to communicate with other machine + * @param retryContextManager RetryContext used to customize retry handling functionality + * @return response in String + * @throws SnowflakeSQLException if Snowflake error occurs + * @throws IOException raises if a general IO error occurs + */ + private static String executeRequestInternal( + HttpRequestBase httpRequest, + int retryTimeout, + int authTimeout, + int socketTimeout, + int maxRetries, + int injectSocketTimeout, + AtomicBoolean canceling, + boolean withoutCookies, + boolean includeRetryParameters, + boolean includeRequestGuid, + boolean retryOnHTTP403, + CloseableHttpClient httpClient, + ExecTimeTelemetryData execTimeData, + RetryContextManager retryContextManager) + throws SnowflakeSQLException, IOException { + String requestInfoScrubbed = SecretDetector.maskSASToken(httpRequest.toString()); + String responseText = ""; + + logger.debug( + "Pool: {} Executing: {}", + (ArgSupplier) JdbcHttpUtil::getHttpClientStats, + requestInfoScrubbed); + + CloseableHttpResponse response = null; + Stopwatch stopwatch = null; + + String requestIdStr = URLUtil.getRequestIdLogStr(httpRequest.getURI()); + HttpExecutingContext context = + HttpExecutingContextBuilder.forSimpleRequest(requestIdStr, requestInfoScrubbed) + .retryTimeout(retryTimeout) + .authTimeout(authTimeout) + .origSocketTimeout(socketTimeout) + .maxRetries(maxRetries) + .injectSocketTimeout(injectSocketTimeout) + .canceling(canceling) + .withoutCookies(withoutCookies) + .includeRetryParameters(includeRetryParameters) + .includeRequestGuid(includeRequestGuid) + .retryHTTP403(retryOnHTTP403) + .unpackResponse(true) + .noRetry(false) + .loginRequest(isNewRetryStrategyRequest(httpRequest)) + .build(); + responseText = + RestRequest.executeWithRetries( + httpClient, httpRequest, context, execTimeData, retryContextManager) + .getUnpackedCloseableHttpResponse(); + + logger.debug( + "Pool: {} Request returned for: {} took {} ms", + (ArgSupplier) JdbcHttpUtil::getHttpClientStats, + requestInfoScrubbed, + stopwatch == null ? "n/a" : stopwatch.elapsedMillis()); + + return responseText; + } + + // This is a workaround for JDK-7036144. + // + // The GZIPInputStream prematurely closes its input if a) it finds + // a whole GZIP block and b) input.available() returns 0. In order + // to work around this issue, we inject a thin wrapper for the + // InputStream whose available() method always returns at least 1. + // + // Further details on this bug: + // http://bugs.java.com/bugdatabase/view_bug.do?bug_id=7036144 + public static final class HttpInputStream extends InputStream { + private final InputStream httpIn; + + public HttpInputStream(InputStream httpIn) { + this.httpIn = httpIn; + } + + // This is the only modified function, all other + // methods are simple wrapper around the HTTP stream. + @Override + public final int available() throws IOException { + int available = httpIn.available(); + return available == 0 ? 1 : available; + } + + // ONLY WRAPPER METHODS FROM HERE ON. + @Override + public final int read() throws IOException { + return httpIn.read(); + } + + @Override + public final int read(byte b[]) throws IOException { + return httpIn.read(b); + } + + @Override + public final int read(byte b[], int off, int len) throws IOException { + return httpIn.read(b, off, len); + } + + @Override + public final long skip(long n) throws IOException { + return httpIn.skip(n); + } + + @Override + public final void close() throws IOException { + httpIn.close(); + } + + @Override + public synchronized void mark(int readlimit) { + httpIn.mark(readlimit); + } + + @Override + public synchronized void reset() throws IOException { + httpIn.reset(); + } + + @Override + public final boolean markSupported() { + return httpIn.markSupported(); + } + } + + static final class SFConnectionSocketFactory extends PlainConnectionSocketFactory { + @Override + public Socket createSocket(HttpContext ctx) throws IOException { + if (socksProxyDisabled) { + logger.trace("Creating socket with no proxy"); + return new Socket(Proxy.NO_PROXY); + } + logger.trace("Creating socket with proxy"); + return super.createSocket(ctx); + } + } + + /** + * Helper function to attach additional headers to a request if present. This takes a (nullable) + * map of headers in format and adds them to the incoming request using addHeader. + * + *

Snowsight uses this to attach headers with additional telemetry information, see + * https://snowflakecomputing.atlassian.net/wiki/spaces/EN/pages/2960557006/GS+Communication + * + * @param request The request to add headers to. Must not be null. + * @param additionalHeaders The headers to add. May be null. + */ + static void applyAdditionalHeadersForSnowsight( + HttpRequestBase request, Map additionalHeaders) { + if (additionalHeaders != null && !additionalHeaders.isEmpty()) { + additionalHeaders.forEach(request::addHeader); + } + } + + /** + * Inlined from SystemUtil.convertSystemPropertyToIntValue. Source: + * https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/SystemUtil.java + * + * @param systemProperty name of the system property + * @param defaultValue default value used + * @return the value of the system property, else the default value + */ + static int convertSystemPropertyToIntValue(String systemProperty, int defaultValue) { + String systemPropertyValue = systemGetProperty(systemProperty); + int returnVal = defaultValue; + if (systemPropertyValue != null) { + try { + returnVal = Integer.parseInt(systemPropertyValue); + } catch (NumberFormatException ex) { + logger.warn( + "Failed to parse the system parameter {} with value {}", + systemProperty, + systemPropertyValue); + } + } + return returnVal; + } + + /** + * Inlined from SessionUtil.isNewRetryStrategyRequest. Source: + * https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/SessionUtil.java + * + *

Helper method to check if the request path is a login/auth request to use for retry + * strategy. + * + * @param request the post request + * @return true if this is a login/auth request, false otherwise + */ + public static boolean isNewRetryStrategyRequest(HttpRequestBase request) { + // These constants are from SessionUtil in JDBC + String SF_PATH_LOGIN_REQUEST = "/session/v1/login-request"; + String SF_PATH_AUTHENTICATOR_REQUEST = "/session/authenticator-request"; + String SF_PATH_TOKEN_REQUEST = "/session/token-request"; + String SF_PATH_OKTA_TOKEN_REQUEST_SUFFIX = "/api/v1/authn"; + String SF_PATH_OKTA_SSO_REQUEST_SUFFIX = "/sso/saml"; + + URI requestURI = request.getURI(); + String requestPath = requestURI.getPath(); + if (requestPath != null) { + return requestPath.equals(SF_PATH_LOGIN_REQUEST) + || requestPath.equals(SF_PATH_AUTHENTICATOR_REQUEST) + || requestPath.equals(SF_PATH_TOKEN_REQUEST) + || requestPath.contains(SF_PATH_OKTA_TOKEN_REQUEST_SUFFIX) + || requestPath.contains(SF_PATH_OKTA_SSO_REQUEST_SUFFIX); + } + return false; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/OCSPErrorCode.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/OCSPErrorCode.java new file mode 100644 index 000000000..43afa4da0 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/OCSPErrorCode.java @@ -0,0 +1,26 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/jdbc/OCSPErrorCode.java + * + * Permitted differences: package declaration. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +public enum OCSPErrorCode { + CERTIFICATE_STATUS_GOOD, + CERTIFICATE_STATUS_REVOKED, + CERTIFICATE_STATUS_UNKNOWN, + OCSP_CACHE_DOWNLOAD_TIMEOUT, + OCSP_RESPONSE_FETCH_TIMEOUT, + OCSP_RESPONSE_FETCH_FAILURE, + INVALID_CACHE_SERVER_URL, + EXPIRED_OCSP_SIGNING_CERTIFICATE, + INVALID_CERTIFICATE_SIGNATURE, + INVALID_OCSP_RESPONSE_SIGNATURE, + INVALID_OCSP_RESPONSE_VALIDITY, + INVALID_OCSP_RESPONSE, + REVOCATION_CHECK_FAILURE, + INVALID_SSD, + NO_OCSP_URL_ATTACHED, + NO_ROOTCA_FOUND +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/OCSPTelemetryData.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/OCSPTelemetryData.java new file mode 100644 index 000000000..d791e014b --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/OCSPTelemetryData.java @@ -0,0 +1,79 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/OCSPTelemetryData.java + * + * Permitted differences: package declaration, + * OCSPMode -> net.snowflake.ingest.utils.OCSPMode, + * TelemetryService uses ingest replicated version (same package). + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import java.security.cert.CertificateException; +import net.minidev.json.JSONObject; +import net.snowflake.ingest.utils.OCSPMode; + +public class OCSPTelemetryData { + private String certId; + private String sfcPeerHost; + private String ocspUrl; + private String ocspReq; + private Boolean cacheEnabled; + private Boolean cacheHit; + private OCSPMode ocspMode; + + public OCSPTelemetryData() { + this.ocspMode = OCSPMode.FAIL_OPEN; + this.cacheEnabled = true; + } + + public void setCertId(String certId) { + this.certId = certId; + } + + public void setSfcPeerHost(String sfcPeerHost) { + this.sfcPeerHost = sfcPeerHost; + } + + public void setOcspUrl(String ocspUrl) { + this.ocspUrl = ocspUrl; + } + + public void setOcspReq(String ocspReq) { + this.ocspReq = ocspReq; + } + + public void setCacheEnabled(Boolean cacheEnabled) { + this.cacheEnabled = cacheEnabled; + if (!cacheEnabled) { + this.cacheHit = false; + } + } + + public void setCacheHit(Boolean cacheHit) { + if (!this.cacheEnabled) { + this.cacheHit = false; + } else { + this.cacheHit = cacheHit; + } + } + + public void setOCSPMode(OCSPMode ocspMode) { + this.ocspMode = ocspMode; + } + + public String generateTelemetry(String eventType, CertificateException ex) { + JSONObject value = new JSONObject(); + String valueStr; + value.put("eventType", eventType); + value.put("sfcPeerHost", this.sfcPeerHost); + value.put("certId", this.certId); + value.put("ocspResponderURL", this.ocspUrl); + value.put("ocspReqBase64", this.ocspReq); + value.put("ocspMode", this.ocspMode.name()); + value.put("cacheEnabled", this.cacheEnabled); + value.put("cacheHit", this.cacheHit); + valueStr = value.toString(); // Avoid adding exception stacktrace to user logs. + TelemetryService.getInstance().logOCSPExceptionTelemetryEvent(eventType, value, ex); + return valueStr; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/RestRequest.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/RestRequest.java new file mode 100644 index 000000000..d907e7f7d --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/RestRequest.java @@ -0,0 +1,1304 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/jdbc/RestRequest.java + * + * Permitted differences: package declaration, import swaps for already-replicated classes, + * @SnowflakeJdbcInternalApi annotation removed. + * SnowflakeUtil.isNullOrEmpty -> StorageClientUtil.isNullOrEmpty (already in package). + * SnowflakeUtil.logResponseDetails -> FQN net.snowflake.client.jdbc.SnowflakeUtil.logResponseDetails. + * SessionUtil.isNewRetryStrategyRequest -> JdbcHttpUtil.isNewRetryStrategyRequest. + * HttpUtil references swapped to JdbcHttpUtil (replicated in same package). + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import static net.snowflake.ingest.streaming.internal.fileTransferAgent.StorageClientUtil.isNullOrEmpty; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.SSLKeyException; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLProtocolException; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.ArgSupplier; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.Event; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.EventUtil; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLogger; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLoggerFactory; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SecretDetector; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.UUIDUtils; +import net.snowflake.ingest.utils.Stopwatch; +import org.apache.commons.io.IOUtils; +import org.apache.http.Header; +import org.apache.http.HttpResponse; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.util.EntityUtils; + +/** + * This is an abstraction on top of http client. + * + *

Currently it only has one method for retrying http request execution so that the same logic + * doesn't have to be replicated at difference places where retry is needed. + */ +public class RestRequest { + private static final SFLogger logger = SFLoggerFactory.getLogger(RestRequest.class); + + // Request guid per HTTP request + private static final String SF_REQUEST_GUID = "request_guid"; + + // min backoff in milli before we retry due to transient issues + private static final long minBackoffInMilli = 1000; + + // max backoff in milli before we retry due to transient issues + // we double the backoff after each retry till we reach the max backoff + private static final long maxBackoffInMilli = 16000; + + // retry at least once even if timeout limit has been reached + private static final int MIN_RETRY_COUNT = 1; + + static final String ERROR_FIELD_NAME = "error"; + static final String ERROR_USE_DPOP_NONCE = "use_dpop_nonce"; + static final String DPOP_NONCE_HEADER_NAME = "dpop-nonce"; + + static final Set> sslExceptions = + new HashSet<>( + Arrays.asList( + SSLHandshakeException.class, + SSLKeyException.class, + SSLPeerUnverifiedException.class, + SSLProtocolException.class)); + + /** + * Execute an HTTP request with retry logic. + * + * @param httpClient client object used to communicate with other machine + * @param httpRequest request object contains all the request information + * @param retryTimeout : retry timeout (in seconds) + * @param authTimeout : authenticator specific timeout (in seconds) + * @param socketTimeout : curl timeout (in ms) + * @param maxRetries : max retry count for the request + * @param injectSocketTimeout : simulate socket timeout + * @param canceling canceling flag + * @param withoutCookies whether the cookie spec should be set to IGNORE or not + * @param includeRetryParameters whether to include retry parameters in retried requests. Only + * needs to be true for JDBC statement execution (query requests to Snowflake server). + * @param includeRequestGuid whether to include request_guid parameter + * @param retryHTTP403 whether to retry on HTTP 403 or not should be executed before and/or after + * the retry + * @return HttpResponse Object get from server + * @throws net.snowflake.ingest.streaming.internal.fileTransferAgent.SnowflakeSQLException Request + * timeout Exception or Illegal State Exception i.e. connection is already shutdown etc + */ + public static CloseableHttpResponse execute( + CloseableHttpClient httpClient, + HttpRequestBase httpRequest, + long retryTimeout, + long authTimeout, + int socketTimeout, + int maxRetries, + int injectSocketTimeout, + AtomicBoolean canceling, + boolean withoutCookies, + boolean includeRetryParameters, + boolean includeRequestGuid, + boolean retryHTTP403, + ExecTimeTelemetryData execTimeTelemetryData) + throws SnowflakeSQLException { + return execute( + httpClient, + httpRequest, + retryTimeout, + authTimeout, + socketTimeout, + maxRetries, + injectSocketTimeout, + canceling, + withoutCookies, + includeRetryParameters, + includeRequestGuid, + retryHTTP403, + false, // noRetry + execTimeTelemetryData, + null); + } + + /** + * Execute an HTTP request with retry logic. + * + * @param httpClient client object used to communicate with other machine + * @param httpRequest request object contains all the request information + * @param retryTimeout : retry timeout (in seconds) + * @param authTimeout : authenticator specific timeout (in seconds) + * @param socketTimeout : curl timeout (in ms) + * @param maxRetries : max retry count for the request + * @param injectSocketTimeout : simulate socket timeout + * @param canceling canceling flag + * @param withoutCookies whether the cookie spec should be set to IGNORE or not + * @param includeRetryParameters whether to include retry parameters in retried requests. Only + * needs to be true for JDBC statement execution (query requests to Snowflake server). + * @param includeRequestGuid whether to include request_guid parameter + * @param retryHTTP403 whether to retry on HTTP 403 or not should be executed before and/or after + * the retry + * @return HttpResponse Object get from server + * @throws net.snowflake.ingest.streaming.internal.fileTransferAgent.SnowflakeSQLException Request + * timeout Exception or Illegal State Exception i.e. connection is already shutdown etc + */ + public static CloseableHttpResponse execute( + CloseableHttpClient httpClient, + HttpRequestBase httpRequest, + long retryTimeout, + long authTimeout, + int socketTimeout, + int maxRetries, + int injectSocketTimeout, + AtomicBoolean canceling, + boolean withoutCookies, + boolean includeRetryParameters, + boolean includeRequestGuid, + boolean retryHTTP403, + boolean noRetry, + ExecTimeTelemetryData execTimeData) + throws SnowflakeSQLException { + return execute( + httpClient, + httpRequest, + retryTimeout, + authTimeout, + socketTimeout, + maxRetries, + injectSocketTimeout, + canceling, + withoutCookies, + includeRetryParameters, + includeRequestGuid, + retryHTTP403, + noRetry, + execTimeData, + null); + } + + /** + * Execute an HTTP request with retry logic. + * + * @param httpClient client object used to communicate with other machine + * @param httpRequest request object contains all the request information + * @param retryTimeout : retry timeout (in seconds) + * @param authTimeout : authenticator specific timeout (in seconds) + * @param socketTimeout : curl timeout (in ms) + * @param maxRetries : max retry count for the request + * @param injectSocketTimeout : simulate socket timeout + * @param canceling canceling flag + * @param withoutCookies whether the cookie spec should be set to IGNORE or not + * @param includeRetryParameters whether to include retry parameters in retried requests. Only + * needs to be true for JDBC statement execution (query requests to Snowflake server). + * @param includeRequestGuid whether to include request_guid parameter + * @param retryHTTP403 whether to retry on HTTP 403 or not + * @param execTimeData ExecTimeTelemetryData should be executed before and/or after the retry + * @return HttpResponse Object get from server + * @throws net.snowflake.ingest.streaming.internal.fileTransferAgent.SnowflakeSQLException Request + * timeout Exception or Illegal State Exception i.e. connection is already shutdown etc + */ + public static CloseableHttpResponse execute( + CloseableHttpClient httpClient, + HttpRequestBase httpRequest, + long retryTimeout, + long authTimeout, + int socketTimeout, + int maxRetries, + int injectSocketTimeout, + AtomicBoolean canceling, + boolean withoutCookies, + boolean includeRetryParameters, + boolean includeRequestGuid, + boolean retryHTTP403, + ExecTimeTelemetryData execTimeData, + RetryContextManager retryContextManager) + throws SnowflakeSQLException { + return execute( + httpClient, + httpRequest, + retryTimeout, + authTimeout, + socketTimeout, + maxRetries, + injectSocketTimeout, + canceling, + withoutCookies, + includeRetryParameters, + includeRequestGuid, + retryHTTP403, + false, // noRetry + execTimeData, + retryContextManager); + } + + /** + * Execute an HTTP request with retry logic. + * + * @param httpClient client object used to communicate with other machine + * @param httpRequest request object contains all the request information + * @param retryTimeout : retry timeout (in seconds) + * @param authTimeout : authenticator specific timeout (in seconds) + * @param socketTimeout : curl timeout (in ms) + * @param maxRetries : max retry count for the request + * @param injectSocketTimeout : simulate socket timeout + * @param canceling canceling flag + * @param withoutCookies whether the cookie spec should be set to IGNORE or not + * @param includeRetryParameters whether to include retry parameters in retried requests. Only + * needs to be true for JDBC statement execution (query requests to Snowflake server). + * @param includeRequestGuid whether to include request_guid parameter + * @param retryHTTP403 whether to retry on HTTP 403 or not + * @param noRetry should we disable retry on non-successful http resp code + * @param execTimeData ExecTimeTelemetryData + * @param retryManager RetryContextManager - object allowing to optionally pass custom logic that + * should be executed before and/or after the retry + * @return HttpResponse Object get from server + * @throws net.snowflake.ingest.streaming.internal.fileTransferAgent.SnowflakeSQLException Request + * timeout Exception or Illegal State Exception i.e. connection is already shutdown etc + */ + public static CloseableHttpResponse execute( + CloseableHttpClient httpClient, + HttpRequestBase httpRequest, + long retryTimeout, + long authTimeout, + int socketTimeout, + int maxRetries, + int injectSocketTimeout, + AtomicBoolean canceling, + boolean withoutCookies, + boolean includeRetryParameters, + boolean includeRequestGuid, + boolean retryHTTP403, + boolean noRetry, + ExecTimeTelemetryData execTimeData, + RetryContextManager retryManager) + throws SnowflakeSQLException { + return executeWithRetries( + httpClient, + httpRequest, + retryTimeout, + authTimeout, + socketTimeout, + maxRetries, + injectSocketTimeout, + canceling, // no canceling + withoutCookies, // no cookie + includeRetryParameters, // no retry + includeRequestGuid, // no request_guid + retryHTTP403, // retry on HTTP 403 + noRetry, + new ExecTimeTelemetryData()) + .getHttpResponse(); + } + + static long getNewBackoffInMilli( + long previousBackoffInMilli, + boolean isLoginRequest, + DecorrelatedJitterBackoff decorrelatedJitterBackoff, + int retryCount, + long retryTimeoutInMilliseconds, + long elapsedMilliForTransientIssues) { + long backoffInMilli; + if (isLoginRequest) { + long jitteredBackoffInMilli = + decorrelatedJitterBackoff.getJitterForLogin(previousBackoffInMilli); + backoffInMilli = + (long) + decorrelatedJitterBackoff.chooseRandom( + jitteredBackoffInMilli + previousBackoffInMilli, + Math.pow(2, retryCount) + jitteredBackoffInMilli); + } else { + + backoffInMilli = decorrelatedJitterBackoff.nextSleepTime(previousBackoffInMilli); + } + + backoffInMilli = Math.min(maxBackoffInMilli, Math.max(previousBackoffInMilli, backoffInMilli)); + + if (retryTimeoutInMilliseconds > 0 + && (elapsedMilliForTransientIssues + backoffInMilli) > retryTimeoutInMilliseconds) { + // If the timeout will be reached before the next backoff, just use the remaining + // time (but cannot be negative) - this is the only place when backoff is not in range + // min-max. + backoffInMilli = + Math.max( + 0, + Math.min( + backoffInMilli, retryTimeoutInMilliseconds - elapsedMilliForTransientIssues)); + logger.debug( + "We are approaching retry timeout {}ms, setting backoff to {}ms", + retryTimeoutInMilliseconds, + backoffInMilli); + } + return backoffInMilli; + } + + static boolean isNonRetryableHTTPCode(CloseableHttpResponse response, boolean retryHTTP403) { + return (response != null) + && (response.getStatusLine().getStatusCode() < 500 + || // service unavailable + response.getStatusLine().getStatusCode() >= 600) + && // gateway timeout + response.getStatusLine().getStatusCode() != 408 + && // retry + response.getStatusLine().getStatusCode() != 429 + && // request timeout + (!retryHTTP403 || response.getStatusLine().getStatusCode() != 403); + } + + private static boolean isCertificateRevoked(Exception ex) { + if (ex == null) { + return false; + } + Throwable ex0 = getRootCause(ex); + if (!(ex0 instanceof SFOCSPException)) { + return false; + } + SFOCSPException cause = (SFOCSPException) ex0; + return cause.getErrorCode() == OCSPErrorCode.CERTIFICATE_STATUS_REVOKED; + } + + private static Throwable getRootCause(Throwable ex) { + Throwable ex0 = ex; + while (ex0.getCause() != null) { + ex0 = ex0.getCause(); + } + return ex0; + } + + private static void setRequestConfig( + HttpRequestBase httpRequest, + boolean withoutCookies, + int injectSocketTimeout, + String requestIdStr, + long authTimeoutInMilli) { + if (withoutCookies) { + httpRequest.setConfig(JdbcHttpUtil.getRequestConfigWithoutCookies()); + } + + // For first call, simulate a socket timeout by setting socket timeout + // to the injected socket timeout value + if (injectSocketTimeout != 0) { + // test code path + logger.debug( + "{}Injecting socket timeout by setting socket timeout to {} ms", + requestIdStr, + injectSocketTimeout); + httpRequest.setConfig( + JdbcHttpUtil.getDefaultRequestConfigWithSocketTimeout( + injectSocketTimeout, withoutCookies)); + } + + // When the auth timeout is set, set the socket timeout as the authTimeout + // so that it can be renewed in time and pass it to the http request configuration. + if (authTimeoutInMilli > 0) { + int requestSocketAndConnectTimeout = (int) authTimeoutInMilli; + logger.debug( + "{}Setting auth timeout as the socket timeout: {} ms", requestIdStr, authTimeoutInMilli); + httpRequest.setConfig( + JdbcHttpUtil.getDefaultRequestConfigWithSocketAndConnectTimeout( + requestSocketAndConnectTimeout, withoutCookies)); + } + } + + private static void setRequestURI( + HttpRequestBase httpRequest, + String requestIdStr, + boolean includeRetryParameters, + boolean includeRequestGuid, + int retryCount, + String lastStatusCodeForRetry, + long startTime, + String requestInfoScrubbed) + throws URISyntaxException { + /* + * Add retryCount if the first request failed + * GS can use the parameter for optimization. Specifically GS + * will only check metadata database to see if a query has been running + * for a retry request. This way for the majority of query requests + * which are not part of retry we don't have to pay the performance + * overhead of looking up in metadata database. + */ + URIBuilder builder = new URIBuilder(httpRequest.getURI()); + // If HTAP + if ("true".equalsIgnoreCase(System.getenv("HTAP_SIMULATION")) + && builder.getPathSegments().contains("query-request")) { + logger.debug("{}Setting htap simulation", requestIdStr); + builder.setParameter("target", "htap_simulation"); + } + if (includeRetryParameters && retryCount > 0) { + updateRetryParameters(builder, retryCount, lastStatusCodeForRetry, startTime); + } + + if (includeRequestGuid) { + UUID guid = UUIDUtils.getUUID(); + logger.debug("{}Request {} guid: {}", requestIdStr, requestInfoScrubbed, guid.toString()); + // Add request_guid for better tracing + builder.setParameter(SF_REQUEST_GUID, guid.toString()); + } + + httpRequest.setURI(builder.build()); + } + + /** + * Execute an HTTP request with retry logic. + * + * @param httpClient client object used to communicate with other machine + * @param httpRequest request object contains all the request information + * @param retryTimeout : retry timeout (in seconds) + * @param authTimeout : authenticator specific timeout (in seconds) + * @param socketTimeout : curl timeout (in ms) + * @param maxRetries : max retry count for the request + * @param injectSocketTimeout : simulate socket timeout + * @param canceling canceling flag + * @param withoutCookies whether the cookie spec should be set to IGNORE or not + * @param includeRetryParameters whether to include retry parameters in retried requests. Only + * needs to be true for JDBC statement execution (query requests to Snowflake server). + * @param includeRequestGuid whether to include request_guid parameter + * @param retryHTTP403 whether to retry on HTTP 403 or not + * @return HttpResponseContextDto Object get from server or exception + * @throws net.snowflake.ingest.streaming.internal.fileTransferAgent.SnowflakeSQLException Request + * timeout Exception or Illegal State Exception i.e. connection is already shutdown etc + */ + public static HttpResponseContextDto executeWithRetries( + CloseableHttpClient httpClient, + HttpRequestBase httpRequest, + long retryTimeout, + long authTimeout, + int socketTimeout, + int maxRetries, + int injectSocketTimeout, + AtomicBoolean canceling, + boolean withoutCookies, + boolean includeRetryParameters, + boolean includeRequestGuid, + boolean retryHTTP403, + boolean unpackResponse, + ExecTimeTelemetryData execTimeTelemetryData) + throws SnowflakeSQLException { + return executeWithRetries( + httpClient, + httpRequest, + retryTimeout, + authTimeout, + socketTimeout, + maxRetries, + injectSocketTimeout, + canceling, + withoutCookies, + includeRetryParameters, + includeRequestGuid, + retryHTTP403, + false, + unpackResponse, + execTimeTelemetryData); + } + + /** + * Execute an HTTP request with retry logic. + * + * @param httpClient client object used to communicate with other machine + * @param httpRequest request object contains all the request information + * @param retryTimeout : retry timeout (in seconds) + * @param authTimeout : authenticator specific timeout (in seconds) + * @param socketTimeout : curl timeout (in ms) + * @param maxRetries : max retry count for the request + * @param injectSocketTimeout : simulate socket timeout + * @param canceling canceling flag + * @param withoutCookies whether the cookie spec should be set to IGNORE or not + * @param includeRetryParameters whether to include retry parameters in retried requests. Only + * needs to be true for JDBC statement execution (query requests to Snowflake server). + * @param includeRequestGuid whether to include request_guid parameter + * @param retryHTTP403 whether to retry on HTTP 403 or not + * @param execTimeTelemetryData ExecTimeTelemetryData should be executed before and/or after the + * retry + * @return HttpResponseContextDto Object get from server or exception + * @throws net.snowflake.ingest.streaming.internal.fileTransferAgent.SnowflakeSQLException Request + * timeout Exception or Illegal State Exception i.e. connection is already shutdown etc + */ + public static HttpResponseContextDto executeWithRetries( + CloseableHttpClient httpClient, + HttpRequestBase httpRequest, + long retryTimeout, + long authTimeout, + int socketTimeout, + int maxRetries, + int injectSocketTimeout, + AtomicBoolean canceling, + boolean withoutCookies, + boolean includeRetryParameters, + boolean includeRequestGuid, + boolean retryHTTP403, + boolean noRetry, + boolean unpackResponse, + ExecTimeTelemetryData execTimeTelemetryData) + throws SnowflakeSQLException { + String requestIdStr = URLUtil.getRequestIdLogStr(httpRequest.getURI()); + String requestInfoScrubbed = SecretDetector.maskSASToken(httpRequest.toString()); + HttpExecutingContext context = + HttpExecutingContextBuilder.withRequest(requestIdStr, requestInfoScrubbed) + .retryTimeout(retryTimeout) + .authTimeout(authTimeout) + .origSocketTimeout(socketTimeout) + .maxRetries(maxRetries) + .injectSocketTimeout(injectSocketTimeout) + .canceling(canceling) + .withoutCookies(withoutCookies) + .includeRetryParameters(includeRetryParameters) + .includeRequestGuid(includeRequestGuid) + .retryHTTP403(retryHTTP403) + .noRetry(noRetry) + .unpackResponse(unpackResponse) + .loginRequest(JdbcHttpUtil.isNewRetryStrategyRequest(httpRequest)) + .build(); + return executeWithRetries(httpClient, httpRequest, context, execTimeTelemetryData, null); + } + + /** + * Execute an HTTP request with retry logic. + * + * @param httpClient client object used to communicate with other machine + * @param httpRequest request object contains all the request information + * @param execTimeData ExecTimeTelemetryData should be executed before and/or after the retry + * @param retryManager RetryManager containing extra actions used during retries + * @return HttpResponseContextDto Object get from server or exception + * @throws net.snowflake.ingest.streaming.internal.fileTransferAgent.SnowflakeSQLException Request + * timeout Exception or Illegal State Exception i.e. connection is already shutdown etc + */ + public static HttpResponseContextDto executeWithRetries( + CloseableHttpClient httpClient, + HttpRequestBase httpRequest, + HttpExecutingContext httpExecutingContext, + ExecTimeTelemetryData execTimeData, + RetryContextManager retryManager) + throws SnowflakeSQLException { + Stopwatch networkComunnicationStapwatch = null; + Stopwatch requestReponseStopWatch = null; + HttpResponseContextDto responseDto = new HttpResponseContextDto(); + + if (logger.isDebugEnabled()) { + networkComunnicationStapwatch = new Stopwatch(); + networkComunnicationStapwatch.start(); + logger.debug( + "{}Executing rest request: {}, retry timeout: {}, socket timeout: {}, max retries: {}," + + " inject socket timeout: {}, canceling: {}, without cookies: {}, include retry" + + " parameters: {}, include request guid: {}, retry http 403: {}, no retry: {}", + httpExecutingContext.getRequestId(), + httpExecutingContext.getRequestInfoScrubbed(), + httpExecutingContext.getRetryTimeoutInMilliseconds(), + httpExecutingContext.getOrigSocketTimeout(), + httpExecutingContext.getMaxRetries(), + httpExecutingContext.isInjectSocketTimeout(), + httpExecutingContext.getCanceling(), + httpExecutingContext.isWithoutCookies(), + httpExecutingContext.isIncludeRetryParameters(), + httpExecutingContext.isIncludeRequestGuid(), + httpExecutingContext.isRetryHTTP403(), + httpExecutingContext.isNoRetry()); + } + if (httpExecutingContext.isLoginRequest()) { + logger.debug( + "{}Request is a login/auth request. Using new retry strategy", + httpExecutingContext.getRequestId()); + } + + RestRequest.setRequestConfig( + httpRequest, + httpExecutingContext.isWithoutCookies(), + httpExecutingContext.getInjectSocketTimeout(), + httpExecutingContext.getRequestId(), + httpExecutingContext.getAuthTimeoutInMilliseconds()); + + // try request till we get a good response or retry timeout + while (true) { + logger.debug( + "{}Retry count: {}, max retries: {}, retry timeout: {} s, backoff: {} ms. Attempting" + + " request: {}", + httpExecutingContext.getRequestId(), + httpExecutingContext.getRetryCount(), + httpExecutingContext.getMaxRetries(), + httpExecutingContext.getRetryTimeout(), + httpExecutingContext.getMinBackoffInMillis(), + httpExecutingContext.getRequestInfoScrubbed()); + try { + // update start time + httpExecutingContext.setStartTimePerRequest(System.currentTimeMillis()); + + RestRequest.setRequestURI( + httpRequest, + httpExecutingContext.getRequestId(), + httpExecutingContext.isIncludeRetryParameters(), + httpExecutingContext.isIncludeRequestGuid(), + httpExecutingContext.getRetryCount(), + httpExecutingContext.getLastStatusCodeForRetry(), + httpExecutingContext.getStartTime(), + httpExecutingContext.getRequestInfoScrubbed()); + + execTimeData.setHttpClientStart(); + CloseableHttpResponse response = httpClient.execute(httpRequest); + responseDto.setHttpResponse(response); + execTimeData.setHttpClientEnd(); + } catch (Exception ex) { + responseDto.setSavedEx(handlingNotRetryableException(ex, httpExecutingContext)); + } finally { + // Reset the socket timeout to its original value if it is not the + // very first iteration. + if (httpExecutingContext.getInjectSocketTimeout() != 0 + && httpExecutingContext.getRetryCount() == 0) { + // test code path + httpRequest.setConfig( + JdbcHttpUtil.getDefaultRequestConfigWithSocketTimeout( + httpExecutingContext.getOrigSocketTimeout(), + httpExecutingContext.isWithoutCookies())); + } + } + boolean shouldSkipRetry = + shouldSkipRetryWithLoggedReason(httpRequest, responseDto, httpExecutingContext); + httpExecutingContext.setShouldRetry(!shouldSkipRetry); + + if (httpExecutingContext.isUnpackResponse() + && responseDto.getHttpResponse() != null + && responseDto.getHttpResponse().getStatusLine().getStatusCode() + == 200) { // todo extract getter for statusCode + processHttpResponse(httpExecutingContext, execTimeData, responseDto); + } + + if (!httpExecutingContext.isShouldRetry()) { + if (responseDto.getHttpResponse() == null) { + if (responseDto.getSavedEx() != null) { + logger.error( + "{}Returning null response. Cause: {}, request: {}", + httpExecutingContext.getRequestId(), + getRootCause(responseDto.getSavedEx()), + httpExecutingContext.getRequestInfoScrubbed()); + } else { + logger.error( + "{}Returning null response for request: {}", + httpExecutingContext.getRequestId(), + httpExecutingContext.getRequestInfoScrubbed()); + } + } else if (responseDto.getHttpResponse().getStatusLine().getStatusCode() != 200) { + logger.error( + "{}Error response: HTTP Response code: {}, request: {}", + httpExecutingContext.getRequestId(), + responseDto.getHttpResponse().getStatusLine().getStatusCode(), + httpExecutingContext.getRequestInfoScrubbed()); + responseDto.setSavedEx( + new SnowflakeSQLException( + SqlState.IO_ERROR, + ErrorCode.NETWORK_ERROR.getMessageCode(), + "HTTP status=" + + ((responseDto.getHttpResponse() != null) + ? responseDto.getHttpResponse().getStatusLine().getStatusCode() + : "null response"))); + } else if ((responseDto.getHttpResponse() == null + || responseDto.getHttpResponse().getStatusLine().getStatusCode() != 200)) { + sendTelemetryEvent( + httpRequest, + httpExecutingContext, + responseDto.getHttpResponse(), + responseDto.getSavedEx()); + } + break; + } else { + prepareRetry(httpRequest, httpExecutingContext, retryManager, responseDto); + } + } + + logger.debug( + "{}Execution of request {} took {} ms with total of {} retries", + httpExecutingContext.getRequestId(), + httpExecutingContext.getRequestInfoScrubbed(), + networkComunnicationStapwatch == null + ? "n/a" + : networkComunnicationStapwatch.elapsedMillis(), + httpExecutingContext.getRetryCount()); + + httpExecutingContext.resetRetryCount(); + if (logger.isDebugEnabled() && networkComunnicationStapwatch != null) { + networkComunnicationStapwatch.stop(); + } + if (responseDto.getSavedEx() != null) { + Exception savedEx = responseDto.getSavedEx(); + if (savedEx instanceof SnowflakeSQLException) { + throw (SnowflakeSQLException) savedEx; + } else { + throw new SnowflakeSQLException( + savedEx, + ErrorCode.NETWORK_ERROR, + "Exception encountered for HTTP request: " + savedEx.getMessage()); + } + } + return responseDto; + } + + private static void processHttpResponse( + HttpExecutingContext httpExecutingContext, + ExecTimeTelemetryData execTimeData, + HttpResponseContextDto responseDto) { + CloseableHttpResponse response = responseDto.getHttpResponse(); + try { + String responseText; + responseText = verifyAndUnpackResponse(response, execTimeData); + httpExecutingContext.setShouldRetry(false); + responseDto.setUnpackedCloseableHttpResponse(responseText); + } catch (IOException ex) { + boolean skipRetriesBecauseOf200 = httpExecutingContext.isSkipRetriesBecauseOf200(); + boolean retryReasonDifferentThan200 = + !httpExecutingContext.isShouldRetry() && skipRetriesBecauseOf200; + httpExecutingContext.setShouldRetry(retryReasonDifferentThan200); + responseDto.setSavedEx(ex); + } + } + + private static void updateRetryParameters( + URIBuilder builder, int retryCount, String lastStatusCodeForRetry, long startTime) { + builder.setParameter("retryCount", String.valueOf(retryCount)); + builder.setParameter("retryReason", lastStatusCodeForRetry); + builder.setParameter("clientStartTime", String.valueOf(startTime)); + } + + private static void prepareRetry( + HttpRequestBase httpRequest, + HttpExecutingContext httpExecutingContext, + RetryContextManager retryManager, + HttpResponseContextDto dto) + throws SnowflakeSQLException { + // Potentially retryable error + logRequestResult( + dto.getHttpResponse(), + httpExecutingContext.getRequestId(), + httpExecutingContext.getRequestInfoScrubbed(), + dto.getSavedEx()); + + // get the elapsed time for the last request + // elapsed in millisecond for last call, used for calculating the + // remaining amount of time to sleep: + // (backoffInMilli - elapsedMilliForLastCall) + long elapsedMilliForLastCall = + System.currentTimeMillis() - httpExecutingContext.getStartTimePerRequest(); + + if (httpExecutingContext.socketOrConnectTimeoutReached()) + /* socket timeout not reached */ { + /* connect timeout not reached */ + // check if this is a login-request + if (String.valueOf(httpRequest.getURI()).contains("login-request")) { + throw new SnowflakeSQLException( + ErrorCode.AUTHENTICATOR_REQUEST_TIMEOUT, + httpExecutingContext.getRetryCount(), + true, + httpExecutingContext.getElapsedMilliForTransientIssues() / 1000); + } + } + + // sleep for backoff - elapsed amount of time + sleepForBackoffAndPrepareNext(elapsedMilliForLastCall, httpExecutingContext); + + httpExecutingContext.incrementRetryCount(); + httpExecutingContext.setLastStatusCodeForRetry( + dto.getHttpResponse() == null + ? "0" + : String.valueOf(dto.getHttpResponse().getStatusLine().getStatusCode())); + // If the request failed with any other retry-able error and auth timeout is reached + // increase the retry count and throw special exception to renew the token before retrying. + + RetryContextManager.RetryHook retryManagerHook = null; + if (retryManager != null) { + retryManagerHook = retryManager.getRetryHook(); + retryManager + .getRetryContext() + .setElapsedTimeInMillis(httpExecutingContext.getElapsedMilliForTransientIssues()) + .setRetryTimeoutInMillis(httpExecutingContext.getRetryTimeoutInMilliseconds()); + } + + // Make sure that any authenticator specific info that needs to be + // updated gets updated before the next retry. Ex - OKTA OTT, JWT token + // Aim is to achieve this using RetryContextManager, but raising + // AUTHENTICATOR_REQUEST_TIMEOUT Exception is still supported as well. In both cases the + // retried request must be aware of the elapsed time not to exceed the timeout limit. + if (retryManagerHook == RetryContextManager.RetryHook.ALWAYS_BEFORE_RETRY) { + retryManager.executeRetryCallbacks(httpRequest); + } + + if (httpExecutingContext.getAuthTimeout() > 0 + && httpExecutingContext.getElapsedMilliForTransientIssues() + >= httpExecutingContext.getAuthTimeout()) { + throw new SnowflakeSQLException( + ErrorCode.AUTHENTICATOR_REQUEST_TIMEOUT, + httpExecutingContext.getRetryCount(), + false, + httpExecutingContext.getElapsedMilliForTransientIssues() / 1000); + } + + int numOfRetryToTriggerTelemetry = + TelemetryService.getInstance().getNumOfRetryToTriggerTelemetry(); + if (httpExecutingContext.getRetryCount() == numOfRetryToTriggerTelemetry) { + TelemetryService.getInstance() + .logHttpRequestTelemetryEvent( + String.format("HttpRequestRetry%dTimes", numOfRetryToTriggerTelemetry), + httpRequest, + httpExecutingContext.getInjectSocketTimeout(), + httpExecutingContext.getCanceling(), + httpExecutingContext.isWithoutCookies(), + httpExecutingContext.isIncludeRetryParameters(), + httpExecutingContext.isIncludeRequestGuid(), + dto.getHttpResponse(), + dto.getSavedEx(), + httpExecutingContext.getBreakRetryReason(), + httpExecutingContext.getRetryTimeout(), + httpExecutingContext.getRetryCount(), + SqlState.IO_ERROR, + ErrorCode.NETWORK_ERROR.getMessageCode()); + } + dto.setSavedEx(null); + httpExecutingContext.setSkipRetriesBecauseOf200(false); + + // release connection before retry + httpRequest.releaseConnection(); + } + + private static void sendTelemetryEvent( + HttpRequestBase httpRequest, + HttpExecutingContext httpExecutingContext, + CloseableHttpResponse response, + Exception savedEx) { + String eventName; + if (response == null) { + eventName = "NullResponseHttpError"; + } else { + if (response.getStatusLine() == null) { + eventName = "NullResponseStatusLine"; + } else { + eventName = String.format("HttpError%d", response.getStatusLine().getStatusCode()); + } + } + TelemetryService.getInstance() + .logHttpRequestTelemetryEvent( + eventName, + httpRequest, + httpExecutingContext.getInjectSocketTimeout(), + httpExecutingContext.getCanceling(), + httpExecutingContext.isWithoutCookies(), + httpExecutingContext.isIncludeRetryParameters(), + httpExecutingContext.isIncludeRequestGuid(), + response, + savedEx, + httpExecutingContext.getBreakRetryReason(), + httpExecutingContext.getRetryTimeout(), + httpExecutingContext.getRetryCount(), + null, + 0); + } + + private static void sleepForBackoffAndPrepareNext( + long elapsedMilliForLastCall, HttpExecutingContext context) { + if (context.getMinBackoffInMillis() > elapsedMilliForLastCall) { + try { + logger.debug( + "{}Retry request {}: sleeping for {} ms", + context.getRequestId(), + context.getRequestInfoScrubbed(), + context.getBackoffInMillis()); + Thread.sleep(context.getBackoffInMillis()); + } catch (InterruptedException ex1) { + logger.debug( + "{}Backoff sleep before retrying login got interrupted", context.getRequestId()); + } + context.increaseElapsedMilliForTransientIssues(context.getBackoffInMillis()); + context.setBackoffInMillis( + getNewBackoffInMilli( + context.getBackoffInMillis(), + context.isLoginRequest(), + context.getBackoff(), + context.getRetryCount(), + context.getRetryTimeoutInMilliseconds(), + context.getElapsedMilliForTransientIssues())); + } + } + + private static void logRequestResult( + CloseableHttpResponse response, + String requestIdStr, + String requestInfoScrubbed, + Exception savedEx) { + if (response != null) { + logger.debug( + "{}HTTP response not ok: status code: {}, request: {}", + requestIdStr, + response.getStatusLine().getStatusCode(), + requestInfoScrubbed); + } else if (savedEx != null) { + logger.debug( + "{}Null response for cause: {}, request: {}", + requestIdStr, + getRootCause(savedEx).getMessage(), + requestInfoScrubbed); + } else { + logger.debug("{}Null response for request: {}", requestIdStr, requestInfoScrubbed); + } + } + + private static void checkForDPoPNonceError(CloseableHttpResponse response) throws IOException { + String errorResponse = EntityUtils.toString(response.getEntity()); + if (!isNullOrEmpty(errorResponse)) { + ObjectMapper objectMapper = ObjectMapperFactory.getObjectMapper(); + JsonNode rootNode = objectMapper.readTree(errorResponse); + JsonNode errorNode = rootNode.get(ERROR_FIELD_NAME); + if (errorNode != null + && errorNode.isValueNode() + && errorNode.isTextual() + && errorNode.textValue().equals(ERROR_USE_DPOP_NONCE)) { + throw new SnowflakeUseDPoPNonceException( + response.getFirstHeader(DPOP_NONCE_HEADER_NAME).getValue()); + } + } + } + + static Exception handlingNotRetryableException( + Exception ex, HttpExecutingContext httpExecutingContext) throws SnowflakeSQLLoggedException { + Exception savedEx = null; + if (ex instanceof IllegalStateException) { + throw new SnowflakeSQLLoggedException( + ErrorCode.INVALID_STATE, ex, /* session= */ ex.getMessage()); + } else if (isExceptionInGroup(ex, sslExceptions) && !isProtocolVersionError(ex)) { + String formattedMsg = + ex.getMessage() + + "\n" + + "Verify that the hostnames and portnumbers in SYSTEM$ALLOWLIST are added to your" + + " firewall's allowed list.\n" + + "To troubleshoot your connection further, you can refer to this article:\n" + + "https://docs.snowflake.com/en/user-guide/client-connectivity-troubleshooting/overview"; + + throw new SnowflakeSQLLoggedException(ErrorCode.NETWORK_ERROR, ex, formattedMsg); + } else if (ex instanceof Exception) { + savedEx = ex; + // if the request took more than socket timeout log a warning + long currentMillis = System.currentTimeMillis(); + if ((currentMillis - httpExecutingContext.getStartTimePerRequest()) + > JdbcHttpUtil.getSocketTimeout().toMillis()) { + logger.warn( + "{}HTTP request took longer than socket timeout {} ms: {} ms", + httpExecutingContext.getRequestId(), + JdbcHttpUtil.getSocketTimeout().toMillis(), + (currentMillis - httpExecutingContext.getStartTimePerRequest())); + } + StringWriter sw = new StringWriter(); + savedEx.printStackTrace(new PrintWriter(sw)); + logger.debug( + "{}Exception encountered for: {}, {}, {}", + httpExecutingContext.getRequestId(), + httpExecutingContext.getRequestInfoScrubbed(), + ex.getLocalizedMessage(), + (ArgSupplier) sw::toString); + } + return ex; + } + + static boolean isExceptionInGroup(Exception e, Set> group) { + for (Class clazz : group) { + if (clazz.isInstance(e)) { + return true; + } + } + return false; + } + + static boolean isProtocolVersionError(Exception e) { + return e.getMessage() != null + && e.getMessage().contains("Received fatal alert: protocol_version"); + } + + private static boolean handleCertificateRevoked( + Exception savedEx, HttpExecutingContext httpExecutingContext, boolean skipRetrying) { + if (!skipRetrying && RestRequest.isCertificateRevoked(savedEx)) { + String msg = "Unknown reason"; + Throwable rootCause = RestRequest.getRootCause(savedEx); + msg = + rootCause.getMessage() != null && !rootCause.getMessage().isEmpty() + ? rootCause.getMessage() + : msg; + logger.debug( + "{}Error response not retryable, " + msg + ", request: {}", + httpExecutingContext.getRequestId(), + httpExecutingContext.getRequestInfoScrubbed()); + EventUtil.triggerBasicEvent( + Event.EventType.NETWORK_ERROR, + msg + ", Request: " + httpExecutingContext.getRequestInfoScrubbed(), + false); + + httpExecutingContext.setBreakRetryReason("certificate revoked error"); + httpExecutingContext.setBreakRetryEventName("HttpRequestRetryVertificateRevoked"); + httpExecutingContext.setShouldRetry(false); + return true; + } + return skipRetrying; + } + + private static boolean handleNonRetryableHttpCode( + HttpResponseContextDto dto, HttpExecutingContext httpExecutingContext, boolean skipRetrying) { + CloseableHttpResponse response = dto.getHttpResponse(); + if (!skipRetrying && isNonRetryableHTTPCode(response, httpExecutingContext.isRetryHTTP403())) { + String msg = "Unknown reason"; + if (response != null) { + logger.debug( + "{}HTTP response code for request {}: {}", + httpExecutingContext.getRequestId(), + httpExecutingContext.getRequestInfoScrubbed(), + response.getStatusLine().getStatusCode()); + msg = + "StatusCode: " + + response.getStatusLine().getStatusCode() + + ", Reason: " + + response.getStatusLine().getReasonPhrase(); + } else if (dto.getSavedEx() != null) // may be null. + { + Throwable rootCause = RestRequest.getRootCause(dto.getSavedEx()); + msg = rootCause.getMessage(); + } + + if (response == null || response.getStatusLine().getStatusCode() != 200) { + logger.debug( + "{}Error response not retryable, " + msg + ", request: {}", + httpExecutingContext.getRequestId(), + httpExecutingContext.getRequestInfoScrubbed()); + EventUtil.triggerBasicEvent( + Event.EventType.NETWORK_ERROR, + msg + ", Request: " + httpExecutingContext.getRequestInfoScrubbed(), + false); + } + httpExecutingContext.setBreakRetryReason("status code does not need retry"); + httpExecutingContext.setShouldRetry(false); + httpExecutingContext.setSkipRetriesBecauseOf200( + response.getStatusLine().getStatusCode() == 200); + + try { + if (response == null || response.getStatusLine().getStatusCode() != 200) { + logger.error( + "Error executing request: {}", httpExecutingContext.getRequestInfoScrubbed()); + + if (response != null + && response.getStatusLine().getStatusCode() == 400 + && response.getEntity() != null) { + checkForDPoPNonceError(response); + } + + logResponseDetails(response, logger); + + if (response != null) { + EntityUtils.consume(response.getEntity()); + } + + // We throw here exception if timeout was reached for login + dto.setSavedEx( + new SnowflakeSQLException( + SqlState.IO_ERROR, + ErrorCode.NETWORK_ERROR.getMessageCode(), + "HTTP status=" + + ((response != null) + ? response.getStatusLine().getStatusCode() + : "null response"))); + } + } catch (IOException e) { + dto.setSavedEx( + new SnowflakeSQLException( + SqlState.IO_ERROR, + ErrorCode.NETWORK_ERROR.getMessageCode(), + "Exception details: " + e.getMessage())); + } + return true; + } + return skipRetrying; + } + + private static void logTelemetryEvent( + HttpRequestBase request, + CloseableHttpResponse response, + Exception savedEx, + HttpExecutingContext httpExecutingContext) { + TelemetryService.getInstance() + .logHttpRequestTelemetryEvent( + httpExecutingContext.getBreakRetryEventName(), + request, + httpExecutingContext.getInjectSocketTimeout(), + httpExecutingContext.getCanceling(), + httpExecutingContext.isWithoutCookies(), + httpExecutingContext.isIncludeRetryParameters(), + httpExecutingContext.isIncludeRequestGuid(), + response, + savedEx, + httpExecutingContext.getBreakRetryReason(), + httpExecutingContext.getRetryTimeout(), + httpExecutingContext.getRetryCount(), + SqlState.IO_ERROR, + ErrorCode.NETWORK_ERROR.getMessageCode()); + } + + private static boolean handleMaxRetriesExceeded( + HttpExecutingContext httpExecutingContext, boolean skipRetrying) { + if (!skipRetrying && httpExecutingContext.maxRetriesExceeded()) { + logger.error( + "{}Stop retrying as max retries have been reached for request: {}! Max retry count: {}", + httpExecutingContext.getRequestId(), + httpExecutingContext.getRequestInfoScrubbed(), + httpExecutingContext.getMaxRetries()); + + httpExecutingContext.setBreakRetryReason("max retries reached"); + httpExecutingContext.setBreakRetryEventName("HttpRequestRetryLimitExceeded"); + httpExecutingContext.setShouldRetry(false); + return true; + } + return skipRetrying; + } + + private static boolean handleElapsedTimeoutExceeded( + HttpExecutingContext httpExecutingContext, boolean skipRetrying) { + if (!skipRetrying && httpExecutingContext.getRetryTimeoutInMilliseconds() > 0) { + // Check for retry time-out. + // increment total elapsed due to transient issues + long elapsedMilliForLastCall = + System.currentTimeMillis() - httpExecutingContext.getStartTimePerRequest(); + httpExecutingContext.increaseElapsedMilliForTransientIssues(elapsedMilliForLastCall); + + // check if the total elapsed time for transient issues has exceeded + // the retry timeout and we retry at least the min, if so, we will not + // retry + if (httpExecutingContext.elapsedTimeExceeded() && httpExecutingContext.moreThanMinRetries()) { + logger.error( + "{}Stop retrying since elapsed time due to network " + + "issues has reached timeout. " + + "Elapsed: {} ms, timeout: {} ms", + httpExecutingContext.getRequestId(), + httpExecutingContext.getElapsedMilliForTransientIssues(), + httpExecutingContext.getRetryTimeoutInMilliseconds()); + + httpExecutingContext.setBreakRetryReason("retry timeout"); + httpExecutingContext.setBreakRetryEventName("HttpRequestRetryTimeout"); + httpExecutingContext.setShouldRetry(false); + return true; + } + } + return skipRetrying; + } + + private static boolean handleCancelingSignal( + HttpExecutingContext httpExecutingContext, boolean skipRetrying) { + if (!skipRetrying + && httpExecutingContext.getCanceling() != null + && httpExecutingContext.getCanceling().get()) { + logger.debug( + "{}Stop retrying since canceling is requested", httpExecutingContext.getRequestId()); + httpExecutingContext.setBreakRetryReason("canceling is requested"); + httpExecutingContext.setShouldRetry(false); + return true; + } + return skipRetrying; + } + + private static boolean handleNoRetryFlag( + HttpExecutingContext httpExecutingContext, boolean skipRetrying) { + if (!skipRetrying && httpExecutingContext.isNoRetry()) { + logger.debug( + "{}HTTP retry disabled for this request. noRetry: {}", + httpExecutingContext.getRequestId(), + httpExecutingContext.isNoRetry()); + httpExecutingContext.setBreakRetryReason("retry is disabled"); + httpExecutingContext.resetRetryCount(); + httpExecutingContext.setShouldRetry(false); + return true; + } + return skipRetrying; + } + + private static boolean shouldSkipRetryWithLoggedReason( + HttpRequestBase request, + HttpResponseContextDto responseDto, + HttpExecutingContext httpExecutingContext) { + CloseableHttpResponse response = responseDto.getHttpResponse(); + Exception savedEx = responseDto.getSavedEx(); + List> conditions = + Arrays.asList( + skipRetrying -> handleNoRetryFlag(httpExecutingContext, skipRetrying), + skipRetrying -> handleCancelingSignal(httpExecutingContext, skipRetrying), + skipRetrying -> handleElapsedTimeoutExceeded(httpExecutingContext, skipRetrying), + skipRetrying -> handleMaxRetriesExceeded(httpExecutingContext, skipRetrying), + skipRetrying -> handleCertificateRevoked(savedEx, httpExecutingContext, skipRetrying), + skipRetrying -> + handleNonRetryableHttpCode(responseDto, httpExecutingContext, skipRetrying)); + + // Process each condition using Stream + boolean skipRetrying = + conditions.stream().reduce(Function::andThen).orElse(Function.identity()).apply(false); + + // Log telemetry + logTelemetryEvent(request, response, savedEx, httpExecutingContext); + + return skipRetrying; + } + + /** + * Inlined from SnowflakeUtil.logResponseDetails — cannot call JDBC's version because it expects + * net.snowflake.client.log.SFLogger, not our replicated SFLogger. + */ + private static void logResponseDetails(HttpResponse response, SFLogger logger) { + if (response == null) { + logger.error("null response", false); + return; + } + + // log the response + if (response.getStatusLine() != null) { + logger.error("Response status line reason: {}", response.getStatusLine().getReasonPhrase()); + } + + // log each header from response + Header[] headers = response.getAllHeaders(); + if (headers != null) { + for (Header header : headers) { + logger.debug("Header name: {}, value: {}", header.getName(), header.getValue()); + } + } + + // log response + if (response.getEntity() != null) { + try { + StringWriter writer = new StringWriter(); + BufferedReader bufferedReader = + new BufferedReader(new InputStreamReader((response.getEntity().getContent()))); + IOUtils.copy(bufferedReader, writer); + logger.error("Response content: {}", writer.toString()); + } catch (IOException ex) { + logger.error("Failed to read content due to exception: " + "{}", ex.getMessage()); + } + } + } + + private static String verifyAndUnpackResponse( + CloseableHttpResponse response, ExecTimeTelemetryData execTimeData) throws IOException { + try (StringWriter writer = new StringWriter()) { + execTimeData.setResponseIOStreamStart(); + try (InputStream ins = response.getEntity().getContent()) { + IOUtils.copy(ins, writer, "UTF-8"); + } + + execTimeData.setResponseIOStreamEnd(); + return writer.toString(); + } finally { + IOUtils.closeQuietly(response); + } + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/RetryContext.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/RetryContext.java new file mode 100644 index 000000000..c2996ae73 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/RetryContext.java @@ -0,0 +1,44 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/jdbc/RetryContext.java + * + * Permitted differences: package declaration, @SnowflakeJdbcInternalApi annotation removed. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +/** RetryContext stores information about an ongoing request's retrying process. */ +public class RetryContext { + static final int SECONDS_TO_MILLIS_FACTOR = 1000; + private long elapsedTimeInMillis; + private long retryTimeoutInMillis; + private long retryCount; + + public RetryContext() {} + + public RetryContext setElapsedTimeInMillis(long elapsedTimeInMillis) { + this.elapsedTimeInMillis = elapsedTimeInMillis; + return this; + } + + public RetryContext setRetryTimeoutInMillis(long retryTimeoutInMillis) { + this.retryTimeoutInMillis = retryTimeoutInMillis; + return this; + } + + public RetryContext setRetryCount(long retryCount) { + this.retryCount = retryCount; + return this; + } + + private long getRemainingRetryTimeoutInMillis() { + return retryTimeoutInMillis - elapsedTimeInMillis; + } + + public long getRemainingRetryTimeoutInSeconds() { + return (getRemainingRetryTimeoutInMillis()) / SECONDS_TO_MILLIS_FACTOR; + } + + public long getRetryCount() { + return retryCount; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/RetryContextManager.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/RetryContextManager.java new file mode 100644 index 000000000..60a5abe4e --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/RetryContextManager.java @@ -0,0 +1,90 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/jdbc/RetryContextManager.java + * + * Permitted differences: package declaration, import swaps for already-replicated classes, + * @SnowflakeJdbcInternalApi annotation removed. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import java.util.ArrayList; +import java.util.List; +import org.apache.http.client.methods.HttpRequestBase; + +/** + * RetryContextManager lets you register logic (as callbacks) that will be re-executed during a + * retry of a request. + */ +public class RetryContextManager { + + // List of retry callbacks that will be executed in the order they were registered. + private final List< + ThrowingBiFunction> + retryCallbacks = new ArrayList<>(); + + // A RetryHook flag that can be used by client code to decide when (or if) callbacks should be + // executed. + private final RetryHook retryHook; + private RetryContext retryContext; + + /** Enumeration for different retry hook strategies. */ + public enum RetryHook { + /** Always execute the registered retry callbacks on every retry. */ + ALWAYS_BEFORE_RETRY, + } + + /** Default constructor using ALWAYS_BEFORE_RETRY as the default retry hook. */ + public RetryContextManager() { + this(RetryHook.ALWAYS_BEFORE_RETRY); + } + + /** + * Constructor that accepts a specific RetryHook. + * + * @param retryHook the retry hook strategy. + */ + public RetryContextManager(RetryHook retryHook) { + this.retryHook = retryHook; + this.retryContext = new RetryContext(); + } + + /** + * Registers a retry callback that will be executed on each retry. + * + * @param callback A RetryCallback encapsulating the logic to be replayed on retry. + * @return the current instance for fluent chaining. + */ + public RetryContextManager registerRetryCallback( + ThrowingBiFunction + callback) { + retryCallbacks.add(callback); + return this; + } + + /** + * Executes all registered retry callbacks in the order they were added, before reattempting the + * operation. + * + * @param requestToRetry the HTTP request to retry. + * @throws SnowflakeSQLException if an error occurs during callback execution. + */ + public void executeRetryCallbacks(HttpRequestBase requestToRetry) throws SnowflakeSQLException { + for (ThrowingBiFunction + callback : retryCallbacks) { + retryContext = callback.apply(requestToRetry, retryContext); + } + } + + /** + * Returns the configured RetryHook. + * + * @return the retry hook. + */ + public RetryHook getRetryHook() { + return retryHook; + } + + public RetryContext getRetryContext() { + return retryContext; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/S3HttpUtil.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/S3HttpUtil.java index 84507cff5..2477570b4 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/S3HttpUtil.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/S3HttpUtil.java @@ -21,8 +21,7 @@ import net.snowflake.ingest.utils.SFSessionProperty; public class S3HttpUtil { - private static final SFLogger logger = - SFLoggerFactory.getLogger(net.snowflake.ingest.utils.HttpUtil.class); + private static final SFLogger logger = SFLoggerFactory.getLogger(JdbcHttpUtil.class); /** * A static function to set S3 proxy params when there is a valid session @@ -32,13 +31,10 @@ public class S3HttpUtil { */ // Parameter uses JDBC's HttpClientSettingsKey because session.getHttpClientKey() returns it. // This path is only used when session != null (never from streaming ingest). - public static void setProxyForS3( - net.snowflake.client.core.HttpClientSettingsKey key, ClientConfiguration clientConfig) { + public static void setProxyForS3(HttpClientSettingsKey key, ClientConfiguration clientConfig) { if (key != null && key.usesProxy()) { clientConfig.setProxyProtocol( - key.getProxyHttpProtocol() == net.snowflake.client.core.HttpProtocol.HTTPS - ? Protocol.HTTPS - : Protocol.HTTP); + key.getProxyHttpProtocol() == HttpProtocol.HTTPS ? Protocol.HTTPS : Protocol.HTTP); clientConfig.setProxyHost(key.getProxyHost()); clientConfig.setProxyPort(key.getProxyPort()); clientConfig.setNonProxyHosts(key.getNonProxyHosts()); diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SFOCSPException.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SFOCSPException.java new file mode 100644 index 000000000..0513569b3 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SFOCSPException.java @@ -0,0 +1,41 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/SFOCSPException.java + * + * Permitted differences: package declaration, import swaps for already-replicated classes. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLogger; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLoggerFactory; + +public class SFOCSPException extends Throwable { + private static final SFLogger logger = SFLoggerFactory.getLogger(SFOCSPException.class); + + private static final long serialVersionUID = 1L; + + private final OCSPErrorCode errorCode; + + public SFOCSPException(OCSPErrorCode errorCode, String errorMsg) { + this(errorCode, errorMsg, null); + } + + public SFOCSPException(OCSPErrorCode errorCode, String errorMsg, Throwable cause) { + super(errorMsg); + this.errorCode = errorCode; + if (cause != null) { + this.initCause(cause); + } + } + + public OCSPErrorCode getErrorCode() { + return errorCode; + } + + @Override + public String toString() { + return super.toString() + + (getErrorCode() != null ? ", errorCode = " + getErrorCode() : "") + + (getMessage() != null ? ", errorMsg = " + getMessage() : ""); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SFTrustManager.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SFTrustManager.java new file mode 100644 index 000000000..46791f1d7 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SFTrustManager.java @@ -0,0 +1,1689 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/SFTrustManager.java + * + * Permitted differences: package declaration, + * import swaps for already-replicated classes, + * @SnowflakeJdbcInternalApi annotations removed, + * HttpUtil.SFConnectionSocketFactory -> JdbcHttpUtil.SFConnectionSocketFactory, + * TelemetryService OOB call removed from OCSPTelemetryData (separate file). + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import static net.snowflake.ingest.streaming.internal.fileTransferAgent.StorageClientUtil.isNullOrEmpty; +import static net.snowflake.ingest.streaming.internal.fileTransferAgent.StorageClientUtil.systemGetEnv; +import static net.snowflake.ingest.streaming.internal.fileTransferAgent.StorageClientUtil.systemGetProperty; + +import com.amazonaws.Protocol; +import com.amazonaws.http.apache.SdkProxyRoutePlanner; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.JsonNodeType; +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.OutputStream; +import java.math.BigInteger; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; +import java.security.InvalidKeyException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.Signature; +import java.security.SignatureException; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.text.MessageFormat; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Date; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TimeZone; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLogger; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLoggerFactory; +import net.snowflake.ingest.utils.OCSPMode; +import net.snowflake.ingest.utils.SFPair; +import org.apache.commons.codec.binary.Base64; +import org.apache.commons.io.IOUtils; +import org.apache.http.HttpHost; +import org.apache.http.HttpStatus; +import org.apache.http.auth.AuthScope; +import org.apache.http.auth.Credentials; +import org.apache.http.auth.UsernamePasswordCredentials; +import org.apache.http.client.CredentialsProvider; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.config.Registry; +import org.apache.http.config.RegistryBuilder; +import org.apache.http.conn.socket.ConnectionSocketFactory; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.BasicCredentialsProvider; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.DefaultRedirectStrategy; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.impl.conn.PoolingHttpClientConnectionManager; +import org.apache.http.ssl.SSLInitializationException; +import org.bouncycastle.asn1.ASN1Encodable; +import org.bouncycastle.asn1.ASN1Integer; +import org.bouncycastle.asn1.ASN1ObjectIdentifier; +import org.bouncycastle.asn1.ASN1OctetString; +import org.bouncycastle.asn1.DEROctetString; +import org.bouncycastle.asn1.DLSequence; +import org.bouncycastle.asn1.ocsp.CertID; +import org.bouncycastle.asn1.oiw.OIWObjectIdentifiers; +import org.bouncycastle.asn1.x509.AlgorithmIdentifier; +import org.bouncycastle.asn1.x509.Certificate; +import org.bouncycastle.asn1.x509.Extension; +import org.bouncycastle.asn1.x509.Extensions; +import org.bouncycastle.asn1.x509.GeneralName; +import org.bouncycastle.asn1.x509.TBSCertificate; +import org.bouncycastle.cert.X509CertificateHolder; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.cert.ocsp.BasicOCSPResp; +import org.bouncycastle.cert.ocsp.CertificateID; +import org.bouncycastle.cert.ocsp.CertificateStatus; +import org.bouncycastle.cert.ocsp.OCSPException; +import org.bouncycastle.cert.ocsp.OCSPReq; +import org.bouncycastle.cert.ocsp.OCSPReqBuilder; +import org.bouncycastle.cert.ocsp.OCSPResp; +import org.bouncycastle.cert.ocsp.RevokedStatus; +import org.bouncycastle.cert.ocsp.SingleResp; +import org.bouncycastle.operator.DigestCalculator; + +/** + * SFTrustManager is a composite of TrustManager of the default JVM TrustManager and Snowflake OCSP + * revocation status checker. Use this when initializing SSLContext object. + * + *

{@code
+ * TrustManager[] trustManagers = {new SFTrustManager()};
+ * SSLContext sslContext = SSLContext.getInstance("TLS");
+ * sslContext.init(null, trustManagers, null);
+ * }
+ */ +public class SFTrustManager extends X509ExtendedTrustManager { + /** Test System Parameters. Not used in the production */ + public static final String SF_OCSP_RESPONSE_CACHE_SERVER_URL = + "SF_OCSP_RESPONSE_CACHE_SERVER_URL"; + + public static final String SF_OCSP_RESPONSE_CACHE_SERVER_ENABLED = + "SF_OCSP_RESPONSE_CACHE_SERVER_ENABLED"; + public static final String SF_OCSP_TEST_INJECT_VALIDITY_ERROR = + "SF_OCSP_TEST_INJECT_VALIDITY_ERROR"; + public static final String SF_OCSP_TEST_INJECT_UNKNOWN_STATUS = + "SF_OCSP_TEST_INJECT_UNKNOWN_STATUS"; + public static final String SF_OCSP_TEST_RESPONDER_URL = "SF_OCSP_TEST_RESPONDER_URL"; + public static final String SF_OCSP_TEST_OCSP_RESPONSE_CACHE_SERVER_TIMEOUT = + "SF_OCSP_TEST_OCSP_RESPONSE_CACHE_SERVER_TIMEOUT"; + public static final String SF_OCSP_TEST_OCSP_RESPONDER_TIMEOUT = + "SF_OCSP_TEST_OCSP_RESPONDER_TIMEOUT"; + public static final String SF_OCSP_TEST_INVALID_SIGNING_CERT = + "SF_OCSP_TEST_INVALID_SIGNING_CERT"; + public static final String SF_OCSP_TEST_NO_OCSP_RESPONDER_URL = + "SF_OCSP_TEST_NO_OCSP_RESPONDER_URL"; + + /** OCSP response cache file name. Should be identical to other driver's cache file name. */ + static final String CACHE_FILE_NAME = "ocsp_response_cache.json"; + + private static final SFLogger logger = SFLoggerFactory.getLogger(SFTrustManager.class); + private static final ASN1ObjectIdentifier OIDocsp = + new ASN1ObjectIdentifier("1.3.6.1.5.5.7.48.1").intern(); + private static final ASN1ObjectIdentifier SHA1RSA = + new ASN1ObjectIdentifier("1.2.840.113549.1.1.5").intern(); + private static final ASN1ObjectIdentifier SHA256RSA = + new ASN1ObjectIdentifier("1.2.840.113549.1.1.11").intern(); + private static final ASN1ObjectIdentifier SHA384RSA = + new ASN1ObjectIdentifier("1.2.840.113549.1.1.12").intern(); + private static final ASN1ObjectIdentifier SHA512RSA = + new ASN1ObjectIdentifier("1.2.840.113549.1.1.13").intern(); + + private static final String ALGORITHM_SHA1_NAME = "SHA-1"; + + /** Object mapper for JSON encoding and decoding */ + private static final ObjectMapper OBJECT_MAPPER = ObjectMapperFactory.getObjectMapper(); + + /** System property name to specify cache directory. */ + private static final String CACHE_DIR_PROP = "net.snowflake.jdbc.ocspResponseCacheDir"; + + /** Environment name to specify the cache directory. Used if system property not set. */ + private static final String CACHE_DIR_ENV = "SF_OCSP_RESPONSE_CACHE_DIR"; + + /** OCSP response cache entry expiration time (s) */ + private static final long CACHE_EXPIRATION_IN_SECONDS = 432000L; + + /** OCSP response cache lock file expiration time (s) */ + private static final long CACHE_FILE_LOCK_EXPIRATION_IN_SECONDS = 60L; + + /** Default OCSP Cache server connection timeout */ + private static final int DEFAULT_OCSP_CACHE_SERVER_CONNECTION_TIMEOUT = 5000; + + /** Default OCSP responder connection timeout */ + private static final int DEFAULT_OCSP_RESPONDER_CONNECTION_TIMEOUT = 10000; + + /** Default OCSP Cache server host name prefix */ + private static final String DEFAULT_OCSP_CACHE_HOST_PREFIX = "http://ocsp.snowflakecomputing."; + + /** Default domain for OCSP cache host */ + private static final String DEFAULT_OCSP_CACHE_HOST_DOMAIN = "com"; + + /** OCSP response file cache directory */ + private static final FileCacheManager fileCacheManager; + + /** Tolerable validity date range ratio. */ + private static final float TOLERABLE_VALIDITY_RANGE_RATIO = 0.01f; + + /** Maximum clocktime skew (ms) */ + private static final long MAX_CLOCK_SKEW_IN_MILLISECONDS = 900000L; + + /** Minimum cache warm up time (ms) */ + private static final long MIN_CACHE_WARMUP_TIME_IN_MILLISECONDS = 18000000L; + + /** Initial sleeping time in retry (ms) */ + private static final long INITIAL_SLEEPING_TIME_IN_MILLISECONDS = 1000L; + + /** Maximum sleeping time in retry (ms) */ + private static final long MAX_SLEEPING_TIME_IN_MILLISECONDS = 16000L; + + /** Map from signature algorithm ASN1 object to the name. */ + private static final Map SIGNATURE_OID_TO_STRING = + new ConcurrentHashMap<>(); + + /** Map from OCSP response code to a string representation. */ + private static final Map OCSP_RESPONSE_CODE_TO_STRING = + new ConcurrentHashMap<>(); + + private static final Object ROOT_CA_LOCK = new Object(); + + /** OCSP Response cache */ + private static final Map> OCSP_RESPONSE_CACHE = + new ConcurrentHashMap<>(); + + /** Date and timestamp format */ + private static final SimpleDateFormat DATE_FORMAT_UTC = + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + + /** OCSP Response Cache server Retry URL pattern */ + static String SF_OCSP_RESPONSE_CACHE_SERVER_RETRY_URL_PATTERN; + + /** OCSP response cache server URL. */ + static String SF_OCSP_RESPONSE_CACHE_SERVER_URL_VALUE; + + private static JcaX509CertificateConverter CONVERTER_X509 = new JcaX509CertificateConverter(); + + /** RootCA cache */ + private static Map ROOT_CA = new ConcurrentHashMap<>(); + + private static final AtomicBoolean WAS_CACHE_UPDATED = new AtomicBoolean(); + private static final AtomicBoolean WAS_CACHE_READ = new AtomicBoolean(); + + /** OCSP HTTP client */ + private static Map ocspCacheServerClient = + new ConcurrentHashMap<>(); + + /** OCSP event types */ + public static String SF_OCSP_EVENT_TYPE_REVOKED_CERTIFICATE_ERROR = "RevokedCertificateError"; + + public static String SF_OCSP_EVENT_TYPE_VALIDATION_ERROR = "OCSPValidationError"; + + static { + // init OCSP response cache file manager + fileCacheManager = + FileCacheManager.builder() + .setCacheDirectorySystemProperty(CACHE_DIR_PROP) + .setCacheDirectoryEnvironmentVariable(CACHE_DIR_ENV) + .setBaseCacheFileName(CACHE_FILE_NAME) + .setCacheFileLockExpirationInSeconds(CACHE_FILE_LOCK_EXPIRATION_IN_SECONDS) + .setOnlyOwnerPermissions(false) + .build(); + } + + static { + SIGNATURE_OID_TO_STRING.put(SHA1RSA, "SHA1withRSA"); + SIGNATURE_OID_TO_STRING.put(SHA256RSA, "SHA256withRSA"); + SIGNATURE_OID_TO_STRING.put(SHA384RSA, "SHA384withRSA"); + SIGNATURE_OID_TO_STRING.put(SHA512RSA, "SHA512withRSA"); + } + + static { + OCSP_RESPONSE_CODE_TO_STRING.put(OCSPResp.SUCCESSFUL, "successful"); + OCSP_RESPONSE_CODE_TO_STRING.put(OCSPResp.MALFORMED_REQUEST, "malformedRequest"); + OCSP_RESPONSE_CODE_TO_STRING.put(OCSPResp.INTERNAL_ERROR, "internalError"); + OCSP_RESPONSE_CODE_TO_STRING.put(OCSPResp.TRY_LATER, "tryLater"); + OCSP_RESPONSE_CODE_TO_STRING.put(OCSPResp.SIG_REQUIRED, "sigRequired"); + OCSP_RESPONSE_CODE_TO_STRING.put(OCSPResp.UNAUTHORIZED, "unauthorized"); + } + + static { + DATE_FORMAT_UTC.setTimeZone(TimeZone.getTimeZone("UTC")); + } + + /** The default JVM Trust manager. */ + private final X509TrustManager trustManager; + + /** The default JVM Extended Trust Manager */ + private final X509ExtendedTrustManager exTrustManager; + + OCSPCacheServer ocspCacheServer = new OCSPCacheServer(); + + /** OCSP mode */ + private OCSPMode ocspMode; + + private static HttpClientSettingsKey proxySettingsKey; + + /** + * Constructor with the cache file. If not specified, the default cachefile is used. + * + * @param key HttpClientSettingsKey + * @param cacheFile cache file. + */ + SFTrustManager(HttpClientSettingsKey key, File cacheFile) { + this.ocspMode = key.getOcspMode(); + this.proxySettingsKey = key; + this.trustManager = getTrustManager(TrustManagerFactory.getDefaultAlgorithm()); + + if (trustManager instanceof X509ExtendedTrustManager) { + this.exTrustManager = (X509ExtendedTrustManager) trustManager; + } else { + logger.debug("Standard X509TrustManager is used instead of X509ExtendedTrustManager."); + this.exTrustManager = null; + } + + checkNewOCSPEndpointAvailability(); + + if (cacheFile != null) { + fileCacheManager.overrideCacheFile(cacheFile); + } + if (!WAS_CACHE_READ.getAndSet(true)) { + // read cache file once + JsonNode res = fileCacheManager.readCacheFile(); + readJsonStoreCache(res); + } + + logger.debug( + "Initializing trust manager with OCSP mode: {}, cache file: {}", ocspMode, cacheFile); + } + + /** Deletes OCSP response cache file from disk. */ + public static void deleteCache() { + fileCacheManager.deleteCacheFile(); + } + + public static void cleanTestSystemParameters() { + System.clearProperty(SF_OCSP_RESPONSE_CACHE_SERVER_URL); + System.clearProperty(SF_OCSP_RESPONSE_CACHE_SERVER_ENABLED); + System.clearProperty(SF_OCSP_TEST_INJECT_VALIDITY_ERROR); + System.clearProperty(SF_OCSP_TEST_INJECT_UNKNOWN_STATUS); + System.clearProperty(SF_OCSP_TEST_RESPONDER_URL); + System.clearProperty(SF_OCSP_TEST_OCSP_RESPONDER_TIMEOUT); + System.clearProperty(SF_OCSP_TEST_OCSP_RESPONSE_CACHE_SERVER_TIMEOUT); + System.clearProperty(SF_OCSP_TEST_INVALID_SIGNING_CERT); + System.clearProperty(SF_OCSP_TEST_NO_OCSP_RESPONDER_URL); + } + + /** + * Reset OCSP Cache server URL + * + * @param ocspCacheServerUrl OCSP Cache server URL + */ + static void resetOCSPResponseCacherServerURL(String ocspCacheServerUrl) throws IOException { + if (ocspCacheServerUrl == null || SF_OCSP_RESPONSE_CACHE_SERVER_RETRY_URL_PATTERN != null) { + return; + } + SF_OCSP_RESPONSE_CACHE_SERVER_URL_VALUE = ocspCacheServerUrl; + if (!SF_OCSP_RESPONSE_CACHE_SERVER_URL_VALUE.startsWith(DEFAULT_OCSP_CACHE_HOST_PREFIX)) { + URL url = new URL(SF_OCSP_RESPONSE_CACHE_SERVER_URL_VALUE); + if (url.getPort() > 0) { + SF_OCSP_RESPONSE_CACHE_SERVER_RETRY_URL_PATTERN = + String.format( + "%s://%s:%d/retry/%s", url.getProtocol(), url.getHost(), url.getPort(), "%s/%s"); + } else { + SF_OCSP_RESPONSE_CACHE_SERVER_RETRY_URL_PATTERN = + String.format("%s://%s/retry/%s", url.getProtocol(), url.getHost(), "%s/%s"); + } + logger.debug( + "Reset OCSP response cache server URL to: {}", + SF_OCSP_RESPONSE_CACHE_SERVER_RETRY_URL_PATTERN); + } + } + + static void setOCSPResponseCacheServerURL(String serverURL) { + String ocspCacheUrl = systemGetProperty(SF_OCSP_RESPONSE_CACHE_SERVER_URL); + if (ocspCacheUrl != null) { + SF_OCSP_RESPONSE_CACHE_SERVER_URL_VALUE = ocspCacheUrl; + } + try { + ocspCacheUrl = systemGetEnv(SF_OCSP_RESPONSE_CACHE_SERVER_URL); + if (ocspCacheUrl != null) { + SF_OCSP_RESPONSE_CACHE_SERVER_URL_VALUE = ocspCacheUrl; + } + } catch (Throwable ex) { + logger.debug( + "Failed to get environment variable " + SF_OCSP_RESPONSE_CACHE_SERVER_URL + ". Ignored", + true); + } + if (SF_OCSP_RESPONSE_CACHE_SERVER_URL_VALUE == null) { + String topLevelDomain = DEFAULT_OCSP_CACHE_HOST_DOMAIN; + try { + URL url = new URL(serverURL); + int domainIndex = url.getHost().lastIndexOf(".") + 1; + topLevelDomain = url.getHost().substring(domainIndex); + } catch (Exception e) { + logger.debug("Exception while setting top level domain (for OCSP)", e); + } + SF_OCSP_RESPONSE_CACHE_SERVER_URL_VALUE = + String.format("%s%s/%s", DEFAULT_OCSP_CACHE_HOST_PREFIX, topLevelDomain, CACHE_FILE_NAME); + } + logger.debug("Set OCSP response cache server to: {}", SF_OCSP_RESPONSE_CACHE_SERVER_URL_VALUE); + } + + private static boolean useOCSPResponseCacheServer() { + String ocspCacheServerEnabled = systemGetProperty(SF_OCSP_RESPONSE_CACHE_SERVER_ENABLED); + if (Boolean.FALSE.toString().equalsIgnoreCase(ocspCacheServerEnabled)) { + logger.debug("No OCSP Response Cache Server is used.", false); + return false; + } + try { + ocspCacheServerEnabled = systemGetEnv(SF_OCSP_RESPONSE_CACHE_SERVER_ENABLED); + if (Boolean.FALSE.toString().equalsIgnoreCase(ocspCacheServerEnabled)) { + logger.debug("No OCSP Response Cache Server is used.", false); + return false; + } + } catch (Throwable ex) { + logger.debug( + "Failed to get environment variable " + + SF_OCSP_RESPONSE_CACHE_SERVER_ENABLED + + ". Ignored", + false); + } + return true; + } + + /** + * Convert cache key to base64 encoded cert id + * + * @param ocsp_cache_key Cache key to encode + */ + private static String encodeCacheKey(OcspResponseCacheKey ocsp_cache_key) { + try { + DigestCalculator digest = new SHA1DigestCalculator(); + AlgorithmIdentifier algo = digest.getAlgorithmIdentifier(); + ASN1OctetString nameHash = ASN1OctetString.getInstance(ocsp_cache_key.nameHash); + ASN1OctetString keyHash = ASN1OctetString.getInstance(ocsp_cache_key.keyHash); + ASN1Integer snumber = new ASN1Integer(ocsp_cache_key.serialNumber); + CertID cid = new CertID(algo, nameHash, keyHash, snumber); + return Base64.encodeBase64String(cid.toASN1Primitive().getEncoded()); + } catch (Exception ex) { + logger.debug("Failed to encode cache key to base64 encoded cert id", false); + } + return null; + } + + /** + * CertificateID to string + * + * @param certificateID CertificateID + * @return a string representation of CertificateID + */ + private static String CertificateIDToString(CertificateID certificateID) { + return String.format( + "CertID. NameHash: %s, KeyHash: %s, Serial Number: %s", + HexUtil.byteToHexString(certificateID.getIssuerNameHash()), + HexUtil.byteToHexString(certificateID.getIssuerKeyHash()), + MessageFormat.format("{0,number,#}", certificateID.getSerialNumber())); + } + + /** + * Decodes OCSP Response Cache key from JSON + * + * @param elem A JSON element + * @return OcspResponseCacheKey object + */ + private static SFPair> decodeCacheFromJSON( + Map.Entry elem) throws IOException { + long currentTimeSecond = new Date().getTime() / 1000; + byte[] certIdDer = Base64.decodeBase64(elem.getKey()); + DLSequence rawCertId = (DLSequence) ASN1ObjectIdentifier.fromByteArray(certIdDer); + ASN1Encodable[] rawCertIdArray = rawCertId.toArray(); + byte[] issuerNameHashDer = ((DEROctetString) rawCertIdArray[1]).getEncoded(); + byte[] issuerKeyHashDer = ((DEROctetString) rawCertIdArray[2]).getEncoded(); + BigInteger serialNumber = ((ASN1Integer) rawCertIdArray[3]).getValue(); + + OcspResponseCacheKey k = + new OcspResponseCacheKey(issuerNameHashDer, issuerKeyHashDer, serialNumber); + + JsonNode ocspRespBase64 = elem.getValue(); + if (!ocspRespBase64.isArray() || ocspRespBase64.size() != 2) { + logger.debug("Invalid cache file format. Ignored", false); + return null; + } + long producedAt = ocspRespBase64.get(0).asLong(); + String ocspResp = ocspRespBase64.get(1).asText(); + + if (currentTimeSecond - CACHE_EXPIRATION_IN_SECONDS <= producedAt) { + // add cache + return SFPair.of(k, SFPair.of(producedAt, ocspResp)); + } else { + // delete cache + return SFPair.of(k, SFPair.of(producedAt, null)); + } + } + + /** + * Encode OCSP Response Cache to JSON + * + * @return JSON object + */ + private static ObjectNode encodeCacheToJSON() { + try { + ObjectNode out = OBJECT_MAPPER.createObjectNode(); + for (Map.Entry> elem : + OCSP_RESPONSE_CACHE.entrySet()) { + OcspResponseCacheKey key = elem.getKey(); + SFPair value0 = elem.getValue(); + long currentTimeSecond = value0.left; + + DigestCalculator digest = new SHA1DigestCalculator(); + AlgorithmIdentifier algo = digest.getAlgorithmIdentifier(); + ASN1OctetString nameHash = ASN1OctetString.getInstance(key.nameHash); + ASN1OctetString keyHash = ASN1OctetString.getInstance(key.keyHash); + ASN1Integer serialNumber = new ASN1Integer(key.serialNumber); + CertID cid = new CertID(algo, nameHash, keyHash, serialNumber); + ArrayNode vout = OBJECT_MAPPER.createArrayNode(); + vout.add(currentTimeSecond); + vout.add(value0.right); + out.set(Base64.encodeBase64String(cid.toASN1Primitive().getEncoded()), vout); + } + return out; + } catch (IOException ex) { + logger.debug("Failed to encode ASN1 object.", false); + } + return null; + } + + private static synchronized void readJsonStoreCache(JsonNode m) { + if (m == null || !m.getNodeType().equals(JsonNodeType.OBJECT)) { + logger.debug("Invalid cache file format.", false); + return; + } + try { + for (Iterator> itr = m.fields(); itr.hasNext(); ) { + SFPair> ky = decodeCacheFromJSON(itr.next()); + if (ky != null && ky.right != null && ky.right.right != null) { + // valid range. cache the result in memory + OCSP_RESPONSE_CACHE.put(ky.left, ky.right); + WAS_CACHE_UPDATED.set(true); + } else if (ky != null && OCSP_RESPONSE_CACHE.containsKey(ky.left)) { + // delete it from the cache if no OCSP response is back. + OCSP_RESPONSE_CACHE.remove(ky.left); + WAS_CACHE_UPDATED.set(true); + } + } + } catch (IOException ex) { + logger.debug("Failed to decode the cache file", false); + } + } + + /** + * Verifies the signature of the data + * + * @param cert a certificate for public key. + * @param sig signature in a byte array. + * @param data data in a byte array. + * @param idf algorithm identifier object. + * @throws CertificateException raises if the verification fails. + */ + private static void verifySignature( + X509CertificateHolder cert, byte[] sig, byte[] data, AlgorithmIdentifier idf) + throws CertificateException { + try { + String algorithm = SIGNATURE_OID_TO_STRING.get(idf.getAlgorithm()); + if (algorithm == null) { + throw new NoSuchAlgorithmException( + String.format("Unsupported signature OID. OID: %s", idf)); + } + Signature signer = Signature.getInstance(algorithm); + + X509Certificate c = CONVERTER_X509.getCertificate(cert); + signer.initVerify(c.getPublicKey()); + signer.update(data); + if (!signer.verify(sig)) { + throw new CertificateEncodingException( + String.format( + "Failed to verify the signature. Potentially the " + + "data was not generated by by the cert, %s", + cert.getSubject())); + } + } catch (NoSuchAlgorithmException | InvalidKeyException | SignatureException ex) { + throw new CertificateEncodingException("Failed to verify the signature.", ex); + } + } + + /** + * Gets HttpClient object + * + * @return HttpClient + */ + private static CloseableHttpClient getHttpClient(int timeout) { + RequestConfig config = + RequestConfig.custom() + .setConnectTimeout(timeout) + .setConnectionRequestTimeout(timeout) + .setSocketTimeout(timeout) + .build(); + + Registry registry = + RegistryBuilder.create() + .register("http", new JdbcHttpUtil.SFConnectionSocketFactory()) + .build(); + + // Build a connection manager with enough connections + PoolingHttpClientConnectionManager connectionManager = + new PoolingHttpClientConnectionManager(registry); + connectionManager.setMaxTotal(1); + connectionManager.setDefaultMaxPerRoute(10); + + HttpClientBuilder httpClientBuilder = + HttpClientBuilder.create() + .setDefaultRequestConfig(config) + .setConnectionManager(connectionManager) + // Support JVM proxy settings + .useSystemProperties() + .setRedirectStrategy(new DefaultRedirectStrategy()) + .disableCookieManagement(); + + if (proxySettingsKey.usesProxy()) { + // use the custom proxy properties + HttpHost proxy = + new HttpHost(proxySettingsKey.getProxyHost(), proxySettingsKey.getProxyPort()); + SdkProxyRoutePlanner sdkProxyRoutePlanner = + new SdkProxyRoutePlanner( + proxySettingsKey.getProxyHost(), + proxySettingsKey.getProxyPort(), + Protocol.HTTP, + proxySettingsKey.getNonProxyHosts()); + httpClientBuilder = httpClientBuilder.setProxy(proxy).setRoutePlanner(sdkProxyRoutePlanner); + if (!isNullOrEmpty(proxySettingsKey.getProxyUser()) + && !isNullOrEmpty(proxySettingsKey.getProxyPassword())) { + Credentials credentials = + new UsernamePasswordCredentials( + proxySettingsKey.getProxyUser(), proxySettingsKey.getProxyPassword()); + AuthScope authScope = + new AuthScope(proxySettingsKey.getProxyHost(), proxySettingsKey.getProxyPort()); + CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + credentialsProvider.setCredentials(authScope, credentials); + httpClientBuilder = httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider); + } + } + // using the default HTTP client + return httpClientBuilder.build(); + } + + private static long maxLong(long v1, long v2) { + return Math.max(v1, v2); + } + + /** + * Calculates the tolerable validity time beyond the next update. + * + *

Sometimes CA's OCSP response update is delayed beyond the clock skew as the update is not + * populated to all OCSP servers for certain period. + * + * @param thisUpdate the last update + * @param nextUpdate the next update + * @return the tolerable validity beyond the next update. + */ + private static long calculateTolerableValidity(Date thisUpdate, Date nextUpdate) { + return maxLong( + (long) + ((float) (nextUpdate.getTime() - thisUpdate.getTime()) + * TOLERABLE_VALIDITY_RANGE_RATIO), + MIN_CACHE_WARMUP_TIME_IN_MILLISECONDS); + } + + /** + * Checks the validity + * + * @param currentTime the current time + * @param thisUpdate the last update timestamp + * @param nextUpdate the next update timestamp + * @return true if valid or false + */ + private static boolean isValidityRange(Date currentTime, Date thisUpdate, Date nextUpdate) { + if (checkOCSPResponseValidityErrorParameter()) { + return false; // test + } + long tolerableValidity = calculateTolerableValidity(thisUpdate, nextUpdate); + return thisUpdate.getTime() - MAX_CLOCK_SKEW_IN_MILLISECONDS <= currentTime.getTime() + && currentTime.getTime() <= nextUpdate.getTime() + tolerableValidity; + } + + private static boolean checkOCSPResponseValidityErrorParameter() { + String injectValidityError = systemGetProperty(SF_OCSP_TEST_INJECT_VALIDITY_ERROR); + return Boolean.TRUE.toString().equalsIgnoreCase(injectValidityError); + } + + /** + * Is the test parameter enabled? + * + * @param key the test parameter + * @return true if enabled otherwise false + */ + private boolean isEnabledSystemTestParameter(String key) { + return Boolean.TRUE.toString().equalsIgnoreCase(systemGetProperty(key)); + } + + /** fail open mode current state */ + private boolean isOCSPFailOpen() { + return ocspMode == OCSPMode.FAIL_OPEN; + } + + private void checkNewOCSPEndpointAvailability() { + String new_ocsp_ept; + try { + new_ocsp_ept = systemGetEnv("SF_OCSP_ACTIVATE_NEW_ENDPOINT"); + } catch (Throwable ex) { + logger.debug( + "Could not get environment variable to check for New OCSP Endpoint Availability", false); + new_ocsp_ept = systemGetProperty("net.snowflake.jdbc.ocsp_activate_new_endpoint"); + } + ocspCacheServer.new_endpoint_enabled = new_ocsp_ept != null; + } + + /** + * Get TrustManager for the algorithm. This is mainly used to get the JVM default trust manager + * and cache all of the root CA. + * + * @param algorithm algorithm. + * @return TrustManager object. + */ + private X509TrustManager getTrustManager(String algorithm) { + try { + TrustManagerFactory factory = TrustManagerFactory.getInstance(algorithm); + factory.init((KeyStore) null); + X509TrustManager ret = null; + for (TrustManager tm : factory.getTrustManagers()) { + // Multiple TrustManager may be attached. We just need X509 Trust + // Manager here. + if (tm instanceof X509TrustManager) { + ret = (X509TrustManager) tm; + break; + } + } + if (ret == null) { + return null; + } + synchronized (ROOT_CA_LOCK) { + // cache root CA certificates for later use. + if (ROOT_CA.isEmpty()) { + for (X509Certificate cert : ret.getAcceptedIssuers()) { + Certificate bcCert = Certificate.getInstance(cert.getEncoded()); + ROOT_CA.put(bcCert.getSubject().hashCode(), bcCert); + } + } + } + return ret; + } catch (NoSuchAlgorithmException | KeyStoreException | CertificateEncodingException ex) { + throw new SSLInitializationException(ex.getMessage(), ex); + } + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + // default behavior + trustManager.checkClientTrusted(chain, authType); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + trustManager.checkServerTrusted(chain, authType); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, java.net.Socket socket) + throws CertificateException { + if (exTrustManager != null) { + exTrustManager.checkClientTrusted(chain, authType, socket); + } else { + trustManager.checkClientTrusted(chain, authType); + } + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine sslEngine) + throws CertificateException { + if (exTrustManager != null) { + exTrustManager.checkClientTrusted(chain, authType, sslEngine); + } else { + trustManager.checkClientTrusted(chain, authType); + } + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, java.net.Socket socket) + throws CertificateException { + if (exTrustManager != null) { + exTrustManager.checkServerTrusted(chain, authType, socket); + } else { + trustManager.checkServerTrusted(chain, authType); + } + String host = socket.getInetAddress().getHostName(); + this.validateRevocationStatus(chain, host); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine sslEngine) + throws CertificateException { + if (exTrustManager != null) { + exTrustManager.checkServerTrusted(chain, authType, sslEngine); + } else { + trustManager.checkServerTrusted(chain, authType); + } + this.validateRevocationStatus(chain, sslEngine.getPeerHost()); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return trustManager.getAcceptedIssuers(); + } + + /** + * Certificate Revocation checks + * + * @param chain chain of certificates attached. + * @param peerHost Hostname of the server + * @throws CertificateException if any certificate validation fails + */ + void validateRevocationStatus(X509Certificate[] chain, String peerHost) + throws CertificateException { + final List bcChain = convertToBouncyCastleCertificate(chain); + final List> pairIssuerSubjectList = + getPairIssuerSubject(bcChain); + + if (peerHost.startsWith("ocspssd")) { + return; + } + + if (ocspCacheServer.new_endpoint_enabled) { + ocspCacheServer.resetOCSPResponseCacheServer(peerHost); + } + + boolean isCached = isCached(pairIssuerSubjectList); + if (useOCSPResponseCacheServer() && !isCached) { + if (!ocspCacheServer.new_endpoint_enabled) { + logger.debug( + "Downloading OCSP response cache from the server. URL: {}", + SF_OCSP_RESPONSE_CACHE_SERVER_URL_VALUE); + } else { + logger.debug( + "Downloading OCSP response cache from the server. URL: {}", + ocspCacheServer.SF_OCSP_RESPONSE_CACHE_SERVER); + } + try { + readOcspResponseCacheServer(); + } catch (SFOCSPException ex) { + logger.debug( + "Error downloading OCSP Response from cache server : {}." + + "OCSP Responses will be fetched directly from the CA OCSP" + + "Responder ", + ex.getMessage()); + } + // if the cache is downloaded from the server, it should be written + // to the file cache at all times. + } + executeRevocationStatusChecks(pairIssuerSubjectList, peerHost); + if (WAS_CACHE_UPDATED.getAndSet(false)) { + JsonNode input = encodeCacheToJSON(); + fileCacheManager.writeCacheFile(input); + } + } + + /** + * Executes the revocation status checks for all chained certificates + * + * @param pairIssuerSubjectList a list of pair of issuer and subject certificates. + * @throws CertificateException raises if any error occurs. + */ + private void executeRevocationStatusChecks( + List> pairIssuerSubjectList, String peerHost) + throws CertificateException { + long currentTimeSecond = new Date().getTime() / 1000L; + for (SFPair pairIssuerSubject : pairIssuerSubjectList) { + executeOneRevocationStatusCheck(pairIssuerSubject, currentTimeSecond, peerHost); + } + } + + private String generateFailOpenLog(String logData) { + return "OCSP responder didn't respond correctly. Assuming certificate is " + + "not revoked. Details: " + + logData; + } + + /** + * Executes a single revocation status check + * + * @param pairIssuerSubject a pair of issuer and subject certificate + * @param currentTimeSecond the current timestamp + * @throws CertificateException if certificate exception is raised. + */ + private void executeOneRevocationStatusCheck( + SFPair pairIssuerSubject, long currentTimeSecond, String peerHost) + throws CertificateException { + OCSPReq req; + OcspResponseCacheKey keyOcspResponse; + try { + req = createRequest(pairIssuerSubject); + CertID cid = req.getRequestList()[0].getCertID().toASN1Primitive(); + keyOcspResponse = + new OcspResponseCacheKey( + cid.getIssuerNameHash().getEncoded(), + cid.getIssuerKeyHash().getEncoded(), + cid.getSerialNumber().getValue()); + } catch (IOException ex) { + throw new CertificateException(ex.getMessage(), ex); + } + + long sleepTime = INITIAL_SLEEPING_TIME_IN_MILLISECONDS; + DecorrelatedJitterBackoff backoff = + new DecorrelatedJitterBackoff(sleepTime, MAX_SLEEPING_TIME_IN_MILLISECONDS); + CertificateException error; + boolean success = false; + String ocspLog; + OCSPTelemetryData telemetryData = new OCSPTelemetryData(); + telemetryData.setSfcPeerHost(peerHost); + telemetryData.setCertId(encodeCacheKey(keyOcspResponse)); + telemetryData.setCacheEnabled(useOCSPResponseCacheServer()); + telemetryData.setOCSPMode(ocspMode); + Throwable cause = null; + try { + final int maxRetryCounter = isOCSPFailOpen() ? 1 : 2; + for (int retry = 0; retry < maxRetryCounter; ++retry) { + try { + SFPair value0 = OCSP_RESPONSE_CACHE.get(keyOcspResponse); + OCSPResp ocspResp; + try { + try { + if (value0 == null) { + telemetryData.setCacheHit(false); + ocspResp = + fetchOcspResponse( + pairIssuerSubject, + req, + encodeCacheKey(keyOcspResponse), + peerHost, + telemetryData); + + OCSP_RESPONSE_CACHE.put( + keyOcspResponse, SFPair.of(currentTimeSecond, ocspResponseToB64(ocspResp))); + WAS_CACHE_UPDATED.set(true); + value0 = SFPair.of(currentTimeSecond, ocspResponseToB64(ocspResp)); + } else { + telemetryData.setCacheHit(true); + } + } catch (Throwable ex) { + logger.debug( + "Exception occurred while trying to fetch OCSP Response - {}", ex.getMessage()); + throw new SFOCSPException( + OCSPErrorCode.OCSP_RESPONSE_FETCH_FAILURE, + "Exception occurred while trying to fetch OCSP Response", + ex); + } + + logger.debug( + "Validating. {}", CertificateIDToString(req.getRequestList()[0].getCertID())); + try { + validateRevocationStatusMain(pairIssuerSubject, value0.right); + success = true; + break; + } catch (SFOCSPException ex) { + if (ex.getErrorCode() != OCSPErrorCode.REVOCATION_CHECK_FAILURE) { + throw ex; + } + throw new CertificateException(ex.getMessage(), ex); + } + } catch (SFOCSPException ex) { + if (ex.getErrorCode() == OCSPErrorCode.CERTIFICATE_STATUS_REVOKED) { + throw ex; + } else { + throw new CertificateException(ex.getMessage(), ex); + } + } + } catch (CertificateException ex) { + WAS_CACHE_UPDATED.set(OCSP_RESPONSE_CACHE.remove(keyOcspResponse) != null); + if (WAS_CACHE_UPDATED.get()) { + logger.debug("Deleting the invalid OCSP cache.", false); + } + + cause = ex; + logger.debug( + "Retrying {}/{} after sleeping {} ms", retry + 1, maxRetryCounter, sleepTime); + try { + if (retry + 1 < maxRetryCounter) { + Thread.sleep(sleepTime); + sleepTime = backoff.nextSleepTime(sleepTime); + } + } catch (InterruptedException ex0) { // nop + } + } + } + } catch (SFOCSPException ex) { + // Revoked Certificate + error = new CertificateException(ex); + ocspLog = + telemetryData.generateTelemetry(SF_OCSP_EVENT_TYPE_REVOKED_CERTIFICATE_ERROR, error); + logger.error(ocspLog, false); + throw error; + } + + if (!success) { + if (cause != null) // cause is set in the above catch block + { + error = + new CertificateException( + "Certificate Revocation check failed. Could not retrieve OCSP Response.", cause); + logger.debug(cause.getMessage(), false); + } else { + error = + new CertificateException( + "Certificate Revocation check failed. Could not retrieve OCSP Response."); + logger.debug(error.getMessage(), false); + } + + ocspLog = telemetryData.generateTelemetry(SF_OCSP_EVENT_TYPE_VALIDATION_ERROR, error); + if (isOCSPFailOpen()) { + // Log includes fail-open warning. + logger.debug(generateFailOpenLog(ocspLog), false); + } else { + // still not success, raise an error. + logger.debug(ocspLog, false); + throw error; + } + } + } + + /** + * Is OCSP Response cached? + * + * @param pairIssuerSubjectList a list of pair of issuer and subject certificates + * @return true if all of OCSP response are cached else false + */ + private boolean isCached(List> pairIssuerSubjectList) { + long currentTimeSecond = new Date().getTime() / 1000L; + boolean isCached = true; + try { + for (SFPair pairIssuerSubject : pairIssuerSubjectList) { + OCSPReq req = createRequest(pairIssuerSubject); + CertificateID certificateId = req.getRequestList()[0].getCertID(); + logger.debug(CertificateIDToString(certificateId), false); + CertID cid = certificateId.toASN1Primitive(); + OcspResponseCacheKey k = + new OcspResponseCacheKey( + cid.getIssuerNameHash().getEncoded(), + cid.getIssuerKeyHash().getEncoded(), + cid.getSerialNumber().getValue()); + + SFPair res = OCSP_RESPONSE_CACHE.get(k); + if (res == null) { + logger.debug("Not all OCSP responses for the certificate is in the cache.", false); + isCached = false; + break; + } else if (currentTimeSecond - CACHE_EXPIRATION_IN_SECONDS > res.left) { + logger.debug("Cache for CertID expired.", false); + isCached = false; + break; + } else { + try { + validateRevocationStatusMain(pairIssuerSubject, res.right); + } catch (SFOCSPException ex) { + logger.debug( + "Cache includes invalid OCSPResponse. " + + "Will download the OCSP cache from Snowflake OCSP server", + false); + isCached = false; + } + } + } + } catch (IOException ex) { + logger.debug("Failed to encode CertID.", false); + } + return isCached; + } + + /** Reads the OCSP response cache from the server. */ + private void readOcspResponseCacheServer() throws SFOCSPException { + String ocspCacheServerInUse; + + if (ocspCacheServer.new_endpoint_enabled) { + ocspCacheServerInUse = ocspCacheServer.SF_OCSP_RESPONSE_CACHE_SERVER; + } else { + ocspCacheServerInUse = SF_OCSP_RESPONSE_CACHE_SERVER_URL_VALUE; + } + + CloseableHttpResponse response = null; + CloseableHttpClient httpClient = + ocspCacheServerClient.computeIfAbsent( + getOCSPCacheServerConnectionTimeout(), + k -> getHttpClient(getOCSPCacheServerConnectionTimeout())); + try { + URI uri = new URI(ocspCacheServerInUse); + HttpGet get = new HttpGet(uri); + response = httpClient.execute(get); + if (response == null || response.getStatusLine().getStatusCode() != HttpStatus.SC_OK) { + throw new IOException( + String.format( + "Failed to get the OCSP response from the OCSP " + "cache server: HTTP: %d", + response != null ? response.getStatusLine().getStatusCode() : -1)); + } + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + IOUtils.copy(response.getEntity().getContent(), out); + JsonNode m = OBJECT_MAPPER.readTree(out.toByteArray()); + out.close(); + readJsonStoreCache(m); + logger.debug("Successfully downloaded OCSP cache from the server.", false); + } catch (IOException ex) { + logger.debug( + "Failed to read the OCSP response cache from the server. " + "Server: {}, Err: {}", + ocspCacheServerInUse, + ex); + } catch (URISyntaxException ex) { + logger.debug("Indicate that a string could not be parsed as a URI reference.", false); + throw new SFOCSPException( + OCSPErrorCode.INVALID_CACHE_SERVER_URL, "Invalid OCSP Cache Server URL used", ex); + } finally { + IOUtils.closeQuietly(response); + } + } + + private int getOCSPCacheServerConnectionTimeout() { + int timeout = DEFAULT_OCSP_CACHE_SERVER_CONNECTION_TIMEOUT; + if (systemGetProperty(SF_OCSP_TEST_OCSP_RESPONSE_CACHE_SERVER_TIMEOUT) != null) { + try { + timeout = + Integer.parseInt(systemGetProperty(SF_OCSP_TEST_OCSP_RESPONSE_CACHE_SERVER_TIMEOUT)); + } catch (Exception ex) { + // nop + } + } + return timeout; + } + + /** + * Fetches OCSP response from OCSP server + * + * @param pairIssuerSubject a pair of issuer and subject certificates + * @param req OCSP Request object + * @return OCSP Response object + * @throws CertificateEncodingException if any other error occurs + */ + private OCSPResp fetchOcspResponse( + SFPair pairIssuerSubject, + OCSPReq req, + String cid_enc, + String hname, + OCSPTelemetryData telemetryData) + throws CertificateEncodingException { + CloseableHttpResponse response = null; + try { + byte[] ocspReqDer = req.getEncoded(); + String ocspReqDerBase64 = Base64.encodeBase64String(ocspReqDer); + Set ocspUrls = getOcspUrls(pairIssuerSubject.right); + checkExistOCSPURL(ocspUrls); + String ocspUrlStr = ocspUrls.iterator().next(); // first one + ocspUrlStr = overrideOCSPURL(ocspUrlStr); + telemetryData.setOcspUrl(ocspUrlStr); + telemetryData.setOcspReq(ocspReqDerBase64); + + URL url; + String path = ""; + if (!ocspCacheServer.new_endpoint_enabled) { + String urlEncodedOCSPReq = URLUtil.urlEncode(ocspReqDerBase64); + if (SF_OCSP_RESPONSE_CACHE_SERVER_RETRY_URL_PATTERN != null) { + URL ocspUrl = new URL(ocspUrlStr); + if (!isNullOrEmpty(ocspUrl.getPath())) { + path = ocspUrl.getPath(); + } + if (ocspUrl.getPort() > 0) { + url = + new URL( + String.format( + SF_OCSP_RESPONSE_CACHE_SERVER_RETRY_URL_PATTERN, + ocspUrl.getHost() + ":" + ocspUrl.getPort() + path, + urlEncodedOCSPReq)); + } else { + url = + new URL( + String.format( + SF_OCSP_RESPONSE_CACHE_SERVER_RETRY_URL_PATTERN, + ocspUrl.getHost() + path, + urlEncodedOCSPReq)); + } + + } else { + url = new URL(String.format("%s/%s", ocspUrlStr, urlEncodedOCSPReq)); + } + logger.debug("Not hit cache. Fetching OCSP response from CA OCSP server. {}", url); + } else { + url = new URL(ocspCacheServer.SF_OCSP_RESPONSE_RETRY_URL); + logger.debug( + "Not hit cache. Fetching OCSP response from Snowflake OCSP Response Fetcher. {}", url); + } + + long sleepTime = INITIAL_SLEEPING_TIME_IN_MILLISECONDS; + DecorrelatedJitterBackoff backoff = + new DecorrelatedJitterBackoff(sleepTime, MAX_SLEEPING_TIME_IN_MILLISECONDS); + boolean success = false; + + final int maxRetryCounter = isOCSPFailOpen() ? 1 : 2; + Exception savedEx = null; + CloseableHttpClient httpClient = + ocspCacheServerClient.computeIfAbsent( + getOCSPResponderConnectionTimeout(), + k -> getHttpClient(getOCSPResponderConnectionTimeout())); + + for (int retry = 0; retry < maxRetryCounter; ++retry) { + try { + if (!ocspCacheServer.new_endpoint_enabled) { + HttpGet get = new HttpGet(url.toString()); + response = httpClient.execute(get); + } else { + HttpPost post = new HttpPost(url.toString()); + post.setHeader("Content-Type", "application/json"); + OCSPPostReqData postReqData = + new OCSPPostReqData(ocspUrlStr, ocspReqDerBase64, cid_enc, hname); + String json_payload = OBJECT_MAPPER.writeValueAsString(postReqData); + post.setEntity(new StringEntity(json_payload, "utf-8")); + response = httpClient.execute(post); + } + success = + response != null && response.getStatusLine().getStatusCode() == HttpStatus.SC_OK; + if (success) { + break; + } + } catch (IOException ex) { + logger.debug("Failed to reach out OCSP responder: {}", ex.getMessage()); + savedEx = ex; + } + IOUtils.closeQuietly(response); + + logger.debug("Retrying {}/{} after sleeping {} ms", retry + 1, maxRetryCounter, sleepTime); + try { + if (retry + 1 < maxRetryCounter) { + Thread.sleep(sleepTime); + sleepTime = backoff.nextSleepTime(sleepTime); + } + } catch (InterruptedException ex0) { // nop + } + } + if (!success) { + throw new CertificateEncodingException( + String.format( + "Failed to get OCSP response. StatusCode: %d, URL: %s", + response == null ? null : response.getStatusLine().getStatusCode(), ocspUrlStr), + savedEx); + } + ByteArrayOutputStream out = new ByteArrayOutputStream(); + IOUtils.copy(response.getEntity().getContent(), out); + OCSPResp ocspResp = new OCSPResp(out.toByteArray()); + out.close(); + if (ocspResp.getStatus() != OCSPResp.SUCCESSFUL) { + throw new CertificateEncodingException( + String.format( + "Failed to get OCSP response. Status: %s", + OCSP_RESPONSE_CODE_TO_STRING.get(ocspResp.getStatus()))); + } + + return ocspResp; + } catch (IOException ex) { + throw new CertificateEncodingException("Failed to encode object.", ex); + } finally { + IOUtils.closeQuietly(response); + } + } + + private void checkExistOCSPURL(Set ocspUrls) throws CertificateEncodingException { + if (ocspUrls.size() == 0 || isEnabledSystemTestParameter(SF_OCSP_TEST_NO_OCSP_RESPONDER_URL)) { + throw new CertificateEncodingException( + "No OCSP Responder URL is attached to the certificate.", + new SFOCSPException( + OCSPErrorCode.NO_OCSP_URL_ATTACHED, + "No OCSP Responder URL is attached to the certificate.")); + } + } + + private int getOCSPResponderConnectionTimeout() { + int timeout = DEFAULT_OCSP_RESPONDER_CONNECTION_TIMEOUT; + if (systemGetProperty(SF_OCSP_TEST_OCSP_RESPONDER_TIMEOUT) != null) { + try { + timeout = Integer.parseInt(systemGetProperty(SF_OCSP_TEST_OCSP_RESPONDER_TIMEOUT)); + } catch (Exception ex) { + // nop + } + } + return timeout; + } + + private String overrideOCSPURL(String ocspURL) { + String ocspURLInput = systemGetProperty(SF_OCSP_TEST_RESPONDER_URL); + if (ocspURLInput != null) { + logger.debug("Overriding OCSP url to: {}", ocspURLInput); + return ocspURLInput; + } + logger.debug("Overriding OCSP url to: {}", ocspURL); + return ocspURL; + } + + /** + * Validates the certificate revocation status + * + * @param pairIssuerSubject a pair of issuer and subject certificates + * @param ocspRespB64 Base64 encoded OCSP Response object + * @throws SFOCSPException raises if any other error occurs + */ + private void validateRevocationStatusMain( + SFPair pairIssuerSubject, String ocspRespB64) + throws SFOCSPException { + try { + OCSPResp ocspResp = b64ToOCSPResp(ocspRespB64); + if (ocspResp == null) { + throw new SFOCSPException( + OCSPErrorCode.INVALID_OCSP_RESPONSE, "OCSP response is null. The content is invalid."); + } + Date currentTime = new Date(); + BasicOCSPResp basicOcspResp = (BasicOCSPResp) (ocspResp.getResponseObject()); + X509CertificateHolder[] attachedCerts = basicOcspResp.getCerts(); + X509CertificateHolder signVerifyCert; + checkInvalidSigningCertTestParameter(); + if (attachedCerts.length > 0) { + logger.debug( + "Certificate is attached for verification. " + + "Verifying it by the issuer certificate.", + false); + signVerifyCert = attachedCerts[0]; + if (currentTime.after(signVerifyCert.getNotAfter()) + || currentTime.before(signVerifyCert.getNotBefore())) { + throw new SFOCSPException( + OCSPErrorCode.EXPIRED_OCSP_SIGNING_CERTIFICATE, + String.format( + "Cert attached to " + + "OCSP Response is invalid." + + "Current time - %s" + + "Certificate not before time - %s" + + "Certificate not after time - %s", + currentTime, signVerifyCert.getNotBefore(), signVerifyCert.getNotAfter())); + } + try { + verifySignature( + new X509CertificateHolder(pairIssuerSubject.left.getEncoded()), + signVerifyCert.getSignature(), + CONVERTER_X509.getCertificate(signVerifyCert).getTBSCertificate(), + signVerifyCert.getSignatureAlgorithm()); + } catch (CertificateException ex) { + logger.debug("OCSP Signing Certificate signature verification failed", false); + throw new SFOCSPException( + OCSPErrorCode.INVALID_CERTIFICATE_SIGNATURE, + "OCSP Signing Certificate signature verification failed", + ex); + } + logger.debug("Verifying OCSP signature by the attached certificate public key.", false); + } else { + logger.debug( + "Certificate is NOT attached for verification. " + + "Verifying OCSP signature by the issuer public key.", + false); + signVerifyCert = new X509CertificateHolder(pairIssuerSubject.left.getEncoded()); + } + try { + verifySignature( + signVerifyCert, + basicOcspResp.getSignature(), + basicOcspResp.getTBSResponseData(), + basicOcspResp.getSignatureAlgorithmID()); + } catch (CertificateException ex) { + logger.debug("OCSP signature verification failed", false); + throw new SFOCSPException( + OCSPErrorCode.INVALID_OCSP_RESPONSE_SIGNATURE, + "OCSP signature verification failed", + ex); + } + + validateBasicOcspResponse(currentTime, basicOcspResp); + } catch (IOException | OCSPException ex) { + throw new SFOCSPException( + OCSPErrorCode.REVOCATION_CHECK_FAILURE, "Failed to check revocation status.", ex); + } + } + + private void checkInvalidSigningCertTestParameter() throws SFOCSPException { + if (isEnabledSystemTestParameter(SF_OCSP_TEST_INVALID_SIGNING_CERT)) { + throw new SFOCSPException( + OCSPErrorCode.EXPIRED_OCSP_SIGNING_CERTIFICATE, + "Cert attached to OCSP Response is invalid"); + } + } + + /** + * Validates OCSP Basic OCSP response. + * + * @param currentTime the current timestamp. + * @param basicOcspResp BasicOcspResponse data. + * @throws SFOCSPException raises if any failure occurs. + */ + private void validateBasicOcspResponse(Date currentTime, BasicOCSPResp basicOcspResp) + throws SFOCSPException { + for (SingleResp singleResps : basicOcspResp.getResponses()) { + checkCertUnknownTestParameter(); + CertificateStatus certStatus = singleResps.getCertStatus(); + if (certStatus != CertificateStatus.GOOD) { + if (certStatus instanceof RevokedStatus) { + RevokedStatus status = (RevokedStatus) certStatus; + int reason; + try { + reason = status.getRevocationReason(); + } catch (IllegalStateException ex) { + reason = -1; + } + Date revocationTime = status.getRevocationTime(); + throw new SFOCSPException( + OCSPErrorCode.CERTIFICATE_STATUS_REVOKED, + String.format( + "The certificate has been revoked. Reason: %d, Time: %s", + reason, DATE_FORMAT_UTC.format(revocationTime))); + } else { + // Unknown status + throw new SFOCSPException( + OCSPErrorCode.CERTIFICATE_STATUS_UNKNOWN, + "Failed to validate the certificate for UNKNOWN reason."); + } + } + + Date thisUpdate = singleResps.getThisUpdate(); + Date nextUpdate = singleResps.getNextUpdate(); + logger.debug( + "Current Time: {}, This Update: {}, Next Update: {}", + currentTime, + thisUpdate, + nextUpdate); + if (!isValidityRange(currentTime, thisUpdate, nextUpdate)) { + throw new SFOCSPException( + OCSPErrorCode.INVALID_OCSP_RESPONSE_VALIDITY, + String.format( + "The OCSP response validity is out of range: " + + "Current Time: %s, This Update: %s, Next Update: %s", + DATE_FORMAT_UTC.format(currentTime), + DATE_FORMAT_UTC.format(thisUpdate), + DATE_FORMAT_UTC.format(nextUpdate))); + } + } + logger.debug("OK. Verified the certificate revocation status.", false); + } + + private void checkCertUnknownTestParameter() throws SFOCSPException { + if (isEnabledSystemTestParameter(SF_OCSP_TEST_INJECT_UNKNOWN_STATUS)) { + throw new SFOCSPException( + OCSPErrorCode.CERTIFICATE_STATUS_UNKNOWN, + "Failed to validate the certificate for UNKNOWN reason."); + } + } + + /** + * Creates a OCSP Request + * + * @param pairIssuerSubject a pair of issuer and subject certificates + * @return OCSPReq object + */ + private OCSPReq createRequest(SFPair pairIssuerSubject) + throws IOException { + Certificate issuer = pairIssuerSubject.left; + Certificate subject = pairIssuerSubject.right; + OCSPReqBuilder gen = new OCSPReqBuilder(); + try { + DigestCalculator digest = new SHA1DigestCalculator(); + X509CertificateHolder certHolder = new X509CertificateHolder(issuer.getEncoded()); + CertificateID certId = + new CertificateID(digest, certHolder, subject.getSerialNumber().getValue()); + gen.addRequest(certId); + return gen.build(); + } catch (OCSPException ex) { + throw new IOException("Failed to build a OCSPReq.", ex); + } + } + + /** + * Converts X509Certificate to Bouncy Castle Certificate + * + * @param chain an array of X509Certificate + * @return a list of Bouncy Castle Certificate + */ + private List convertToBouncyCastleCertificate(X509Certificate[] chain) + throws CertificateEncodingException { + final List bcChain = new ArrayList<>(); + for (X509Certificate cert : chain) { + bcChain.add(Certificate.getInstance(cert.getEncoded())); + } + return bcChain; + } + + /** + * Creates a pair of Issuer and Subject certificates + * + * @param bcChain a list of bouncy castle Certificate + * @return a list of pair of Issuer and Subject certificates + */ + private List> getPairIssuerSubject(List bcChain) + throws CertificateException { + List> pairIssuerSubject = new ArrayList<>(); + for (int i = 0, len = bcChain.size(); i < len; ++i) { + Certificate bcCert = bcChain.get(i); + if (bcCert.getIssuer().equals(bcCert.getSubject())) { + continue; // skipping ROOT CA + } + if (i < len - 1) { + // Check if the root certificate has been found and stop going down the chain. + Certificate issuer = ROOT_CA.get(bcCert.getIssuer().hashCode()); + if (issuer != null) { + logger.debug( + "A trusted root certificate found: %s, stopping chain traversal here", + bcCert.getIssuer().toString()); + pairIssuerSubject.add(SFPair.of(issuer, bcChain.get(i))); + break; + } + pairIssuerSubject.add(SFPair.of(bcChain.get(i + 1), bcChain.get(i))); + } else { + // no root CA certificate is attached in the certificate chain, so + // getting one from the root CA from JVM. + Certificate issuer = ROOT_CA.get(bcCert.getIssuer().hashCode()); + if (issuer == null) { + throw new CertificateException( + "Failed to find the root CA.", + new SFOCSPException(OCSPErrorCode.NO_ROOTCA_FOUND, "Failed to find the root CA.")); + } + pairIssuerSubject.add(SFPair.of(issuer, bcChain.get(i))); + } + } + return pairIssuerSubject; + } + + /** + * Gets OCSP URLs associated with the certificate. + * + * @param bcCert Bouncy Castle Certificate + * @return a set of OCSP URLs + */ + private Set getOcspUrls(Certificate bcCert) throws IOException { + TBSCertificate bcTbsCert = bcCert.getTBSCertificate(); + Extensions bcExts = bcTbsCert.getExtensions(); + if (bcExts == null) { + throw new IOException("Failed to get Tbs Certificate."); + } + + Set ocsp = new HashSet<>(); + for (Enumeration en = bcExts.oids(); en.hasMoreElements(); ) { + ASN1ObjectIdentifier oid = (ASN1ObjectIdentifier) en.nextElement(); + Extension bcExt = bcExts.getExtension(oid); + if (Extension.authorityInfoAccess.equals(bcExt.getExtnId())) { + // OCSP URLS are included in authorityInfoAccess + DLSequence seq = (DLSequence) bcExt.getParsedValue(); + for (ASN1Encodable asn : seq) { + ASN1Encodable[] pairOfAsn = ((DLSequence) asn).toArray(); + if (pairOfAsn.length == 2) { + ASN1ObjectIdentifier key = (ASN1ObjectIdentifier) pairOfAsn[0]; + if (OIDocsp.equals(key)) { + // ensure OCSP and not CRL + GeneralName gn = GeneralName.getInstance(pairOfAsn[1]); + ocsp.add(gn.getName().toString()); + } + } + } + } + } + return ocsp; + } + + /** OCSP Response Utils */ + private String ocspResponseToB64(OCSPResp ocspResp) { + if (ocspResp == null) { + return null; + } + try { + return Base64.encodeBase64String(ocspResp.getEncoded()); + } catch (Throwable ex) { + logger.debug("Could not convert OCSP Response to Base64", false); + return null; + } + } + + private OCSPResp b64ToOCSPResp(String ocspRespB64) { + try { + return new OCSPResp(Base64.decodeBase64(ocspRespB64)); + } catch (Throwable ex) { + logger.debug("Could not cover OCSP Response from Base64 to OCSPResp object", false); + return null; + } + } + + static class OCSPCacheServer { + String SF_OCSP_RESPONSE_CACHE_SERVER; + String SF_OCSP_RESPONSE_RETRY_URL; + boolean new_endpoint_enabled; + + void resetOCSPResponseCacheServer(String host) { + String ocspCacheServerUrl; + if (host.toLowerCase().contains(".global.snowflakecomputing.")) { + ocspCacheServerUrl = + String.format("https://ocspssd%s/%s", host.substring(host.indexOf('-')), "ocsp"); + } else if (host.toLowerCase().contains(".snowflakecomputing.")) { + ocspCacheServerUrl = + String.format("https://ocspssd%s/%s", host.substring(host.indexOf('.')), "ocsp"); + } else { + String topLevelDomain = host.substring(host.lastIndexOf(".") + 1); + ocspCacheServerUrl = + String.format("https://ocspssd.snowflakecomputing.%s/ocsp", topLevelDomain); + } + SF_OCSP_RESPONSE_CACHE_SERVER = String.format("%s/%s", ocspCacheServerUrl, "fetch"); + SF_OCSP_RESPONSE_RETRY_URL = String.format("%s/%s", ocspCacheServerUrl, "retry"); + } + } + + private static class OCSPPostReqData { + private String ocsp_url; + private String ocsp_req; + private String cert_id_enc; + private String hostname; + + OCSPPostReqData(String ocsp_url, String ocsp_req, String cert_id_enc, String hname) { + this.ocsp_url = ocsp_url; + this.ocsp_req = ocsp_req; + this.cert_id_enc = cert_id_enc; + this.hostname = hname; + } + } + + /** OCSP response cache key object */ + static class OcspResponseCacheKey { + final byte[] nameHash; + final byte[] keyHash; + final BigInteger serialNumber; + + OcspResponseCacheKey(byte[] nameHash, byte[] keyHash, BigInteger serialNumber) { + this.nameHash = nameHash; + this.keyHash = keyHash; + this.serialNumber = serialNumber; + } + + public int hashCode() { + int ret = Arrays.hashCode(this.nameHash) * 37; + ret = ret * 10 + Arrays.hashCode(this.keyHash) * 37; + ret = ret * 10 + this.serialNumber.hashCode(); + return ret; + } + + public boolean equals(Object obj) { + if (!(obj instanceof OcspResponseCacheKey)) { + return false; + } + OcspResponseCacheKey target = (OcspResponseCacheKey) obj; + return Arrays.equals(this.nameHash, target.nameHash) + && Arrays.equals(this.keyHash, target.keyHash) + && this.serialNumber.equals(target.serialNumber); + } + + public String toString() { + return String.format( + "OcspResponseCacheKey: NameHash: %s, KeyHash: %s, SerialNumber: %s", + HexUtil.byteToHexString(nameHash), + HexUtil.byteToHexString(keyHash), + serialNumber.toString()); + } + } + + /** SHA1 Digest Calculator used in OCSP Req. */ + static class SHA1DigestCalculator implements DigestCalculator { + private ByteArrayOutputStream bOut = new ByteArrayOutputStream(); + + public AlgorithmIdentifier getAlgorithmIdentifier() { + return new AlgorithmIdentifier(OIWObjectIdentifiers.idSHA1); + } + + public OutputStream getOutputStream() { + return bOut; + } + + public byte[] getDigest() { + byte[] bytes = bOut.toByteArray(); + bOut.reset(); + try { + MessageDigest messageDigest = MessageDigest.getInstance(ALGORITHM_SHA1_NAME); + return messageDigest.digest(bytes); + } catch (NoSuchAlgorithmException ex) { + String errMsg = + String.format( + "Failed to instantiate the algorithm: %s. err=%s", + ALGORITHM_SHA1_NAME, ex.getMessage()); + logger.error(errMsg, false); + throw new RuntimeException(errMsg); + } + } + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeAzureClient.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeAzureClient.java index 6e83c25c6..a8a0bd577 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeAzureClient.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeAzureClient.java @@ -10,8 +10,8 @@ */ package net.snowflake.ingest.streaming.internal.fileTransferAgent; -import static net.snowflake.client.core.HttpUtil.setSessionlessProxyForAzure; import static net.snowflake.ingest.streaming.internal.fileTransferAgent.ErrorCode.CLOUD_STORAGE_CREDENTIALS_EXPIRED; +import static net.snowflake.ingest.streaming.internal.fileTransferAgent.JdbcHttpUtil.setSessionlessProxyForAzure; import static net.snowflake.ingest.streaming.internal.fileTransferAgent.StorageClientUtil.systemGetProperty; import com.fasterxml.jackson.core.JsonFactory; @@ -80,8 +80,7 @@ private SnowflakeAzureClient() {} * required to decrypt/encrypt content in stage */ public static SnowflakeAzureClient createSnowflakeAzureClient( - StageInfo stage, RemoteStoreFileEncryptionMaterial encMat) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + StageInfo stage, RemoteStoreFileEncryptionMaterial encMat) throws SnowflakeSQLException { logger.debug( "Initializing Snowflake Azure client with encryption: {}", encMat != null ? "true" : "false"); @@ -102,9 +101,7 @@ public static SnowflakeAzureClient createSnowflakeAzureClient( * @throws IllegalArgumentException when invalid credentials are used */ private void setupAzureClient(StageInfo stage, RemoteStoreFileEncryptionMaterial encMat) - throws IllegalArgumentException, - SnowflakeSQLException, - net.snowflake.client.jdbc.SnowflakeSQLException { + throws IllegalArgumentException, SnowflakeSQLException { // Save the client creation parameters so that we can reuse them, // to reset the Azure client. this.stageInfo = stage; @@ -189,8 +186,7 @@ public int getEncryptionKeySize() { * @throws SnowflakeSQLException failure to renew the client */ @Override - public void renew(Map stageCredentials) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + public void renew(Map stageCredentials) throws SnowflakeSQLException { logger.debug("Renewing the Azure client"); stageInfo.setCredentials(stageCredentials); setupAzureClient(stageInfo, encMat); @@ -302,7 +298,7 @@ public void download( String stageRegion, String presignedUrl, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { Stopwatch stopwatch = new Stopwatch(); stopwatch.start(); String localFilePath = localLocation + localFileSep + destFileName; @@ -416,7 +412,7 @@ public InputStream downloadToStream( String stageRegion, String presignedUrl, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { logger.debug( "Staring download of file from Azure stage path: {} to input stream", stageFilePath); Stopwatch stopwatch = new Stopwatch(); @@ -528,7 +524,7 @@ public void upload( String stageRegion, String presignedUrl, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { logger.info( StorageHelper.getStartUploadLog( "Azure", uploadFromStream, inputStream, fileBackedOutputStream, srcFile, destFileName)); @@ -649,7 +645,7 @@ public void upload( @Override public void handleStorageException( Exception ex, int retryCount, String operation, String command, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { handleAzureException(ex, retryCount, operation, command, this, queryId); } @@ -755,7 +751,7 @@ private static void handleAzureException( String command, SnowflakeAzureClient azClient, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { // no need to retry if it is invalid key exception if (ex.getCause() instanceof InvalidKeyException) { diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeFileTransferAgent.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeFileTransferAgent.java index 21f82c23b..0a4faf5c1 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeFileTransferAgent.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeFileTransferAgent.java @@ -638,9 +638,9 @@ public static void uploadWithoutConnection(SnowflakeFileTransferConfig config) t String streamingIngestClientKey = config.getStreamingIngestClientKey(); // Create HttpClient key - net.snowflake.client.core.HttpClientSettingsKey key = - net.snowflake.client.jdbc.SnowflakeUtil.convertProxyPropertiesToHttpClientKey( - net.snowflake.client.core.OCSPMode.FAIL_OPEN, proxyProperties); + HttpClientSettingsKey key = + StorageClientUtil.convertProxyPropertiesToHttpClientKey( + net.snowflake.ingest.utils.OCSPMode.FAIL_OPEN, proxyProperties); StageInfo stageInfo = metadata.getStageInfo(); stageInfo.setProxyProperties(proxyProperties); @@ -860,7 +860,7 @@ private static void pushFileToRemoteStoreWithPresignedUrl( FileCompressionType compressionType, SnowflakeStorageClient initialClient, int networkTimeoutInMilli, - net.snowflake.client.core.HttpClientSettingsKey ocspModeAndProxyKey, + HttpClientSettingsKey ocspModeAndProxyKey, int parallel, File srcFile, boolean uploadFromStream, diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeGCSClient.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeGCSClient.java index 66909b130..a2a17ffda 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeGCSClient.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeGCSClient.java @@ -38,10 +38,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; -import net.snowflake.client.core.ExecTimeTelemetryData; -import net.snowflake.client.core.HttpResponseContextDto; -import net.snowflake.client.core.HttpUtil; -import net.snowflake.client.jdbc.RestRequest; +// ExecTimeTelemetryData, HttpResponseContextDto, RestRequest all in same package import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.ArgSupplier; import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLogger; import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLoggerFactory; @@ -84,8 +81,7 @@ private SnowflakeGCSClient() {} * required to decrypt/encrypt content in stage */ public static SnowflakeGCSClient createSnowflakeGCSClient( - StageInfo stage, RemoteStoreFileEncryptionMaterial encMat) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + StageInfo stage, RemoteStoreFileEncryptionMaterial encMat) throws SnowflakeSQLException { logger.debug( "Initializing Snowflake GCS client with encryption: {}", encMat != null ? "true" : "false"); SnowflakeGCSClient sfGcsClient = new SnowflakeGCSClient(); @@ -140,8 +136,7 @@ public boolean requirePresignedUrl() { } @Override - public void renew(Map stageCredentials) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + public void renew(Map stageCredentials) throws SnowflakeSQLException { logger.debug("Renewing the Snowflake GCS client"); stageInfo.setCredentials(stageCredentials); setupGCSClient(stageInfo, encMat); @@ -199,7 +194,7 @@ public void download( String stageRegion, String presignedUrl, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { String localFilePath = localLocation + localFileSep + destFileName; logger.debug( "Staring download of file from GCS stage path: {} to {}", stageFilePath, localFilePath); @@ -221,7 +216,8 @@ public void download( logger.debug("Fetching result: {}", scrubPresignedUrl(presignedUrl)); - CloseableHttpClient httpClient = HttpUtil.getHttpClientWithoutDecompression(null, null); + CloseableHttpClient httpClient = + JdbcHttpUtil.getHttpClientWithoutDecompression(null, null); // Get the file on storage using the presigned url HttpResponseContextDto responseDto = @@ -381,7 +377,7 @@ public InputStream downloadToStream( String stageRegion, String presignedUrl, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { logger.debug("Staring download of file from GCS stage path: {} to input stream", stageFilePath); int retryCount = 0; Stopwatch stopwatch = new Stopwatch(); @@ -402,7 +398,8 @@ public InputStream downloadToStream( logger.debug("Fetching result: {}", scrubPresignedUrl(presignedUrl)); - CloseableHttpClient httpClient = HttpUtil.getHttpClientWithoutDecompression(null, null); + CloseableHttpClient httpClient = + JdbcHttpUtil.getHttpClientWithoutDecompression(null, null); // Put the file on storage using the presigned url HttpResponse response = @@ -550,7 +547,7 @@ public InputStream downloadToStream( @Override public void uploadWithPresignedUrlWithoutConnection( int networkTimeoutInMilli, - net.snowflake.client.core.HttpClientSettingsKey ocspModeAndProxyKey, + HttpClientSettingsKey ocspModeAndProxyKey, int parallelism, boolean uploadFromStream, String remoteStorageLocation, @@ -562,7 +559,7 @@ public void uploadWithPresignedUrlWithoutConnection( String stageRegion, String presignedUrl, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { logger.info( StorageHelper.getStartUploadLog( "GCS", uploadFromStream, inputStream, fileBackedOutputStream, srcFile, destFileName)); @@ -602,7 +599,7 @@ public void uploadWithPresignedUrlWithoutConnection( uploadWithPresignedUrl( networkTimeoutInMilli, - (int) HttpUtil.getSocketTimeout().toMillis(), + (int) JdbcHttpUtil.getSocketTimeout().toMillis(), meta.getContentEncoding(), meta.getUserMetadata(), uploadStreamInfo.left, @@ -663,7 +660,7 @@ public void upload( String stageRegion, String presignedUrl, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { logger.info( StorageHelper.getStartUploadLog( "GCS", uploadFromStream, inputStream, fileBackedOutputStream, srcFile, destFileName)); @@ -814,7 +811,7 @@ private void uploadWithDownScopedToken( long contentLength, InputStream content, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { logger.debug("Uploading file {} to bucket {}", destFileName, remoteStorageLocation); try { this.gcsAccessStrategy.uploadWithDownScopedToken( @@ -863,9 +860,9 @@ private void uploadWithPresignedUrl( Map metadata, InputStream content, String presignedUrl, - net.snowflake.client.core.HttpClientSettingsKey ocspAndProxyKey, + HttpClientSettingsKey ocspAndProxyKey, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { try { URIBuilder uriBuilder = new URIBuilder(presignedUrl); @@ -889,7 +886,7 @@ private void uploadWithPresignedUrl( InputStreamEntity contentEntity = new InputStreamEntity(content, -1); httpRequest.setEntity(contentEntity); - CloseableHttpClient httpClient = HttpUtil.getHttpClient(ocspAndProxyKey, null); + CloseableHttpClient httpClient = JdbcHttpUtil.getHttpClient(ocspAndProxyKey, null); // Put the file on storage using the presigned url HttpResponse response = @@ -1037,7 +1034,7 @@ private SFPair createUploadStream( @Override public void handleStorageException( Exception ex, int retryCount, String operation, String command, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { // no need to retry if it is invalid key exception if (ex.getCause() instanceof InvalidKeyException) { // Most likely cause is that the unlimited strength policy files are not installed @@ -1172,9 +1169,7 @@ public String getDigestMetadata(StorageObjectMetadata meta) { * @throws IllegalArgumentException when invalid credentials are used */ private void setupGCSClient(StageInfo stage, RemoteStoreFileEncryptionMaterial encMat) - throws IllegalArgumentException, - SnowflakeSQLException, - net.snowflake.client.jdbc.SnowflakeSQLException { + throws IllegalArgumentException, SnowflakeSQLException { // Save the client creation parameters so that we can reuse them, // to reset the GCS client. this.stageInfo = stage; diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeMutableProxyRoutePlanner.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeMutableProxyRoutePlanner.java new file mode 100644 index 000000000..301b83592 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeMutableProxyRoutePlanner.java @@ -0,0 +1,92 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/SnowflakeMutableProxyRoutePlanner.java + * + * Permitted differences: package declaration, import swaps for already-replicated classes. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import com.amazonaws.Protocol; +import com.amazonaws.http.apache.SdkProxyRoutePlanner; +import java.io.Serializable; +import org.apache.http.HttpException; +import org.apache.http.HttpHost; +import org.apache.http.HttpRequest; +import org.apache.http.conn.routing.HttpRoute; +import org.apache.http.conn.routing.HttpRoutePlanner; +import org.apache.http.protocol.HttpContext; + +/** + * This class defines a ProxyRoutePlanner (used for creating HttpClients) that has the ability to + * change the nonProxyHosts setting. + */ +public class SnowflakeMutableProxyRoutePlanner implements HttpRoutePlanner, Serializable { + + private SdkProxyRoutePlanner proxyRoutePlanner = null; + private String host; + private int proxyPort; + private String nonProxyHosts; + private HttpProtocol protocol; + + /** + * @deprecated Use {@link #SnowflakeMutableProxyRoutePlanner(String, int, HttpProtocol, String)} + * instead + * @param host host + * @param proxyPort proxy port + * @param proxyProtocol proxy protocol + * @param nonProxyHosts non-proxy hosts + */ + @Deprecated + public SnowflakeMutableProxyRoutePlanner( + String host, int proxyPort, Protocol proxyProtocol, String nonProxyHosts) { + this(host, proxyPort, toSnowflakeProtocol(proxyProtocol), nonProxyHosts); + } + + /** + * @param host host + * @param proxyPort proxy port + * @param proxyProtocol proxy protocol + * @param nonProxyHosts non-proxy hosts + */ + public SnowflakeMutableProxyRoutePlanner( + String host, int proxyPort, HttpProtocol proxyProtocol, String nonProxyHosts) { + proxyRoutePlanner = + new SdkProxyRoutePlanner(host, proxyPort, toAwsProtocol(proxyProtocol), nonProxyHosts); + this.host = host; + this.proxyPort = proxyPort; + this.nonProxyHosts = nonProxyHosts; + this.protocol = proxyProtocol; + } + + /** + * Set non-proxy hosts + * + * @param nonProxyHosts non-proxy hosts + */ + public void setNonProxyHosts(String nonProxyHosts) { + this.nonProxyHosts = nonProxyHosts; + proxyRoutePlanner = + new SdkProxyRoutePlanner(host, proxyPort, toAwsProtocol(protocol), nonProxyHosts); + } + + /** + * @return non-proxy hosts string + */ + public String getNonProxyHosts() { + return nonProxyHosts; + } + + @Override + public HttpRoute determineRoute(HttpHost target, HttpRequest request, HttpContext context) + throws HttpException { + return proxyRoutePlanner.determineRoute(target, request, context); + } + + private static Protocol toAwsProtocol(HttpProtocol protocol) { + return protocol == HttpProtocol.HTTP ? Protocol.HTTP : Protocol.HTTPS; + } + + private static HttpProtocol toSnowflakeProtocol(Protocol protocol) { + return protocol == Protocol.HTTP ? HttpProtocol.HTTP : HttpProtocol.HTTPS; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeS3Client.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeS3Client.java index 4bede51dd..dec49b0fd 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeS3Client.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeS3Client.java @@ -110,7 +110,7 @@ public SnowflakeS3Client( String stageEndPoint, boolean isClientSideEncrypted, boolean useS3RegionalUrl) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { logger.debug( "Initializing Snowflake S3 client with encryption: {}, client side encrypted: {}", encMat != null, @@ -134,7 +134,7 @@ private void setupSnowflakeS3Client( String stageRegion, String stageEndPoint, boolean isClientSideEncrypted) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { // Save the client creation parameters so that we can reuse them, // to reset the AWS client. We won't save the awsCredentials since // we will be refreshing that, every time we reset the AWS client @@ -263,8 +263,7 @@ public int getEncryptionKeySize() { * @throws SnowflakeSQLException if any error occurs */ @Override - public void renew(Map stageCredentials) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + public void renew(Map stageCredentials) throws SnowflakeSQLException { logger.debug("Renewing the Snowflake S3 client"); // We renew the client with fresh credentials and with its original parameters setupSnowflakeS3Client( @@ -324,7 +323,7 @@ public void download( String stageRegion, String presignedUrl, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { Stopwatch stopwatch = new Stopwatch(); stopwatch.start(); String localFilePath = localLocation + localFileSep + destFileName; @@ -445,7 +444,7 @@ public InputStream downloadToStream( String stageRegion, String presignedUrl, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { logger.debug("Staring download of file from S3 stage path: {} to input stream", stageFilePath); Stopwatch stopwatch = new Stopwatch(); stopwatch.start(); @@ -542,7 +541,7 @@ public void upload( String stageRegion, String presignedUrl, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { logger.info( StorageHelper.getStartUploadLog( "S3", uploadFromStream, inputStream, fileBackedOutputStream, srcFile, destFileName)); @@ -763,7 +762,7 @@ private SFPair createUploadStream( @Override public void handleStorageException( Exception ex, int retryCount, String operation, String command, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { handleS3Exception(ex, retryCount, operation, command, this, queryId); } @@ -774,7 +773,7 @@ private static void handleS3Exception( String command, SnowflakeS3Client s3Client, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { // no need to retry if it is invalid key exception if (ex.getCause() instanceof InvalidKeyException) { // Most likely cause is that the unlimited strength policy files are not installed diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeStorageClient.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeStorageClient.java index e50355f00..655a6b859 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeStorageClient.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeStorageClient.java @@ -58,8 +58,7 @@ default boolean requirePresignedUrl() { * @param stageCredentials a Map (as returned by GS) which contains the new credential properties * @throws SnowflakeSQLException failure to renew the storage client */ - void renew(Map stageCredentials) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException; + void renew(Map stageCredentials) throws SnowflakeSQLException; /** shuts down the client */ void shutdown(); @@ -103,7 +102,7 @@ default void download( String stageFilePath, String stageRegion, String presignedUrl) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { download( command, localLocation, @@ -140,7 +139,7 @@ void download( String stageRegion, String presignedUrl, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException; + throws SnowflakeSQLException; /** * @deprecated @@ -153,7 +152,7 @@ default InputStream downloadToStream( String stageFilePath, String stageRegion, String presignedUrl) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { return downloadToStream( command, parallelism, @@ -185,7 +184,7 @@ InputStream downloadToStream( String stageRegion, String presignedUrl, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException; + throws SnowflakeSQLException; /** * @deprecated @@ -203,7 +202,7 @@ default void upload( StorageObjectMetadata meta, String stageRegion, String presignedUrl) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { upload( command, parallelism, @@ -249,7 +248,7 @@ void upload( String stageRegion, String presignedUrl, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException; + throws SnowflakeSQLException; /** * @deprecated @@ -257,7 +256,7 @@ void upload( @Deprecated default void uploadWithPresignedUrlWithoutConnection( int networkTimeoutInMilli, - net.snowflake.client.core.HttpClientSettingsKey ocspModeAndProxyKey, + HttpClientSettingsKey ocspModeAndProxyKey, int parallelism, boolean uploadFromStream, String remoteStorageLocation, @@ -268,7 +267,7 @@ default void uploadWithPresignedUrlWithoutConnection( StorageObjectMetadata meta, String stageRegion, String presignedUrl) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { uploadWithPresignedUrlWithoutConnection( networkTimeoutInMilli, ocspModeAndProxyKey, @@ -292,7 +291,7 @@ default void uploadWithPresignedUrlWithoutConnection( */ default void uploadWithPresignedUrlWithoutConnection( int networkTimeoutInMilli, - net.snowflake.client.core.HttpClientSettingsKey ocspModeAndProxyKey, + HttpClientSettingsKey ocspModeAndProxyKey, int parallelism, boolean uploadFromStream, String remoteStorageLocation, @@ -304,7 +303,7 @@ default void uploadWithPresignedUrlWithoutConnection( String stageRegion, String presignedUrl, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { if (!requirePresignedUrl()) { throw new SnowflakeSQLLoggedException( queryId, @@ -322,8 +321,7 @@ default void uploadWithPresignedUrlWithoutConnection( */ @Deprecated default void handleStorageException( - Exception ex, int retryCount, String operation, String command) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + Exception ex, int retryCount, String operation, String command) throws SnowflakeSQLException { handleStorageException(ex, retryCount, operation, command, null); } @@ -340,7 +338,7 @@ default void handleStorageException( */ void handleStorageException( Exception ex, int retryCount, String operation, String command, String queryId) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException; + throws SnowflakeSQLException; /** * Returns the material descriptor key diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeUseDPoPNonceException.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeUseDPoPNonceException.java new file mode 100644 index 000000000..dd901214a --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/SnowflakeUseDPoPNonceException.java @@ -0,0 +1,20 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/jdbc/SnowflakeUseDPoPNonceException.java + * + * Permitted differences: package declaration, @SnowflakeJdbcInternalApi annotation removed. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +public class SnowflakeUseDPoPNonceException extends RuntimeException { + + private final String nonce; + + public SnowflakeUseDPoPNonceException(String nonce) { + this.nonce = nonce; + } + + public String getNonce() { + return nonce; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/StorageClientFactory.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/StorageClientFactory.java index e1c540297..728133c0f 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/StorageClientFactory.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/StorageClientFactory.java @@ -52,7 +52,7 @@ public static StorageClientFactory getFactory() { */ public SnowflakeStorageClient createClient( StageInfo stage, int parallel, RemoteStoreFileEncryptionMaterial encMat) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { logger.debug("Creating storage client. Client type: {}", stage.getStageType().name()); switch (stage.getStageType()) { @@ -106,7 +106,7 @@ private SnowflakeS3Client createS3Client( String stageEndPoint, boolean isClientSideEncrypted, boolean useS3RegionalUrl) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { final int S3_TRANSFER_MAX_RETRIES = 3; logger.debug("Creating S3 client with encryption: {}", (encMat == null ? "no" : "yes")); @@ -188,8 +188,7 @@ public StorageObjectMetadata createStorageMetadataObj(StageInfo.StageType stageT * @return the SnowflakeS3Client instance created */ private SnowflakeAzureClient createAzureClient( - StageInfo stage, RemoteStoreFileEncryptionMaterial encMat) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + StageInfo stage, RemoteStoreFileEncryptionMaterial encMat) throws SnowflakeSQLException { logger.debug("Creating Azure client with encryption: {}", (encMat == null ? "no" : "yes")); SnowflakeAzureClient azureClient; @@ -213,8 +212,7 @@ private SnowflakeAzureClient createAzureClient( * @return the SnowflakeGCSClient instance created */ private SnowflakeGCSClient createGCSClient( - StageInfo stage, RemoteStoreFileEncryptionMaterial encMat) - throws SnowflakeSQLException, net.snowflake.client.jdbc.SnowflakeSQLException { + StageInfo stage, RemoteStoreFileEncryptionMaterial encMat) throws SnowflakeSQLException { logger.debug("Creating GCS client with encryption: {}", (encMat == null ? "no" : "yes")); SnowflakeGCSClient gcsClient; diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/StorageClientUtil.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/StorageClientUtil.java index 5ef4c408e..67f35d61e 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/StorageClientUtil.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/StorageClientUtil.java @@ -92,6 +92,19 @@ static boolean isNullOrEmpty(String str) { return str == null || str.isEmpty(); } + /** Replicated from SnowflakeUtil.systemGetEnv */ + static String systemGetEnv(String name) { + try { + return System.getenv(name); + } catch (SecurityException ex) { + logger.debug( + "Failed to get environment variable {}. Security exception raised: {}", + name, + ex.getMessage()); + return null; + } + } + /** * Replicated from SnowflakeUtil.isWindows, which delegates to Constants.getOS(). The OS detection * logic from Constants is inlined here to avoid replicating the full Constants class. @@ -116,7 +129,7 @@ static boolean isWindows() { * SnowflakeSQLException here temporarily until Step 5 replaces it. */ static HttpClientSettingsKey convertProxyPropertiesToHttpClientKey(OCSPMode mode, Properties info) - throws net.snowflake.client.jdbc.SnowflakeSQLException { + throws SnowflakeSQLException { if (info != null && info.size() > 0 && info.getProperty(SFSessionProperty.USE_PROXY.getPropertyKey()) != null) { @@ -129,9 +142,8 @@ static HttpClientSettingsKey convertProxyPropertiesToHttpClientKey(OCSPMode mode proxyPort = Integer.parseInt(info.getProperty(SFSessionProperty.PROXY_PORT.getPropertyKey())); } catch (NumberFormatException | NullPointerException e) { - throw new net.snowflake.client.jdbc.SnowflakeSQLException( - net.snowflake.client.jdbc.ErrorCode.INVALID_PROXY_PROPERTIES, - "Could not parse port number"); + throw new SnowflakeSQLException( + ErrorCode.INVALID_PROXY_PROPERTIES, "Could not parse port number"); } String proxyUser = info.getProperty(SFSessionProperty.PROXY_USER.getPropertyKey()); String proxyPassword = info.getProperty(SFSessionProperty.PROXY_PASSWORD.getPropertyKey()); @@ -232,12 +244,10 @@ static void throwJCEMissingError(String operation, Exception ex, String queryId) * Replicated from SnowflakeFileTransferAgent.throwNoSpaceLeftError. Source: * https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/jdbc/SnowflakeFileTransferAgent.java * - * @deprecated use {@link #throwNoSpaceLeftError(net.snowflake.client.core.SFSession, String, - * Exception, String)} + * @deprecated use {@link #throwNoSpaceLeftError(Object, String, Exception, String)} */ @Deprecated - static void throwNoSpaceLeftError( - net.snowflake.client.core.SFSession session, String operation, Exception ex) + static void throwNoSpaceLeftError(Object session, String operation, Exception ex) throws SnowflakeSQLLoggedException { throwNoSpaceLeftError(session, operation, ex, null); } @@ -246,8 +256,7 @@ static void throwNoSpaceLeftError( * Replicated from SnowflakeFileTransferAgent.throwNoSpaceLeftError. Source: * https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/jdbc/SnowflakeFileTransferAgent.java */ - static void throwNoSpaceLeftError( - net.snowflake.client.core.SFSession session, String operation, Exception ex, String queryId) + static void throwNoSpaceLeftError(Object session, String operation, Exception ex, String queryId) throws SnowflakeSQLLoggedException { String exMessage = getRootCause(ex).getMessage(); if (exMessage != null && exMessage.equals(NO_SPACE_LEFT_ON_DEVICE_ERR)) { diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/ThrowingBiFunction.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/ThrowingBiFunction.java new file mode 100644 index 000000000..d9745268b --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/ThrowingBiFunction.java @@ -0,0 +1,12 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/util/ThrowingBiFunction.java + * + * Permitted differences: package declaration, @SnowflakeJdbcInternalApi annotation removed. + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +@FunctionalInterface +public interface ThrowingBiFunction { + R apply(A a, B b) throws T; +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/URLUtil.java b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/URLUtil.java new file mode 100644 index 000000000..db2c3391d --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/fileTransferAgent/URLUtil.java @@ -0,0 +1,82 @@ +/* + * Replicated from snowflake-jdbc (v3.25.1) + * Source: https://github.com/snowflakedb/snowflake-jdbc/blob/v3.25.1/src/main/java/net/snowflake/client/core/URLUtil.java + * + * Permitted differences: package declaration, import swaps for already-replicated classes, + * @SnowflakeJdbcInternalApi annotation removed. + * SFSession.SF_QUERY_REQUEST_ID inlined as a local constant (value: "requestId"). + */ +package net.snowflake.ingest.streaming.internal.fileTransferAgent; + +import java.io.UnsupportedEncodingException; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.regex.PatternSyntaxException; +import javax.annotation.Nullable; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLogger; +import net.snowflake.ingest.streaming.internal.fileTransferAgent.log.SFLoggerFactory; +import org.apache.http.NameValuePair; +import org.apache.http.client.utils.URLEncodedUtils; + +public class URLUtil { + + private static final SFLogger logger = SFLoggerFactory.getLogger(URLUtil.class); + + // Inlined from SFSession.SF_QUERY_REQUEST_ID + private static final String SF_QUERY_REQUEST_ID = "requestId"; + + static final String validURLPattern = + "^http(s?)\\:\\/\\/[0-9a-zA-Z]([-.\\w]*[0-9a-zA-Z@:])*(:(0-9)*)*(\\/?)([a-zA-Z0-9\\-\\.\\?\\,\\&\\(\\)\\/\\\\\\+&%\\$#_=@]*)?$"; + static final Pattern pattern = Pattern.compile(validURLPattern); + + public static boolean isValidURL(String url) { + try { + Matcher matcher = pattern.matcher(url); + return matcher.find(); + } catch (PatternSyntaxException pex) { + logger.debug("The URL REGEX is invalid. Falling back to basic sanity test"); + try { + new URL(url).toURI(); + return true; + } catch (MalformedURLException mex) { + logger.debug("The URL " + url + ", is invalid"); + return false; + } catch (URISyntaxException uex) { + logger.debug("The URL " + url + ", is invalid"); + return false; + } + } + } + + @Nullable + public static String urlEncode(String target) throws UnsupportedEncodingException { + String encodedTarget; + try { + encodedTarget = URLEncoder.encode(target, StandardCharsets.UTF_8.toString()); + } catch (UnsupportedEncodingException uex) { + logger.debug("The string to be encoded- " + target + ", is invalid"); + return null; + } + return encodedTarget; + } + + public static String getRequestId(URI uri) { + return URLEncodedUtils.parse(uri, StandardCharsets.UTF_8).stream() + .filter(p -> p.getName().equals(SF_QUERY_REQUEST_ID)) + .findFirst() + .map(NameValuePair::getValue) + .orElse(null); + } + + public static String getRequestIdLogStr(URI uri) { + String requestId = getRequestId(uri); + + return requestId == null ? "" : "[requestId=" + requestId + "] "; + } +}