diff --git a/README.rst b/README.rst index d561e621..6a7b354a 100644 --- a/README.rst +++ b/README.rst @@ -60,7 +60,7 @@ Extra packages: +---------------+---------------------------------------+------------------+ | Pandas | ``pip install PyAthena[Pandas]`` | >=1.3.0 | +---------------+---------------------------------------+------------------+ -| Arrow | ``pip install PyAthena[Arrow]`` | >=7.0.0 | +| Arrow | ``pip install PyAthena[Arrow]`` | >=10.0.0 | +---------------+---------------------------------------+------------------+ | fastparquet | ``pip install PyAthena[fastparquet]`` | >=0.4.0 | +---------------+---------------------------------------+------------------+ diff --git a/docs/arrow.rst b/docs/arrow.rst index b6599897..9231eccb 100644 --- a/docs/arrow.rst +++ b/docs/arrow.rst @@ -252,6 +252,52 @@ Try adding an alias to the SELECTed column, such as ``SELECT 1 AS name``. pyathena.error.OperationalError: SYNTAX_ERROR: line 1:1: Column name not specified at position 1 +S3 Timeout Configuration +~~~~~~~~~~~~~~~~~~~~~~~~~ + +ArrowCursor supports configuring S3 connection and request timeouts through ``connect_timeout`` and ``request_timeout`` parameters. +These parameters are particularly useful when experiencing timeout errors due to: + +- Role assumption with AWS STS (cross-account access) +- High network latency between your environment and S3 +- Connecting from regions far from the S3 bucket + +By default, PyArrow uses AWS SDK default timeouts (typically 1 second for connection, 3 seconds for requests). +You can increase these values to accommodate slower authentication or network conditions. + +.. code:: python + + from pyathena import connect + from pyathena.arrow.cursor import ArrowCursor + + # Configure higher timeouts for role assumption scenarios + cursor = connect( + s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=ArrowCursor, + cursor_kwargs={ + "connect_timeout": 10.0, # Socket connection timeout in seconds + "request_timeout": 30.0 # Request timeout in seconds + } + ).cursor() + +.. code:: python + + from pyathena import connect + from pyathena.arrow.cursor import ArrowCursor + + cursor = connect( + s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2" + ).cursor(ArrowCursor, connect_timeout=10.0, request_timeout=30.0) + +The timeout parameters accept float values in seconds and apply to all S3 operations performed by the cursor, +including HeadObject and GetObject operations when retrieving query results. + +.. note:: + + These timeout parameters require PyArrow >= 10.0.0, which added support for configuring S3FileSystem timeouts. + .. _async-arrow-cursor: AsyncArrowCursor @@ -426,6 +472,18 @@ As with AsyncArrowCursor, the UNLOAD option is also available. region_name="us-west-2", cursor_class=AsyncArrowCursor).cursor(unload=True) +AsyncArrowCursor also supports S3 timeout configuration using the same ``connect_timeout`` and ``request_timeout`` parameters as ArrowCursor. + +.. code:: python + + from pyathena import connect + from pyathena.arrow.async_cursor import AsyncArrowCursor + + cursor = connect( + s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2" + ).cursor(AsyncArrowCursor, connect_timeout=10.0, request_timeout=30.0) + .. _`pyarrow.Table object`: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html .. _`official unload documentation`: https://docs.aws.amazon.com/athena/latest/ug/unload.html .. _`future object`: https://docs.python.org/3/library/concurrent.futures.html#future-objects diff --git a/pyathena/arrow/async_cursor.py b/pyathena/arrow/async_cursor.py index 9c76f4f7..1e9304b8 100644 --- a/pyathena/arrow/async_cursor.py +++ b/pyathena/arrow/async_cursor.py @@ -71,8 +71,42 @@ def __init__( unload: bool = False, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, + connect_timeout: Optional[float] = None, + request_timeout: Optional[float] = None, **kwargs, ) -> None: + """Initialize an AsyncArrowCursor. + + Args: + s3_staging_dir: S3 location for query results. + schema_name: Default schema name. + catalog_name: Default catalog name. + work_group: Athena workgroup name. + poll_interval: Query status polling interval in seconds. + encryption_option: S3 encryption option (SSE_S3, SSE_KMS, CSE_KMS). + kms_key: KMS key ARN for encryption. + kill_on_interrupt: Cancel running query on keyboard interrupt. + max_workers: Maximum number of workers for concurrent execution. + arraysize: Number of rows to fetch per batch. + unload: Enable UNLOAD for high-performance Parquet output. + result_reuse_enable: Enable Athena query result reuse. + result_reuse_minutes: Minutes to reuse cached results. + connect_timeout: Socket connection timeout in seconds for S3 operations. + Defaults to AWS SDK default (typically 1 second) if not specified. + request_timeout: Request timeout in seconds for S3 operations. + Defaults to AWS SDK default (typically 3 seconds) if not specified. + Increase this value if you experience timeout errors when using + role assumption with STS or have high latency to S3. + **kwargs: Additional connection parameters. + + Example: + >>> # Use higher timeouts for role assumption scenarios + >>> cursor = connection.cursor( + ... AsyncArrowCursor, + ... connect_timeout=10.0, + ... request_timeout=30.0 + ... ) + """ super().__init__( s3_staging_dir=s3_staging_dir, schema_name=schema_name, @@ -89,6 +123,8 @@ def __init__( **kwargs, ) self._unload = unload + self._connect_timeout = connect_timeout + self._request_timeout = request_timeout @staticmethod def get_default_converter( @@ -125,6 +161,8 @@ def _collect_result_set( retry_config=self._retry_config, unload=self._unload, unload_location=unload_location, + connect_timeout=self._connect_timeout, + request_timeout=self._request_timeout, **kwargs, ) diff --git a/pyathena/arrow/cursor.py b/pyathena/arrow/cursor.py index e3057cd2..cb3cd54d 100644 --- a/pyathena/arrow/cursor.py +++ b/pyathena/arrow/cursor.py @@ -63,8 +63,41 @@ def __init__( result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, on_start_query_execution: Optional[Callable[[str], None]] = None, + connect_timeout: Optional[float] = None, + request_timeout: Optional[float] = None, **kwargs, ) -> None: + """Initialize an ArrowCursor. + + Args: + s3_staging_dir: S3 location for query results. + schema_name: Default schema name. + catalog_name: Default catalog name. + work_group: Athena workgroup name. + poll_interval: Query status polling interval in seconds. + encryption_option: S3 encryption option (SSE_S3, SSE_KMS, CSE_KMS). + kms_key: KMS key ARN for encryption. + kill_on_interrupt: Cancel running query on keyboard interrupt. + unload: Enable UNLOAD for high-performance Parquet output. + result_reuse_enable: Enable Athena query result reuse. + result_reuse_minutes: Minutes to reuse cached results. + on_start_query_execution: Callback invoked when query starts. + connect_timeout: Socket connection timeout in seconds for S3 operations. + Defaults to AWS SDK default (typically 1 second) if not specified. + request_timeout: Request timeout in seconds for S3 operations. + Defaults to AWS SDK default (typically 3 seconds) if not specified. + Increase this value if you experience timeout errors when using + role assumption with STS or have high latency to S3. + **kwargs: Additional connection parameters. + + Example: + >>> # Use higher timeouts for role assumption scenarios + >>> cursor = connection.cursor( + ... ArrowCursor, + ... connect_timeout=10, + ... request_timeout=30 + ... ) + """ super().__init__( s3_staging_dir=s3_staging_dir, schema_name=schema_name, @@ -80,6 +113,8 @@ def __init__( ) self._unload = unload self._on_start_query_execution = on_start_query_execution + self._connect_timeout = connect_timeout + self._request_timeout = request_timeout self._query_id: Optional[str] = None self._result_set: Optional[AthenaArrowResultSet] = None @@ -205,6 +240,8 @@ def execute( retry_config=self._retry_config, unload=self._unload, unload_location=unload_location, + connect_timeout=self._connect_timeout, + request_timeout=self._request_timeout, **kwargs, ) else: diff --git a/pyathena/arrow/result_set.py b/pyathena/arrow/result_set.py index 5e4e27c0..40b03125 100644 --- a/pyathena/arrow/result_set.py +++ b/pyathena/arrow/result_set.py @@ -94,6 +94,8 @@ def __init__( block_size: Optional[int] = None, unload: bool = False, unload_location: Optional[str] = None, + connect_timeout: Optional[float] = None, + request_timeout: Optional[float] = None, **kwargs, ) -> None: super().__init__( @@ -108,6 +110,8 @@ def __init__( self._block_size = block_size if block_size else self.DEFAULT_BLOCK_SIZE self._unload = unload self._unload_location = unload_location + self._connect_timeout = connect_timeout + self._request_timeout = request_timeout self._kwargs = kwargs self._fs = self.__s3_file_system() if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location: @@ -122,6 +126,14 @@ def __s3_file_system(self): from pyarrow import fs connection = self.connection + + # Build timeout parameters dict + timeout_kwargs = {} + if self._connect_timeout is not None: + timeout_kwargs["connect_timeout"] = self._connect_timeout + if self._request_timeout is not None: + timeout_kwargs["request_timeout"] = self._request_timeout + if "role_arn" in connection._kwargs and connection._kwargs["role_arn"]: external_id = connection._kwargs.get("external_id") fs = fs.S3FileSystem( @@ -130,6 +142,7 @@ def __s3_file_system(self): external_id="" if external_id is None else external_id, load_frequency=connection._kwargs["duration_seconds"], region=connection.region_name, + **timeout_kwargs, ) elif connection.profile_name: profile = connection.session._session.full_config["profiles"][connection.profile_name] @@ -138,6 +151,7 @@ def __s3_file_system(self): secret_key=profile.get("aws_secret_access_key", None), session_token=profile.get("aws_session_token", None), region=connection.region_name, + **timeout_kwargs, ) else: # Try explicit credentials first @@ -151,6 +165,7 @@ def __s3_file_system(self): secret_key=explicit_secret_key, session_token=connection._kwargs.get("aws_session_token"), region=connection.region_name, + **timeout_kwargs, ) else: # Fall back to dynamic credentials from boto3 session @@ -163,13 +178,14 @@ def __s3_file_system(self): secret_key=credentials.secret_key, session_token=credentials.token, region=connection.region_name, + **timeout_kwargs, ) else: # Fall back to default (no explicit credentials) - fs = fs.S3FileSystem(region=connection.region_name) + fs = fs.S3FileSystem(region=connection.region_name, **timeout_kwargs) except Exception: # Fall back to default if credential retrieval fails - fs = fs.S3FileSystem(region=connection.region_name) + fs = fs.S3FileSystem(region=connection.region_name, **timeout_kwargs) return fs diff --git a/pyproject.toml b/pyproject.toml index 82d16ada..f1a84ac2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ awsathena = "pyathena.sqlalchemy.base:AthenaDialect" [project.optional-dependencies] sqlalchemy = ["sqlalchemy>=1.0.0"] pandas = ["pandas>=1.3.0"] -arrow = ["pyarrow>=7.0.0"] +arrow = ["pyarrow>=10.0.0"] fastparquet = ["fastparquet>=0.4.0"] [dependency-groups] @@ -50,7 +50,7 @@ dev = [ "sqlalchemy>=1.0.0", "pandas>=1.3.0", "numpy>=1.26.0", - "pyarrow>=7.0.0", + "pyarrow>=10.0.0", "fastparquet>=0.4.0", "Jinja2>=3.1.0", "mypy>=0.900", diff --git a/tests/pyathena/arrow/test_cursor.py b/tests/pyathena/arrow/test_cursor.py index 1755d5b6..046232ec 100644 --- a/tests/pyathena/arrow/test_cursor.py +++ b/tests/pyathena/arrow/test_cursor.py @@ -642,3 +642,46 @@ def test_callback(query_id: str): assert len(callback_results) == 1 assert callback_results[0] == arrow_cursor.query_id assert arrow_cursor.query_id is not None + + @pytest.mark.parametrize( + "arrow_cursor", + [ + { + "cursor_kwargs": { + "connect_timeout": 10, + "request_timeout": 30, + } + } + ], + indirect=["arrow_cursor"], + ) + def test_timeout_parameters(self, arrow_cursor): + """Test that timeout parameters are correctly passed to ArrowCursor and result set.""" + # Verify timeout parameters are set on cursor + assert arrow_cursor._connect_timeout == 10 + assert arrow_cursor._request_timeout == 30 + + # Execute a simple query to create a result set + arrow_cursor.execute("SELECT 1") + + # Verify timeout parameters are passed to result set + assert arrow_cursor.result_set._connect_timeout == 10 + assert arrow_cursor.result_set._request_timeout == 30 + + @pytest.mark.parametrize( + "arrow_cursor", + [{"cursor_kwargs": {"connect_timeout": 5.5, "request_timeout": 15.5}}], + indirect=["arrow_cursor"], + ) + def test_timeout_parameters_float(self, arrow_cursor): + """Test that timeout parameters accept float values.""" + # Verify float timeout parameters are set on cursor + assert arrow_cursor._connect_timeout == 5.5 + assert arrow_cursor._request_timeout == 15.5 + + # Execute a simple query to create a result set + arrow_cursor.execute("SELECT 1") + + # Verify float timeout parameters are passed to result set + assert arrow_cursor.result_set._connect_timeout == 5.5 + assert arrow_cursor.result_set._request_timeout == 15.5