Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Extra packages:
+---------------+---------------------------------------+------------------+
| Pandas | ``pip install PyAthena[Pandas]`` | >=1.3.0 |
+---------------+---------------------------------------+------------------+
| Arrow | ``pip install PyAthena[Arrow]`` | >=7.0.0 |
| Arrow | ``pip install PyAthena[Arrow]`` | >=10.0.0 |
+---------------+---------------------------------------+------------------+
| fastparquet | ``pip install PyAthena[fastparquet]`` | >=0.4.0 |
+---------------+---------------------------------------+------------------+
Expand Down
58 changes: 58 additions & 0 deletions docs/arrow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,52 @@ Try adding an alias to the SELECTed column, such as ``SELECT 1 AS name``.

pyathena.error.OperationalError: SYNTAX_ERROR: line 1:1: Column name not specified at position 1

S3 Timeout Configuration
~~~~~~~~~~~~~~~~~~~~~~~~~

ArrowCursor supports configuring S3 connection and request timeouts through ``connect_timeout`` and ``request_timeout`` parameters.
These parameters are particularly useful when experiencing timeout errors due to:

- Role assumption with AWS STS (cross-account access)
- High network latency between your environment and S3
- Connecting from regions far from the S3 bucket

By default, PyArrow uses AWS SDK default timeouts (typically 1 second for connection, 3 seconds for requests).
You can increase these values to accommodate slower authentication or network conditions.

.. code:: python

from pyathena import connect
from pyathena.arrow.cursor import ArrowCursor

# Configure higher timeouts for role assumption scenarios
cursor = connect(
s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/",
region_name="us-west-2",
cursor_class=ArrowCursor,
cursor_kwargs={
"connect_timeout": 10.0, # Socket connection timeout in seconds
"request_timeout": 30.0 # Request timeout in seconds
}
).cursor()

.. code:: python

from pyathena import connect
from pyathena.arrow.cursor import ArrowCursor

cursor = connect(
s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/",
region_name="us-west-2"
).cursor(ArrowCursor, connect_timeout=10.0, request_timeout=30.0)

The timeout parameters accept float values in seconds and apply to all S3 operations performed by the cursor,
including HeadObject and GetObject operations when retrieving query results.

.. note::

These timeout parameters require PyArrow >= 10.0.0, which added support for configuring S3FileSystem timeouts.

.. _async-arrow-cursor:

AsyncArrowCursor
Expand Down Expand Up @@ -426,6 +472,18 @@ As with AsyncArrowCursor, the UNLOAD option is also available.
region_name="us-west-2",
cursor_class=AsyncArrowCursor).cursor(unload=True)

AsyncArrowCursor also supports S3 timeout configuration using the same ``connect_timeout`` and ``request_timeout`` parameters as ArrowCursor.

.. code:: python

from pyathena import connect
from pyathena.arrow.async_cursor import AsyncArrowCursor

cursor = connect(
s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/",
region_name="us-west-2"
).cursor(AsyncArrowCursor, connect_timeout=10.0, request_timeout=30.0)

