diff --git a/cpp/src/arrow/filesystem/s3fs.cc b/cpp/src/arrow/filesystem/s3fs.cc index c6b821f5deb..cab49e3f3ad 100644 --- a/cpp/src/arrow/filesystem/s3fs.cc +++ b/cpp/src/arrow/filesystem/s3fs.cc @@ -179,29 +179,94 @@ bool S3ProxyOptions::Equals(const S3ProxyOptions& other) const { username == other.username && password == other.password); } +// ----------------------------------------------------------------------- +// Custom comparison for AWS retry strategies +// To add a new strategy, add it to the AwsRetryStrategyVariant and +// add a new specialization to the AwsRetryStrategyEquality struct +using AwsRetryStrategyVariant = + std::variant, + std::shared_ptr>; + +struct AwsRetryStrategyEquality { + bool operator()(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) const { + if (!lhs && !rhs) return true; + if (!lhs || !rhs) return false; + + return lhs->GetMaxAttempts() == rhs->GetMaxAttempts(); + } + + bool operator()(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) const { + if (!lhs && !rhs) return true; + if (!lhs || !rhs) return false; + + return lhs->GetMaxAttempts() == rhs->GetMaxAttempts(); + } + + // Template function for same unknown RetryStrategy type - returns true if same pointer + template + bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) const { + if (!lhs && !rhs) return true; + if (!lhs || !rhs) return false; + + return lhs.get() == rhs.get(); + } + + // Template function for different RetryStrategy types - returns false for different + // types + template + bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) const { + return false; + } +}; + // ----------------------------------------------------------------------- // AwsRetryStrategy implementation class AwsRetryStrategy : public S3RetryStrategy { public: - explicit AwsRetryStrategy(std::shared_ptr retry_strategy) + explicit AwsRetryStrategy(AwsRetryStrategyVariant retry_strategy) : retry_strategy_(std::move(retry_strategy)) {} bool ShouldRetry(const AWSErrorDetail& detail, int64_t attempted_retries) override { Aws::Client::AWSError error = DetailToError(detail); - return retry_strategy_->ShouldRetry( - error, static_cast(attempted_retries)); // NOLINT: runtime/int + return std::visit( + [&](const auto& strategy) { + return strategy->ShouldRetry( + error, static_cast(attempted_retries)); // NOLINT: runtime/int + }, + retry_strategy_); } int64_t CalculateDelayBeforeNextRetry(const AWSErrorDetail& detail, int64_t attempted_retries) override { Aws::Client::AWSError error = DetailToError(detail); - return retry_strategy_->CalculateDelayBeforeNextRetry( - error, static_cast(attempted_retries)); // NOLINT: runtime/int + return std::visit( + [&](const auto& strategy) { + return strategy->CalculateDelayBeforeNextRetry( + error, static_cast(attempted_retries)); // NOLINT: runtime/int + }, + retry_strategy_); } + bool Equals(const S3RetryStrategy& other) const override { + auto other_aws = dynamic_cast(&other); + if (!other_aws) { + return false; + } + + return std::visit( + [](const auto& lhs, const auto& rhs) { + return AwsRetryStrategyEquality()(lhs, rhs); + }, + retry_strategy_, other_aws->retry_strategy_); + } + + protected: + AwsRetryStrategyVariant retry_strategy_; + private: - std::shared_ptr retry_strategy_; static Aws::Client::AWSError DetailToError( const S3RetryStrategy::AWSErrorDetail& detail) { auto exception_name = ToAwsString(detail.exception_name); @@ -426,6 +491,12 @@ bool S3Options::Equals(const S3Options& other) const { default_metadata_size ? (other.default_metadata && other.default_metadata->Equals(*default_metadata)) : (!other.default_metadata || other.default_metadata->size() == 0); + + // Compare retry strategies + const bool retry_strategy_equals = retry_strategy && other.retry_strategy + ? retry_strategy->Equals(*other.retry_strategy) + : (!retry_strategy && !other.retry_strategy); + return (smart_defaults == other.smart_defaults && region == other.region && connect_timeout == other.connect_timeout && request_timeout == other.request_timeout && @@ -442,7 +513,7 @@ bool S3Options::Equals(const S3Options& other) const { tls_ca_dir_path == other.tls_ca_dir_path && tls_verify_certificates == other.tls_verify_certificates && sse_customer_key == other.sse_customer_key && default_metadata_equals && - GetAccessKey() == other.GetAccessKey() && + retry_strategy_equals && GetAccessKey() == other.GetAccessKey() && GetSecretKey() == other.GetSecretKey() && GetSessionToken() == other.GetSessionToken()); } diff --git a/cpp/src/arrow/filesystem/s3fs.h b/cpp/src/arrow/filesystem/s3fs.h index 158d70a93fc..2ada7dbc61f 100644 --- a/cpp/src/arrow/filesystem/s3fs.h +++ b/cpp/src/arrow/filesystem/s3fs.h @@ -86,6 +86,11 @@ class ARROW_EXPORT S3RetryStrategy { /// Returns the time in milliseconds the S3 client should sleep for until retrying. virtual int64_t CalculateDelayBeforeNextRetry(const AWSErrorDetail& error, int64_t attempted_retries) = 0; + /// Returns true if this retry strategy is equal to another retry strategy. + /// By default, it returns true if the two objects are of the same type. + virtual bool Equals(const S3RetryStrategy& other) const { + return typeid(*this) == typeid(other); + } /// Returns a stock AWS Default retry strategy. static std::shared_ptr GetAwsDefaultRetryStrategy( int64_t max_attempts); diff --git a/cpp/src/arrow/filesystem/s3fs_test.cc b/cpp/src/arrow/filesystem/s3fs_test.cc index f7c125c8964..73fbeee366e 100644 --- a/cpp/src/arrow/filesystem/s3fs_test.cc +++ b/cpp/src/arrow/filesystem/s3fs_test.cc @@ -414,6 +414,109 @@ TEST_F(S3OptionsTest, FromAssumeRole) { options = S3Options::FromAssumeRole("my_role_arn", "session", "id", 42, sts_client); } +TEST_F(S3OptionsTest, RetryStrategyEquals) { + // Test DefaultRetryStrategy equality + auto default_strategy1 = S3RetryStrategy::GetAwsDefaultRetryStrategy(3); + auto default_strategy2 = S3RetryStrategy::GetAwsDefaultRetryStrategy(3); + auto default_strategy3 = S3RetryStrategy::GetAwsDefaultRetryStrategy(5); + + ASSERT_TRUE(default_strategy1->Equals(*default_strategy2)); + ASSERT_FALSE(default_strategy1->Equals(*default_strategy3)); + + // Test StandardRetryStrategy equality + auto standard_strategy1 = S3RetryStrategy::GetAwsStandardRetryStrategy(3); + auto standard_strategy2 = S3RetryStrategy::GetAwsStandardRetryStrategy(3); + auto standard_strategy3 = S3RetryStrategy::GetAwsStandardRetryStrategy(5); + + ASSERT_TRUE(standard_strategy1->Equals(*standard_strategy2)); + ASSERT_FALSE(standard_strategy1->Equals(*standard_strategy3)); + + // Test different strategy types + ASSERT_FALSE(default_strategy1->Equals(*standard_strategy1)); + ASSERT_FALSE(standard_strategy1->Equals(*default_strategy1)); +} + +TEST_F(S3OptionsTest, RetryStrategyInS3Options) { + // Test S3Options with null retry strategy + S3Options options_null = S3Options::Defaults(); + ASSERT_EQ(options_null.retry_strategy, nullptr); + + // Test S3Options with DefaultRetryStrategy - different max_attempts + S3Options options_default_3 = S3Options::Defaults(); + options_default_3.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(3); + ASSERT_NE(options_default_3.retry_strategy, nullptr); + + S3Options options_default_5 = S3Options::Defaults(); + options_default_5.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(5); + ASSERT_NE(options_default_5.retry_strategy, nullptr); + + // Test S3Options with StandardRetryStrategy - different max_attempts + S3Options options_standard_3 = S3Options::Defaults(); + options_standard_3.retry_strategy = S3RetryStrategy::GetAwsStandardRetryStrategy(3); + ASSERT_NE(options_standard_3.retry_strategy, nullptr); + + S3Options options_standard_5 = S3Options::Defaults(); + options_standard_5.retry_strategy = S3RetryStrategy::GetAwsStandardRetryStrategy(5); + ASSERT_NE(options_standard_5.retry_strategy, nullptr); + + // Test equality: same strategy type and max_attempts should be equal + S3Options options_default_3_copy = S3Options::Defaults(); + options_default_3_copy.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(3); + ASSERT_TRUE(options_default_3.Equals(options_default_3_copy)); + + S3Options options_standard_5_copy = S3Options::Defaults(); + options_standard_5_copy.retry_strategy = + S3RetryStrategy::GetAwsStandardRetryStrategy(5); + ASSERT_TRUE(options_standard_5.Equals(options_standard_5_copy)); + + // Test inequality: different max_attempts should not be equal + ASSERT_FALSE(options_default_3.Equals(options_default_5)); + ASSERT_FALSE(options_standard_3.Equals(options_standard_5)); + + // Test inequality: different strategy types should not be equal + ASSERT_FALSE(options_default_3.Equals(options_standard_3)); + ASSERT_FALSE(options_standard_5.Equals(options_default_5)); + + // Test inequality: null vs non-null retry strategy should not be equal + ASSERT_FALSE(options_null.Equals(options_default_3)); + ASSERT_FALSE(options_default_3.Equals(options_null)); +} + +TEST_F(S3OptionsTest, RetryStrategyInS3FileSystem) { + // Test S3FileSystem with null retry strategy + S3Options options_null = S3Options::Defaults(); + ASSERT_OK_AND_ASSIGN(auto fs_null, S3FileSystem::Make(options_null)); + ASSERT_EQ(fs_null->options().retry_strategy, nullptr); + + // Test S3FileSystem with DefaultRetryStrategy + S3Options options_default = S3Options::Defaults(); + options_default.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(3); + ASSERT_OK_AND_ASSIGN(auto fs_default, S3FileSystem::Make(options_default)); + ASSERT_NE(fs_default->options().retry_strategy, nullptr); + ASSERT_TRUE( + fs_default->options().retry_strategy->Equals(*options_default.retry_strategy)); + + // Test that same default strategy but different max_attempts create different file + // systems + S3Options options_default_5 = S3Options::Defaults(); + options_default_5.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(5); + ASSERT_OK_AND_ASSIGN(auto fs_default_5, S3FileSystem::Make(options_default_5)); + ASSERT_FALSE(fs_default->Equals(*fs_default_5)); + + // Test S3FileSystem with StandardRetryStrategy + S3Options options_standard = S3Options::Defaults(); + options_standard.retry_strategy = S3RetryStrategy::GetAwsStandardRetryStrategy(5); + ASSERT_OK_AND_ASSIGN(auto fs_standard, S3FileSystem::Make(options_standard)); + ASSERT_NE(fs_standard->options().retry_strategy, nullptr); + ASSERT_TRUE( + fs_standard->options().retry_strategy->Equals(*options_standard.retry_strategy)); + + // Test that different retry strategies create different file systems + ASSERT_FALSE(fs_null->Equals(*fs_default)); + ASSERT_FALSE(fs_default->Equals(*fs_standard)); + ASSERT_FALSE(fs_null->Equals(*fs_standard)); +} + //////////////////////////////////////////////////////////////////////////// // Region resolution test diff --git a/python/pyarrow/_s3fs.pyx b/python/pyarrow/_s3fs.pyx index 6317bd3785f..bfcd23c6ce4 100644 --- a/python/pyarrow/_s3fs.pyx +++ b/python/pyarrow/_s3fs.pyx @@ -120,6 +120,9 @@ class S3RetryStrategy: def __init__(self, max_attempts=3): self.max_attempts = max_attempts + def __reduce__(self): + return (self.__class__, (self.max_attempts,)) + class AwsStandardS3RetryStrategy(S3RetryStrategy): """ @@ -281,6 +284,7 @@ cdef class S3FileSystem(FileSystem): cdef: CS3FileSystem* s3fs + object _retry_strategy def __init__(self, *, access_key=None, secret_key=None, session_token=None, bint anonymous=False, region=None, request_timeout=None, @@ -412,9 +416,11 @@ cdef class S3FileSystem(FileSystem): if isinstance(retry_strategy, AwsStandardS3RetryStrategy): options.value().retry_strategy = CS3RetryStrategy.GetAwsStandardRetryStrategy( retry_strategy.max_attempts) + self._retry_strategy = retry_strategy elif isinstance(retry_strategy, AwsDefaultS3RetryStrategy): options.value().retry_strategy = CS3RetryStrategy.GetAwsDefaultRetryStrategy( retry_strategy.max_attempts) + self._retry_strategy = retry_strategy else: raise ValueError(f'Invalid retry_strategy {retry_strategy!r}') if tls_ca_file_path is not None: @@ -470,6 +476,7 @@ cdef class S3FileSystem(FileSystem): allow_bucket_creation=opts.allow_bucket_creation, allow_bucket_deletion=opts.allow_bucket_deletion, check_directory_existence_before_creation=opts.check_directory_existence_before_creation, + retry_strategy=self._retry_strategy, default_metadata=pyarrow_wrap_metadata(opts.default_metadata), proxy_options={'scheme': frombytes(opts.proxy_options.scheme), 'host': frombytes(opts.proxy_options.host), @@ -489,3 +496,10 @@ cdef class S3FileSystem(FileSystem): The AWS region this filesystem connects to. """ return frombytes(self.s3fs.region()) + + @property + def retry_strategy(self): + """ + The retry strategy currently configured for this S3 filesystem. + """ + return self._retry_strategy diff --git a/python/pyarrow/tests/test_fs.py b/python/pyarrow/tests/test_fs.py index 376398baa07..214e0fced05 100644 --- a/python/pyarrow/tests/test_fs.py +++ b/python/pyarrow/tests/test_fs.py @@ -1226,14 +1226,34 @@ def test_s3_options(pickle_module): assert isinstance(fs, S3FileSystem) assert pickle_module.loads(pickle_module.dumps(fs)) == fs - # Note that the retry strategy won't survive pickling for now - fs = S3FileSystem( + # Test S3FileSystem with different retry strategies + # They are equal only when the retry strategy is from the same class + # and the same parameters + fs_std_5 = S3FileSystem( retry_strategy=AwsStandardS3RetryStrategy(max_attempts=5)) - assert isinstance(fs, S3FileSystem) - - fs = S3FileSystem( + assert isinstance(fs_std_5, S3FileSystem) + assert pickle_module.loads(pickle_module.dumps(fs_std_5)) == fs_std_5 + assert fs_std_5.retry_strategy.max_attempts == 5 + + fs_std_10 = S3FileSystem( + retry_strategy=AwsStandardS3RetryStrategy(max_attempts=10)) + assert isinstance(fs_std_10, S3FileSystem) + assert pickle_module.loads(pickle_module.dumps(fs_std_10)) == fs_std_10 + assert fs_std_10.retry_strategy.max_attempts == 10 + assert fs_std_10 != fs_std_5 + + fs_std_10_2 = S3FileSystem( + retry_strategy=AwsStandardS3RetryStrategy(max_attempts=10)) + assert isinstance(fs_std_10_2, S3FileSystem) + assert pickle_module.loads(pickle_module.dumps(fs_std_10_2)) == fs_std_10_2 + assert fs_std_10_2 == fs_std_10 + + fs_def_5 = S3FileSystem( retry_strategy=AwsDefaultS3RetryStrategy(max_attempts=5)) - assert isinstance(fs, S3FileSystem) + assert isinstance(fs_def_5, S3FileSystem) + assert pickle_module.loads(pickle_module.dumps(fs_def_5)) == fs_def_5 + assert fs_def_5.retry_strategy.max_attempts == 5 + assert fs_def_5 != fs_std_5 fs2 = S3FileSystem(role_arn='role') assert isinstance(fs2, S3FileSystem)