.. _`pyarrow.Table object`: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html
.. _`official unload documentation`: https://docs.aws.amazon.com/athena/latest/ug/unload.html
.. _`future object`: https://docs.python.org/3/library/concurrent.futures.html#future-objects
38 changes: 38 additions & 0 deletions pyathena/arrow/async_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,42 @@ def __init__(
unload: bool = False,
result_reuse_enable: bool = False,
result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES,
connect_timeout: Optional[float] = None,
request_timeout: Optional[float] = None,
**kwargs,
) -> None:
"""Initialize an AsyncArrowCursor.

Args:
s3_staging_dir: S3 location for query results.
schema_name: Default schema name.
catalog_name: Default catalog name.
work_group: Athena workgroup name.
poll_interval: Query status polling interval in seconds.
encryption_option: S3 encryption option (SSE_S3, SSE_KMS, CSE_KMS).
kms_key: KMS key ARN for encryption.
kill_on_interrupt: Cancel running query on keyboard interrupt.
max_workers: Maximum number of workers for concurrent execution.
arraysize: Number of rows to fetch per batch.
unload: Enable UNLOAD for high-performance Parquet output.
result_reuse_enable: Enable Athena query result reuse.
result_reuse_minutes: Minutes to reuse cached results.
connect_timeout: Socket connection timeout in seconds for S3 operations.
Defaults to AWS SDK default (typically 1 second) if not specified.
request_timeout: Request timeout in seconds for S3 operations.
Defaults to AWS SDK default (typically 3 seconds) if not specified.
Increase this value if you experience timeout errors when using
role assumption with STS or have high latency to S3.
**kwargs: Additional connection parameters.

Example:
>>> # Use higher timeouts for role assumption scenarios
>>> cursor = connection.cursor(
... AsyncArrowCursor,
... connect_timeout=10.0,
... request_timeout=30.0
... )
"""
super().__init__(
s3_staging_dir=s3_staging_dir,
schema_name=schema_name,
Expand All @@ -89,6 +123,8 @@ def __init__(
**kwargs,
)
self._unload = unload
self._connect_timeout = connect_timeout
self._request_timeout = request_timeout

@staticmethod
def get_default_converter(
Expand Down Expand Up @@ -125,6 +161,8 @@ def _collect_result_set(
retry_config=self._retry_config,
unload=self._unload,
unload_location=unload_location,
connect_timeout=self._connect_timeout,
request_timeout=self._request_timeout,
**kwargs,
)

Expand Down
37 changes: 37 additions & 0 deletions pyathena/arrow/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,41 @@ def __init__(
result_reuse_enable: bool = False,
result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES,
on_start_query_execution: Optional[Callable[[str], None]] = None,
connect_timeout: Optional[float] = None,
request_timeout: Optional[float] = None,
**kwargs,
) -> None:
"""Initialize an ArrowCursor.

Args:
s3_staging_dir: S3 location for query results.
schema_name: Default schema name.
catalog_name: Default catalog name.
work_group: Athena workgroup name.
poll_interval: Query status polling interval in seconds.
encryption_option: S3 encryption option (SSE_S3, SSE_KMS, CSE_KMS).
kms_key: KMS key ARN for encryption.
kill_on_interrupt: Cancel running query on keyboard interrupt.
unload: Enable UNLOAD for high-performance Parquet output.
result_reuse_enable: Enable Athena query result reuse.
result_reuse_minutes: Minutes to reuse cached results.
on_start_query_execution: Callback invoked when query starts.
connect_timeout: Socket connection timeout in seconds for S3 operations.
Defaults to AWS SDK default (typically 1 second) if not specified.
request_timeout: Request timeout in seconds for S3 operations.
Defaults to AWS SDK default (typically 3 seconds) if not specified.
Increase this value if you experience timeout errors when using
role assumption with STS or have high latency to S3.
**kwargs: Additional connection parameters.

Example:
>>> # Use higher timeouts for role assumption scenarios
>>> cursor = connection.cursor(
... ArrowCursor,
... connect_timeout=10,
... request_timeout=30
... )
"""
super().__init__(
s3_staging_dir=s3_staging_dir,
schema_name=schema_name,
Expand All @@ -80,6 +113,8 @@ def __init__(
)
self._unload = unload
self._on_start_query_execution = on_start_query_execution
self._connect_timeout = connect_timeout
self._request_timeout = request_timeout
self._query_id: Optional[str] = None
self._result_set: Optional[AthenaArrowResultSet] = None

Expand Down Expand Up @@ -205,6 +240,8 @@ def execute(
retry_config=self._retry_config,
unload=self._unload,
unload_location=unload_location,
connect_timeout=self._connect_timeout,
request_timeout=self._request_timeout,
**kwargs,
)
else:
Expand Down
20 changes: 18 additions & 2 deletions pyathena/arrow/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def __init__(
block_size: Optional[int] = None,
unload: bool = False,
unload_location: Optional[str] = None,
connect_timeout: Optional[float] = None,
request_timeout: Optional[float] = None,
**kwargs,
) -> None:
super().__init__(
Expand All @@ -108,6 +110,8 @@ def __init__(
self._block_size = block_size if block_size else self.DEFAULT_BLOCK_SIZE
self._unload = unload
self._unload_location = unload_location
self._connect_timeout = connect_timeout
self._request_timeout = request_timeout
self._kwargs = kwargs
self._fs = self.__s3_file_system()
if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location:
Expand All @@ -122,6 +126,14 @@ def __s3_file_system(self):
from pyarrow import fs

connection = self.connection

# Build timeout parameters dict
timeout_kwargs = {}
if self._connect_timeout is not None:
timeout_kwargs["connect_timeout"] = self._connect_timeout
if self._request_timeout is not None:
timeout_kwargs["request_timeout"] = self._request_timeout

if "role_arn" in connection._kwargs and connection._kwargs["role_arn"]:
external_id = connection._kwargs.get("external_id")
fs = fs.S3FileSystem(
Expand All @@ -130,6 +142,7 @@ def __s3_file_system(self):
external_id="" if external_id is None else external_id,
load_frequency=connection._kwargs["duration_seconds"],
region=connection.region_name,
**timeout_kwargs,
)
elif connection.profile_name:
profile = connection.session._session.full_config["profiles"][connection.profile_name]
Expand All @@ -138,6 +151,7 @@ def __s3_file_system(self):
secret_key=profile.get("aws_secret_access_key", None),
session_token=profile.get("aws_session_token", None),
region=connection.region_name,
**timeout_kwargs,
)
else:
# Try explicit credentials first
Expand All @@ -151,6 +165,7 @@ def __s3_file_system(self):
secret_key=explicit_secret_key,
session_token=connection._kwargs.get("aws_session_token"),
region=connection.region_name,
**timeout_kwargs,
)
else:
# Fall back to dynamic credentials from boto3 session
Expand All @@ -163,13 +178,14 @@ def __s3_file_system(self):
secret_key=credentials.secret_key,
session_token=credentials.token,
region=connection.region_name,
**timeout_kwargs,
)
else:
# Fall back to default (no explicit credentials)
fs = fs.S3FileSystem(region=connection.region_name)
fs = fs.S3FileSystem(region=connection.region_name, **timeout_kwargs)
except Exception:
# Fall back to default if credential retrieval fails
fs = fs.S3FileSystem(region=connection.region_name)
fs = fs.S3FileSystem(region=connection.region_name, **timeout_kwargs)

return fs

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ awsathena = "pyathena.sqlalchemy.base:AthenaDialect"
[project.optional-dependencies]
sqlalchemy = ["sqlalchemy>=1.0.0"]
pandas = ["pandas>=1.3.0"]
arrow = ["pyarrow>=7.0.0"]
arrow = ["pyarrow>=10.0.0"]
fastparquet = ["fastparquet>=0.4.0"]

[dependency-groups]
dev = [
"sqlalchemy>=1.0.0",
"pandas>=1.3.0",
"numpy>=1.26.0",
"pyarrow>=7.0.0",
"pyarrow>=10.0.0",
"fastparquet>=0.4.0",
"Jinja2>=3.1.0",
"mypy>=0.900",
Expand Down
43 changes: 43 additions & 0 deletions tests/pyathena/arrow/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,3 +642,46 @@ def test_callback(query_id: str):
assert len(callback_results) == 1
assert callback_results[0] == arrow_cursor.query_id
assert arrow_cursor.query_id is not None

@pytest.mark.parametrize(
"arrow_cursor",
[
{
"cursor_kwargs": {
"connect_timeout": 10,
"request_timeout": 30,
}
}
],
indirect=["arrow_cursor"],
)
def test_timeout_parameters(self, arrow_cursor):
"""Test that timeout parameters are correctly passed to ArrowCursor and result set."""
# Verify timeout parameters are set on cursor
assert arrow_cursor._connect_timeout == 10
assert arrow_cursor._request_timeout == 30

# Execute a simple query to create a result set
arrow_cursor.execute("SELECT 1")

# Verify timeout parameters are passed to result set
assert arrow_cursor.result_set._connect_timeout == 10
assert arrow_cursor.result_set._request_timeout == 30

@pytest.mark.parametrize(
"arrow_cursor",
[{"cursor_kwargs": {"connect_timeout": 5.5, "request_timeout": 15.5}}],
indirect=["arrow_cursor"],
)
def test_timeout_parameters_float(self, arrow_cursor):
"""Test that timeout parameters accept float values."""
# Verify float timeout parameters are set on cursor
assert arrow_cursor._connect_timeout == 5.5
assert arrow_cursor._request_timeout == 15.5

# Execute a simple query to create a result set
arrow_cursor.execute("SELECT 1")

# Verify float timeout parameters are passed to result set
assert arrow_cursor.result_set._connect_timeout == 5.5
assert arrow_cursor.result_set._request_timeout == 15.5