From d2f5c2c8e5155f8958ce1ec91e2539b308a96104 Mon Sep 17 00:00:00 2001 From: teodor-delibasic_data Date: Wed, 11 Mar 2026 09:29:46 +0000 Subject: [PATCH 01/13] Initial commit Signed-off-by: teodor-delibasic_data --- python/NEXT_CHANGELOG.md | 20 + python/examples/async_example_arrow.py | 200 +++++++ python/examples/sync_example_arrow.py | 191 +++++++ python/pyproject.toml | 4 + python/rust/Cargo.toml | 5 +- python/rust/src/arrow.rs | 687 +++++++++++++++++++++++++ python/rust/src/async_wrapper.rs | 52 ++ python/rust/src/lib.rs | 8 + python/rust/src/sync_wrapper.rs | 61 +++ python/tests/test_arrow.py | 309 +++++++++++ python/zerobus/__init__.py | 6 +- python/zerobus/sdk/aio/__init__.py | 2 + python/zerobus/sdk/aio/zerobus_sdk.py | 101 +++- python/zerobus/sdk/shared/arrow.py | 90 ++++ python/zerobus/sdk/sync/__init__.py | 2 + python/zerobus/sdk/sync/zerobus_sdk.py | 120 ++++- 16 files changed, 1854 insertions(+), 4 deletions(-) create mode 100644 python/examples/async_example_arrow.py create mode 100644 python/examples/sync_example_arrow.py create mode 100644 python/rust/src/arrow.rs create mode 100644 python/tests/test_arrow.py create mode 100644 python/zerobus/sdk/shared/arrow.py diff --git a/python/NEXT_CHANGELOG.md b/python/NEXT_CHANGELOG.md index d5e2f40..d735d2b 100644 --- a/python/NEXT_CHANGELOG.md +++ b/python/NEXT_CHANGELOG.md @@ -6,15 +6,35 @@ ### New Features and Improvements +- **Arrow Flight Support (Experimental)**: Added support for ingesting `pyarrow.RecordBatch` and `pyarrow.Table` objects via Arrow Flight protocol + - **Note**: Arrow Flight is not yet supported by default from the Zerobus server side. + - New `ZerobusArrowStream` class (sync in `zerobus.sdk.sync`, async in `zerobus.sdk.aio`) with `ingest_batch()`, `wait_for_offset()`, `flush()`, `close()`, `get_unacked_batches()` methods + - New `ArrowStreamConfigurationOptions` for configuring Arrow streams (max inflight batches, recovery, timeouts) + - New `create_arrow_stream()` and `recreate_arrow_stream()` methods on both sync and async `ZerobusSdk` + - Accepts both `pyarrow.RecordBatch` and `pyarrow.Table` (Tables are combined to a single batch internally) + - Arrow is opt-in: install via `pip install databricks-zerobus-ingest-sdk[arrow]` (requires `pyarrow>=14.0.0`) + - Arrow types gated behind `_core.arrow` submodule — not loaded unless pyarrow is installed + - Available from both `zerobus.sdk.sync` and `zerobus.sdk.aio`, and re-exported from top-level `zerobus` package + ### Bug Fixes ### Documentation ### Internal Changes +- Bumped Rust SDK dependency to v1.0.1 with `arrow-flight` feature +- Added `arrow-ipc`, `arrow-schema`, `arrow-array` (v56.2.0) Rust dependencies for IPC serialization +- Added PyO3 arrow module (`arrow.rs`) with `ArrowStreamConfigurationOptions`, `ZerobusArrowStream`, `AsyncZerobusArrowStream` pyclasses +- Added Python-side serialization helpers in `zerobus.sdk.shared.arrow` (`_serialize_schema`, `_serialize_batch`, `_deserialize_batch`) + ### Breaking Changes ### Deprecations ### API Changes +- Added `create_arrow_stream(table_name, schema, client_id, client_secret, options=None, headers_provider=None)` to sync and async `ZerobusSdk` +- Added `recreate_arrow_stream(old_stream)` to sync and async `ZerobusSdk` +- Added `ZerobusArrowStream` class (sync and async variants) with methods: `ingest_batch()`, `wait_for_offset()`, `flush()`, `close()`, `get_unacked_batches()`, properties: `is_closed`, `table_name` +- Added `ArrowStreamConfigurationOptions` class with fields: `max_inflight_batches`, `recovery`, `recovery_timeout_ms`, `recovery_backoff_ms`, `recovery_retries`, `server_lack_of_ack_timeout_ms`, `flush_timeout_ms`, `connection_timeout_ms` +- Added optional dependency: `pyarrow>=14.0.0` via `pip install databricks-zerobus-ingest-sdk[arrow]` diff --git a/python/examples/async_example_arrow.py b/python/examples/async_example_arrow.py new file mode 100644 index 0000000..23225ce --- /dev/null +++ b/python/examples/async_example_arrow.py @@ -0,0 +1,200 @@ +""" +Asynchronous Ingestion Example - Arrow Flight Mode + +This example demonstrates record ingestion using the asynchronous API with Arrow Flight. + +Record Type Mode: Arrow (RecordBatch) + - Records are sent as pyarrow RecordBatches + - Uses Arrow Flight protocol for columnar data transfer + - Best for structured/columnar data, DataFrames, Parquet workflows + +Requirements: + pip install databricks-zerobus-ingest-sdk[arrow] + +Note: Arrow Flight support is experimental and not yet supported for production use. +""" + +import asyncio +import logging +import os +import time + +import pyarrow as pa + +from zerobus.sdk.aio import ZerobusSdk +from zerobus.sdk.shared.arrow import ArrowStreamConfigurationOptions + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +# Configuration - update these with your values +SERVER_ENDPOINT = os.getenv("ZEROBUS_SERVER_ENDPOINT", "https://your-shard-id.zerobus.region.cloud.databricks.com") +UNITY_CATALOG_ENDPOINT = os.getenv("DATABRICKS_WORKSPACE_URL", "https://your-workspace.cloud.databricks.com") +TABLE_NAME = os.getenv("ZEROBUS_TABLE_NAME", "catalog.schema.table") + +# For OAuth authentication +CLIENT_ID = os.getenv("DATABRICKS_CLIENT_ID", "your-oauth-client-id") +CLIENT_SECRET = os.getenv("DATABRICKS_CLIENT_SECRET", "your-oauth-client-secret") + +# Number of batches to ingest +NUM_BATCHES = 10 +ROWS_PER_BATCH = 100 + +# Define the Arrow schema +SCHEMA = pa.schema([ + ("device_name", pa.large_utf8()), + ("temp", pa.int32()), + ("humidity", pa.int64()), +]) + + +def create_sample_batch(batch_index): + """ + Creates a sample RecordBatch with air quality data. + + Returns a pyarrow.RecordBatch with ROWS_PER_BATCH rows. + """ + return pa.record_batch( + { + "device_name": [f"sensor-{(batch_index * ROWS_PER_BATCH + i) % 10}" for i in range(ROWS_PER_BATCH)], + "temp": [20 + ((batch_index * ROWS_PER_BATCH + i) % 15) for i in range(ROWS_PER_BATCH)], + "humidity": [50 + ((batch_index * ROWS_PER_BATCH + i) % 40) for i in range(ROWS_PER_BATCH)], + }, + schema=SCHEMA, + ) + + +async def main(): + print("Starting asynchronous ingestion example (Arrow Flight Mode)...") + print("=" * 60) + + # Check if credentials are configured + if CLIENT_ID == "your-oauth-client-id" or CLIENT_SECRET == "your-oauth-client-secret": + logger.error("Please set DATABRICKS_CLIENT_ID and DATABRICKS_CLIENT_SECRET environment variables") + return + + if SERVER_ENDPOINT == "https://your-shard-id.zerobus.region.cloud.databricks.com": + logger.error("Please set ZEROBUS_SERVER_ENDPOINT environment variable") + return + + if TABLE_NAME == "catalog.schema.table": + logger.error("Please set ZEROBUS_TABLE_NAME environment variable") + return + + try: + # Step 1: Initialize the SDK + sdk = ZerobusSdk(SERVER_ENDPOINT, UNITY_CATALOG_ENDPOINT) + logger.info("SDK initialized") + + # Step 2: Configure arrow stream options (all optional, shown with defaults) + options = ArrowStreamConfigurationOptions( + max_inflight_batches=10, + recovery=True, + recovery_timeout_ms=15000, + recovery_backoff_ms=2000, + recovery_retries=3, + ) + logger.info("Arrow stream configuration created") + + # Step 3: Create an Arrow Flight stream + # + # Pass a pyarrow.Schema - the SDK handles serialization internally. + # The SDK automatically: + # - Includes authorization header with OAuth token + # - Includes x-databricks-zerobus-table-name header + stream = await sdk.create_arrow_stream( + TABLE_NAME, SCHEMA, CLIENT_ID, CLIENT_SECRET, options + ) + logger.info(f"Arrow stream created for table: {stream.table_name}") + + # Step 4: Ingest Arrow RecordBatches asynchronously + logger.info(f"\nIngesting {NUM_BATCHES} batches of {ROWS_PER_BATCH} rows each...") + start_time = time.time() + total_rows = 0 + + try: + # ======================================================================== + # Ingest RecordBatches - each call returns an offset + # ======================================================================== + offsets = [] + for i in range(NUM_BATCHES): + batch = create_sample_batch(i) + offset = await stream.ingest_batch(batch) + offsets.append(offset) + total_rows += batch.num_rows + logger.info(f" Batch {i + 1}: {batch.num_rows} rows, offset: {offset}") + + # ======================================================================== + # You can also ingest a pyarrow.Table directly + # The SDK converts it to a single RecordBatch internally + # ======================================================================== + table = pa.table( + { + "device_name": [f"sensor-table-{i}" for i in range(50)], + "temp": list(range(20, 70)), + "humidity": list(range(50, 100)), + }, + schema=SCHEMA, + ) + offset = await stream.ingest_batch(table) + offsets.append(offset) + total_rows += table.num_rows + logger.info(f" Table ingested: {table.num_rows} rows, offset: {offset}") + + submit_end_time = time.time() + submit_duration = submit_end_time - start_time + logger.info(f"\nAll batches submitted in {submit_duration:.2f} seconds") + + # ======================================================================== + # Wait for the last offset to be acknowledged + # ======================================================================== + logger.info(f"Waiting for offset {offsets[-1]} to be acknowledged...") + await stream.wait_for_offset(offsets[-1]) + logger.info(f" Offset {offsets[-1]} acknowledged") + + # Step 5: Flush and close the stream + logger.info("\nFlushing stream...") + await stream.flush() + logger.info("Stream flushed") + + end_time = time.time() + total_duration = end_time - start_time + rows_per_second = total_rows / total_duration + + await stream.close() + logger.info("Stream closed") + + # Print summary + print("\n" + "=" * 60) + print("Ingestion Summary:") + print(f" Total batches: {NUM_BATCHES + 1}") + print(f" Total rows: {total_rows}") + print(f" Submit time: {submit_duration:.2f} seconds") + print(f" Total time: {total_duration:.2f} seconds") + print(f" Throughput: {rows_per_second:.2f} rows/sec") + print(f" Record type: Arrow Flight (RecordBatch)") + print("=" * 60) + + except Exception as e: + logger.error(f"\nError during ingestion: {e}") + + # On failure, you can retrieve unacked batches for retry + if stream.is_closed: + unacked = await stream.get_unacked_batches() + if unacked: + logger.info(f" {len(unacked)} unacked batches available for retry") + for i, batch in enumerate(unacked): + logger.info(f" Batch {i}: {batch.num_rows} rows, schema: {batch.schema}") + + await stream.close() + raise + + except Exception as e: + logger.error(f"\nFailed to initialize stream: {e}") + raise + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/examples/sync_example_arrow.py b/python/examples/sync_example_arrow.py new file mode 100644 index 0000000..b066c8d --- /dev/null +++ b/python/examples/sync_example_arrow.py @@ -0,0 +1,191 @@ +""" +Synchronous Ingestion Example - Arrow Flight Mode + +This example demonstrates record ingestion using the synchronous API with Arrow Flight. + +Record Type Mode: Arrow (RecordBatch) + - Records are sent as pyarrow RecordBatches + - Uses Arrow Flight protocol for columnar data transfer + - Best for structured/columnar data, DataFrames, Parquet workflows + +Requirements: + pip install databricks-zerobus-ingest-sdk[arrow] + +Note: Arrow Flight support is experimental and not yet supported for production use. +""" + +import logging +import os +import time + +import pyarrow as pa + +from zerobus.sdk.shared.arrow import ArrowStreamConfigurationOptions +from zerobus.sdk.sync import ZerobusSdk + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +# Configuration - update these with your values +SERVER_ENDPOINT = os.getenv("ZEROBUS_SERVER_ENDPOINT", "https://your-shard-id.zerobus.region.cloud.databricks.com") +UNITY_CATALOG_ENDPOINT = os.getenv("DATABRICKS_WORKSPACE_URL", "https://your-workspace.cloud.databricks.com") +TABLE_NAME = os.getenv("ZEROBUS_TABLE_NAME", "catalog.schema.table") + +# For OAuth authentication +CLIENT_ID = os.getenv("DATABRICKS_CLIENT_ID", "your-oauth-client-id") +CLIENT_SECRET = os.getenv("DATABRICKS_CLIENT_SECRET", "your-oauth-client-secret") + +# Number of batches to ingest +NUM_BATCHES = 10 +ROWS_PER_BATCH = 100 + +# Define the Arrow schema +SCHEMA = pa.schema([ + ("device_name", pa.large_utf8()), + ("temp", pa.int32()), + ("humidity", pa.int64()), +]) + + +def create_sample_batch(batch_index): + """ + Creates a sample RecordBatch with air quality data. + + Returns a pyarrow.RecordBatch with ROWS_PER_BATCH rows. + """ + return pa.record_batch( + { + "device_name": [f"sensor-{(batch_index * ROWS_PER_BATCH + i) % 10}" for i in range(ROWS_PER_BATCH)], + "temp": [20 + ((batch_index * ROWS_PER_BATCH + i) % 15) for i in range(ROWS_PER_BATCH)], + "humidity": [50 + ((batch_index * ROWS_PER_BATCH + i) % 40) for i in range(ROWS_PER_BATCH)], + }, + schema=SCHEMA, + ) + + +def main(): + print("Starting synchronous ingestion example (Arrow Flight Mode)...") + print("=" * 60) + + # Check if credentials are configured + if CLIENT_ID == "your-oauth-client-id" or CLIENT_SECRET == "your-oauth-client-secret": + logger.error("Please set DATABRICKS_CLIENT_ID and DATABRICKS_CLIENT_SECRET environment variables") + return + + if SERVER_ENDPOINT == "https://your-shard-id.zerobus.region.cloud.databricks.com": + logger.error("Please set ZEROBUS_SERVER_ENDPOINT environment variable") + return + + if TABLE_NAME == "catalog.schema.table": + logger.error("Please set ZEROBUS_TABLE_NAME environment variable") + return + + try: + # Step 1: Initialize the SDK + sdk = ZerobusSdk(SERVER_ENDPOINT, UNITY_CATALOG_ENDPOINT) + logger.info("SDK initialized") + + # Step 2: Configure arrow stream options (all optional, shown with defaults) + options = ArrowStreamConfigurationOptions( + max_inflight_batches=10, + recovery=True, + recovery_timeout_ms=15000, + recovery_backoff_ms=2000, + recovery_retries=3, + ) + logger.info("Arrow stream configuration created") + + # Step 3: Create an Arrow Flight stream + # + # Pass a pyarrow.Schema - the SDK handles serialization internally. + # The SDK automatically: + # - Includes authorization header with OAuth token + # - Includes x-databricks-zerobus-table-name header + stream = sdk.create_arrow_stream( + TABLE_NAME, SCHEMA, CLIENT_ID, CLIENT_SECRET, options + ) + logger.info(f"Arrow stream created for table: {stream.table_name}") + + # Step 4: Ingest Arrow RecordBatches + logger.info(f"\nIngesting {NUM_BATCHES} batches of {ROWS_PER_BATCH} rows each...") + start_time = time.time() + total_rows = 0 + + try: + # ======================================================================== + # Ingest RecordBatches - each call returns an offset + # ======================================================================== + for i in range(NUM_BATCHES): + batch = create_sample_batch(i) + offset = stream.ingest_batch(batch) + total_rows += batch.num_rows + logger.info(f" Batch {i + 1}: {batch.num_rows} rows, offset: {offset}") + + # ======================================================================== + # You can also ingest a pyarrow.Table directly + # The SDK converts it to a single RecordBatch internally + # ======================================================================== + table = pa.table( + { + "device_name": [f"sensor-table-{i}" for i in range(50)], + "temp": list(range(20, 70)), + "humidity": list(range(50, 100)), + }, + schema=SCHEMA, + ) + offset = stream.ingest_batch(table) + total_rows += table.num_rows + logger.info(f" Table ingested: {table.num_rows} rows, offset: {offset}") + + # ======================================================================== + # Wait for a specific offset to be acknowledged + # ======================================================================== + logger.info(f"\nWaiting for offset {offset} to be acknowledged...") + stream.wait_for_offset(offset) + logger.info(f" Offset {offset} acknowledged") + + end_time = time.time() + duration_seconds = end_time - start_time + rows_per_second = total_rows / duration_seconds + + # Step 5: Flush and close the stream + logger.info("\nFlushing stream...") + stream.flush() + logger.info("Stream flushed") + + stream.close() + logger.info("Stream closed") + + # Print summary + print("\n" + "=" * 60) + print("Ingestion Summary:") + print(f" Total batches: {NUM_BATCHES + 1}") + print(f" Total rows: {total_rows}") + print(f" Duration: {duration_seconds:.2f} seconds") + print(f" Throughput: {rows_per_second:.2f} rows/sec") + print(f" Record type: Arrow Flight (RecordBatch)") + print("=" * 60) + + except Exception as e: + logger.error(f"\nError during ingestion: {e}") + + # On failure, you can retrieve unacked batches for retry + if stream.is_closed: + unacked = stream.get_unacked_batches() + if unacked: + logger.info(f" {len(unacked)} unacked batches available for retry") + for i, batch in enumerate(unacked): + logger.info(f" Batch {i}: {batch.num_rows} rows, schema: {batch.schema}") + + stream.close() + raise + + except Exception as e: + logger.error(f"\nFailed to initialize stream: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/python/pyproject.toml b/python/pyproject.toml index 4dac31d..cc547a1 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -28,6 +28,9 @@ dependencies = [ ] [project.optional-dependencies] +arrow = [ + "pyarrow>=14.0.0", +] dev = [ # Build tools "wheel", @@ -43,6 +46,7 @@ dev = [ "pytest-cov", "pytest-xdist>=3.6.1,<4.0", "pytest-asyncio", + "pyarrow>=14.0.0", ] [tool.maturin] diff --git a/python/rust/Cargo.toml b/python/rust/Cargo.toml index 9594fa6..871a79f 100644 --- a/python/rust/Cargo.toml +++ b/python/rust/Cargo.toml @@ -11,9 +11,12 @@ crate-type = ["cdylib"] pyo3 = { version = "0.20", features = ["extension-module", "abi3-py39"] } pyo3-asyncio = { version = "0.20", features = ["tokio-runtime"] } tokio = { version = "1", features = ["full"] } -databricks-zerobus-ingest-sdk = "1.0.0" +databricks-zerobus-ingest-sdk = { version = "1.0.1", features = ["arrow-flight"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] } async-trait = "0.1" prost = "0.13" prost-types = "0.13" tonic = "0.13" +arrow-ipc = { version = "56.2.0", default-features = false } +arrow-schema = { version = "56.2.0", default-features = false } +arrow-array = { version = "56.2.0", default-features = false } diff --git a/python/rust/src/arrow.rs b/python/rust/src/arrow.rs new file mode 100644 index 0000000..8fb1e48 --- /dev/null +++ b/python/rust/src/arrow.rs @@ -0,0 +1,687 @@ +//! PyO3 bindings for Arrow Flight stream support. +//! +//! These types are always compiled into the wheel, but the Python-side API +//! gates usage on `pyarrow` being installed at runtime. + +use std::sync::Arc; + +use pyo3::prelude::*; +use pyo3::types::PyBytes; +use tokio::sync::RwLock; + +use databricks_zerobus_ingest_sdk::{ + ArrowStreamConfigurationOptions as RustArrowStreamOptions, + ArrowTableProperties as RustArrowTableProperties, + ZerobusArrowStream as RustZerobusArrowStream, ZerobusError as RustError, + ZerobusSdk as RustSdk, +}; + +use crate::auth::HeadersProviderWrapper; +use crate::common::map_error; + +fn map_rust_error_to_pyerr(err: RustError) -> PyErr { + map_error(err) +} + +/// Deserialize Arrow IPC bytes into a RecordBatch. +fn ipc_bytes_to_record_batch( + ipc_bytes: &[u8], +) -> Result { + let reader = + arrow_ipc::reader::StreamReader::try_new(ipc_bytes, None).map_err(|e| { + RustError::InvalidArgument(format!("Failed to parse Arrow IPC data: {}", e)) + })?; + + let mut batches = Vec::new(); + for batch_result in reader { + let batch = batch_result.map_err(|e| { + RustError::InvalidArgument(format!("Failed to read Arrow batch: {}", e)) + })?; + batches.push(batch); + } + + if batches.is_empty() { + return Err(RustError::InvalidArgument( + "No batches found in Arrow IPC data".to_string(), + )); + } + + Ok(batches.into_iter().next().unwrap()) +} + +/// Serialize a RecordBatch to Arrow IPC bytes. +fn record_batch_to_ipc_bytes( + batch: &arrow_array::RecordBatch, +) -> Result, RustError> { + let mut buffer = Vec::new(); + { + let mut writer = + arrow_ipc::writer::StreamWriter::try_new(&mut buffer, &batch.schema()) + .map_err(|e| { + RustError::InvalidArgument(format!( + "Failed to create Arrow IPC writer: {}", + e + )) + })?; + writer.write(batch).map_err(|e| { + RustError::InvalidArgument(format!("Failed to write Arrow batch: {}", e)) + })?; + writer.finish().map_err(|e| { + RustError::InvalidArgument(format!( + "Failed to finish Arrow IPC stream: {}", + e + )) + })?; + } + Ok(buffer) +} + +/// Build an ArrowSchema from IPC-serialized schema bytes. +/// +/// Python side calls `schema.serialize().to_pybytes()` on a `pyarrow.Schema` +/// to produce the IPC stream bytes. The schema is the first message in the stream. +fn ipc_schema_bytes_to_arrow_schema( + schema_bytes: &[u8], +) -> Result { + let reader = arrow_ipc::reader::StreamReader::try_new(schema_bytes, None) + .map_err(|e| { + RustError::InvalidArgument(format!( + "Failed to parse Arrow schema bytes: {}. \ + Pass schema bytes obtained from pyarrow.Schema via \ + schema.serialize().to_pybytes()", + e + )) + })?; + Ok(reader.schema().as_ref().clone()) +} + +// ============================================================================= +// ARROW STREAM CONFIGURATION OPTIONS +// ============================================================================= + +/// Configuration options for Arrow Flight streams. +#[pyclass] +#[derive(Clone)] +pub struct ArrowStreamConfigurationOptions { + #[pyo3(get, set)] + pub max_inflight_batches: i32, + + #[pyo3(get, set)] + pub recovery: bool, + + #[pyo3(get, set)] + pub recovery_timeout_ms: i64, + + #[pyo3(get, set)] + pub recovery_backoff_ms: i64, + + #[pyo3(get, set)] + pub recovery_retries: i32, + + #[pyo3(get, set)] + pub server_lack_of_ack_timeout_ms: i64, + + #[pyo3(get, set)] + pub flush_timeout_ms: i64, + + #[pyo3(get, set)] + pub connection_timeout_ms: i64, +} + +impl Default for ArrowStreamConfigurationOptions { + fn default() -> Self { + let rust_default = RustArrowStreamOptions::default(); + Self { + max_inflight_batches: rust_default.max_inflight_batches as i32, + recovery: rust_default.recovery, + recovery_timeout_ms: rust_default.recovery_timeout_ms as i64, + recovery_backoff_ms: rust_default.recovery_backoff_ms as i64, + recovery_retries: rust_default.recovery_retries as i32, + server_lack_of_ack_timeout_ms: rust_default.server_lack_of_ack_timeout_ms + as i64, + flush_timeout_ms: rust_default.flush_timeout_ms as i64, + connection_timeout_ms: rust_default.connection_timeout_ms as i64, + } + } +} + +#[pymethods] +impl ArrowStreamConfigurationOptions { + #[new] + #[pyo3(signature = (**kwargs))] + fn new(kwargs: Option<&pyo3::types::PyDict>) -> PyResult { + let mut options = Self::default(); + + if let Some(kwargs) = kwargs { + for (key, value) in kwargs { + let key_str: &str = key.extract()?; + match key_str { + "max_inflight_batches" => { + options.max_inflight_batches = value.extract()? + } + "recovery" => options.recovery = value.extract()?, + "recovery_timeout_ms" => { + options.recovery_timeout_ms = value.extract()? + } + "recovery_backoff_ms" => { + options.recovery_backoff_ms = value.extract()? + } + "recovery_retries" => options.recovery_retries = value.extract()?, + "server_lack_of_ack_timeout_ms" => { + options.server_lack_of_ack_timeout_ms = value.extract()? + } + "flush_timeout_ms" => options.flush_timeout_ms = value.extract()?, + "connection_timeout_ms" => { + options.connection_timeout_ms = value.extract()? + } + _ => { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Unknown configuration option: {}", + key_str + ))); + } + } + } + } + + Ok(options) + } + + fn __repr__(&self) -> String { + format!( + "ArrowStreamConfigurationOptions(max_inflight_batches={}, recovery={}, \ + recovery_timeout_ms={}, recovery_backoff_ms={}, recovery_retries={}, \ + server_lack_of_ack_timeout_ms={}, flush_timeout_ms={}, connection_timeout_ms={})", + self.max_inflight_batches, + self.recovery, + self.recovery_timeout_ms, + self.recovery_backoff_ms, + self.recovery_retries, + self.server_lack_of_ack_timeout_ms, + self.flush_timeout_ms, + self.connection_timeout_ms, + ) + } +} + +impl ArrowStreamConfigurationOptions { + pub fn to_rust(&self) -> RustArrowStreamOptions { + RustArrowStreamOptions { + max_inflight_batches: self.max_inflight_batches as usize, + recovery: self.recovery, + recovery_timeout_ms: self.recovery_timeout_ms as u64, + recovery_backoff_ms: self.recovery_backoff_ms as u64, + recovery_retries: self.recovery_retries as u32, + server_lack_of_ack_timeout_ms: self.server_lack_of_ack_timeout_ms as u64, + flush_timeout_ms: self.flush_timeout_ms as u64, + connection_timeout_ms: self.connection_timeout_ms as u64, + ipc_compression: None, + } + } +} + +// ============================================================================= +// SYNC ARROW STREAM +// ============================================================================= + +/// Synchronous Arrow Flight stream for ingesting pyarrow RecordBatches. +#[pyclass] +pub struct ZerobusArrowStream { + inner: Arc>, + runtime: Arc, +} + +#[pymethods] +impl ZerobusArrowStream { + /// Ingest a single Arrow RecordBatch (as IPC bytes) and return the offset. + /// + /// Args: + /// ipc_bytes: Arrow IPC serialized bytes from pyarrow.RecordBatch.serialize() + fn ingest_batch(&self, py: Python, ipc_bytes: &PyBytes) -> PyResult { + let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes()) + .map_err(|e| map_rust_error_to_pyerr(e))?; + + let stream_clone = self.inner.clone(); + let runtime = self.runtime.clone(); + + py.allow_threads(|| { + runtime.block_on(async move { + let stream_guard = stream_clone.read().await; + stream_guard + .ingest_batch(batch) + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + }) + }) + } + + /// Wait for a specific offset to be acknowledged. + fn wait_for_offset(&self, py: Python, offset: i64) -> PyResult<()> { + let stream_clone = self.inner.clone(); + let runtime = self.runtime.clone(); + + py.allow_threads(|| { + runtime.block_on(async move { + let stream_guard = stream_clone.read().await; + stream_guard + .wait_for_offset(offset) + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + }) + }) + } + + /// Flush all pending batches, waiting for acknowledgment. + fn flush(&self, py: Python) -> PyResult<()> { + let stream_clone = self.inner.clone(); + let runtime = self.runtime.clone(); + + py.allow_threads(|| { + runtime.block_on(async move { + let stream_guard = stream_clone.read().await; + stream_guard + .flush() + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + }) + }) + } + + /// Close the stream gracefully. + fn close(&self, py: Python) -> PyResult<()> { + let stream_clone = self.inner.clone(); + let runtime = self.runtime.clone(); + + py.allow_threads(|| { + runtime.block_on(async move { + let mut stream_guard = stream_clone.write().await; + stream_guard + .close() + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + }) + }) + } + + /// Check if the stream has been closed. + #[getter] + fn is_closed(&self) -> bool { + let stream_clone = self.inner.clone(); + self.runtime.block_on(async move { + let stream_guard = stream_clone.read().await; + stream_guard.is_closed() + }) + } + + /// Get the table name. + #[getter] + fn table_name(&self) -> String { + let stream_clone = self.inner.clone(); + self.runtime.block_on(async move { + let stream_guard = stream_clone.read().await; + stream_guard.table_name().to_string() + }) + } + + /// Get unacknowledged batches as a list of Arrow IPC byte buffers. + fn get_unacked_batches(&self, py: Python) -> PyResult> { + let stream_clone = self.inner.clone(); + let runtime = self.runtime.clone(); + + py.allow_threads(|| { + runtime.block_on(async move { + let stream_guard = stream_clone.read().await; + let batches = stream_guard + .get_unacked_batches() + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + + Python::with_gil(|py| { + let mut py_batches: Vec = Vec::with_capacity(batches.len()); + for batch in &batches { + let ipc_bytes = record_batch_to_ipc_bytes(batch) + .map_err(|e| map_rust_error_to_pyerr(e))?; + py_batches + .push(PyBytes::new(py, &ipc_bytes).into()); + } + Ok(py_batches) + }) + }) + }) + } +} + +// ============================================================================= +// ASYNC ARROW STREAM +// ============================================================================= + +/// Asynchronous Arrow Flight stream for ingesting pyarrow RecordBatches. +#[pyclass(name = "AsyncZerobusArrowStream")] +pub struct AsyncZerobusArrowStream { + inner: Arc>, +} + +#[pymethods] +impl AsyncZerobusArrowStream { + /// Ingest a single Arrow RecordBatch (as IPC bytes) and return the offset. + fn ingest_batch<'py>( + &self, + py: Python<'py>, + ipc_bytes: &PyBytes, + ) -> PyResult<&'py PyAny> { + let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes()) + .map_err(|e| map_rust_error_to_pyerr(e))?; + + let stream_clone = self.inner.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let stream_guard = stream_clone.read().await; + stream_guard + .ingest_batch(batch) + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + }) + } + + /// Wait for a specific offset to be acknowledged. + fn wait_for_offset<'py>( + &self, + py: Python<'py>, + offset: i64, + ) -> PyResult<&'py PyAny> { + let stream_clone = self.inner.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let stream_guard = stream_clone.read().await; + stream_guard + .wait_for_offset(offset) + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + Ok(()) + }) + } + + /// Flush all pending batches. + fn flush<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> { + let stream_clone = self.inner.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let stream_guard = stream_clone.read().await; + stream_guard + .flush() + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + Ok(()) + }) + } + + /// Close the stream gracefully. + fn close<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> { + let stream_clone = self.inner.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let mut stream_guard = stream_clone.write().await; + stream_guard + .close() + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + Ok(()) + }) + } + + /// Check if the stream has been closed. + #[getter] + fn is_closed(&self) -> bool { + // This needs a runtime to lock, but is_closed is atomic so + // we can use try_read or spawn. + let stream_clone = self.inner.clone(); + pyo3_asyncio::tokio::get_runtime().block_on(async move { + let stream_guard = stream_clone.read().await; + stream_guard.is_closed() + }) + } + + /// Get the table name. + #[getter] + fn table_name(&self) -> String { + let stream_clone = self.inner.clone(); + pyo3_asyncio::tokio::get_runtime().block_on(async move { + let stream_guard = stream_clone.read().await; + stream_guard.table_name().to_string() + }) + } + + /// Get unacknowledged batches as a list of Arrow IPC byte buffers. + fn get_unacked_batches<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> { + let stream_clone = self.inner.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let stream_guard = stream_clone.read().await; + let batches = stream_guard + .get_unacked_batches() + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + + Python::with_gil(|py| { + let mut py_batches: Vec = Vec::with_capacity(batches.len()); + for batch in &batches { + let ipc_bytes = record_batch_to_ipc_bytes(batch) + .map_err(|e| map_rust_error_to_pyerr(e))?; + py_batches.push(PyBytes::new(py, &ipc_bytes).into()); + } + Ok(py_batches) + }) + }) + } +} + +// ============================================================================= +// SDK METHODS — called from sync_wrapper and async_wrapper +// ============================================================================= + +/// Create an Arrow stream (sync helper). +pub fn create_arrow_stream_sync( + sdk: &Arc>, + runtime: &Arc, + py: Python, + table_name: String, + schema_ipc_bytes: &[u8], + client_id: String, + client_secret: String, + options: Option<&ArrowStreamConfigurationOptions>, +) -> PyResult { + let schema = ipc_schema_bytes_to_arrow_schema(schema_ipc_bytes) + .map_err(|e| map_rust_error_to_pyerr(e))?; + + let table_props = RustArrowTableProperties { + table_name, + schema: Arc::new(schema), + }; + + let rust_options = options.map(|o| o.to_rust()); + + let sdk_clone = sdk.clone(); + let runtime_clone = runtime.clone(); + + let stream = py.allow_threads(|| { + runtime_clone.block_on(async move { + let sdk_guard = sdk_clone.read().await; + sdk_guard + .create_arrow_stream(table_props, client_id, client_secret, rust_options) + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + }) + })?; + + Ok(ZerobusArrowStream { + inner: Arc::new(RwLock::new(stream)), + runtime: runtime.clone(), + }) +} + +/// Create an Arrow stream with headers provider (sync helper). +pub fn create_arrow_stream_with_headers_provider_sync( + sdk: &Arc>, + runtime: &Arc, + py: Python, + table_name: String, + schema_ipc_bytes: &[u8], + headers_provider: PyObject, + options: Option<&ArrowStreamConfigurationOptions>, +) -> PyResult { + let schema = ipc_schema_bytes_to_arrow_schema(schema_ipc_bytes) + .map_err(|e| map_rust_error_to_pyerr(e))?; + + let table_props = RustArrowTableProperties { + table_name, + schema: Arc::new(schema), + }; + + let rust_options = options.map(|o| o.to_rust()); + let provider = Arc::new(HeadersProviderWrapper::new(headers_provider)); + + let sdk_clone = sdk.clone(); + let runtime_clone = runtime.clone(); + + let stream = py.allow_threads(|| { + runtime_clone.block_on(async move { + let sdk_guard = sdk_clone.read().await; + sdk_guard + .create_arrow_stream_with_headers_provider( + table_props, + provider, + rust_options, + ) + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + }) + })?; + + Ok(ZerobusArrowStream { + inner: Arc::new(RwLock::new(stream)), + runtime: runtime.clone(), + }) +} + +/// Recreate an Arrow stream from a closed stream (sync helper). +pub fn recreate_arrow_stream_sync( + sdk: &Arc>, + runtime: &Arc, + py: Python, + old_stream: &ZerobusArrowStream, +) -> PyResult { + let sdk_clone = sdk.clone(); + let old_inner = old_stream.inner.clone(); + let runtime_clone = runtime.clone(); + + let stream = py.allow_threads(|| { + runtime_clone.block_on(async move { + let old_guard = old_inner.read().await; + let sdk_guard = sdk_clone.read().await; + sdk_guard + .recreate_arrow_stream(&*old_guard) + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + }) + })?; + + Ok(ZerobusArrowStream { + inner: Arc::new(RwLock::new(stream)), + runtime: runtime.clone(), + }) +} + +/// Create an Arrow stream (async helper). +pub fn create_arrow_stream_async<'py>( + sdk: &Arc>, + py: Python<'py>, + table_name: String, + schema_ipc_bytes: Vec, + client_id: String, + client_secret: String, + options: Option, +) -> PyResult<&'py PyAny> { + let schema = ipc_schema_bytes_to_arrow_schema(&schema_ipc_bytes) + .map_err(|e| map_rust_error_to_pyerr(e))?; + + let table_props = RustArrowTableProperties { + table_name, + schema: Arc::new(schema), + }; + + let rust_options = options.map(|o| o.to_rust()); + let sdk_clone = sdk.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let sdk_guard = sdk_clone.read().await; + let stream = sdk_guard + .create_arrow_stream(table_props, client_id, client_secret, rust_options) + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + + Ok(AsyncZerobusArrowStream { + inner: Arc::new(RwLock::new(stream)), + }) + }) +} + +/// Create an Arrow stream with headers provider (async helper). +pub fn create_arrow_stream_with_headers_provider_async<'py>( + sdk: &Arc>, + py: Python<'py>, + table_name: String, + schema_ipc_bytes: Vec, + headers_provider: PyObject, + options: Option, +) -> PyResult<&'py PyAny> { + let schema = ipc_schema_bytes_to_arrow_schema(&schema_ipc_bytes) + .map_err(|e| map_rust_error_to_pyerr(e))?; + + let table_props = RustArrowTableProperties { + table_name, + schema: Arc::new(schema), + }; + + let rust_options = options.map(|o| o.to_rust()); + let provider = Arc::new(HeadersProviderWrapper::new(headers_provider)); + let sdk_clone = sdk.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let sdk_guard = sdk_clone.read().await; + let stream = sdk_guard + .create_arrow_stream_with_headers_provider( + table_props, + provider, + rust_options, + ) + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + + Ok(AsyncZerobusArrowStream { + inner: Arc::new(RwLock::new(stream)), + }) + }) +} + +/// Recreate an Arrow stream from a closed stream (async helper). +pub fn recreate_arrow_stream_async<'py>( + sdk: &Arc>, + py: Python<'py>, + old_stream: &AsyncZerobusArrowStream, +) -> PyResult<&'py PyAny> { + let sdk_clone = sdk.clone(); + let old_inner = old_stream.inner.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let old_guard = old_inner.read().await; + let sdk_guard = sdk_clone.read().await; + let stream = sdk_guard + .recreate_arrow_stream(&*old_guard) + .await + .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + + Ok(AsyncZerobusArrowStream { + inner: Arc::new(RwLock::new(stream)), + }) + }) +} diff --git a/python/rust/src/async_wrapper.rs b/python/rust/src/async_wrapper.rs index d4aa79f..543aa04 100644 --- a/python/rust/src/async_wrapper.rs +++ b/python/rust/src/async_wrapper.rs @@ -14,6 +14,7 @@ use databricks_zerobus_ingest_sdk::{ ZerobusStream as RustStream, }; +use crate::arrow::{self, ArrowStreamConfigurationOptions, AsyncZerobusArrowStream}; use crate::auth::HeadersProviderWrapper; use crate::common::{map_error, AckCallback, StreamConfigurationOptions, TableProperties}; @@ -469,6 +470,57 @@ impl ZerobusSdk { }) } + /// Create a new Arrow Flight stream with OAuth authentication (async). + #[pyo3(signature = (table_name, schema_ipc_bytes, client_id, client_secret, options = None))] + fn create_arrow_stream<'py>( + &self, + py: Python<'py>, + table_name: String, + schema_ipc_bytes: Vec, + client_id: String, + client_secret: String, + options: Option, + ) -> PyResult<&'py PyAny> { + arrow::create_arrow_stream_async( + &self.inner, + py, + table_name, + schema_ipc_bytes, + client_id, + client_secret, + options, + ) + } + + /// Create a new Arrow Flight stream with custom headers provider (async). + #[pyo3(signature = (table_name, schema_ipc_bytes, headers_provider, options = None))] + fn create_arrow_stream_with_headers_provider<'py>( + &self, + py: Python<'py>, + table_name: String, + schema_ipc_bytes: Vec, + headers_provider: PyObject, + options: Option, + ) -> PyResult<&'py PyAny> { + arrow::create_arrow_stream_with_headers_provider_async( + &self.inner, + py, + table_name, + schema_ipc_bytes, + headers_provider, + options, + ) + } + + /// Recreate a closed Arrow stream (async). + fn recreate_arrow_stream<'py>( + &self, + py: Python<'py>, + old_stream: &AsyncZerobusArrowStream, + ) -> PyResult<&'py PyAny> { + arrow::recreate_arrow_stream_async(&self.inner, py, old_stream) + } + /// Recreate a stream from an old stream (async) fn recreate_stream<'py>( &self, diff --git a/python/rust/src/lib.rs b/python/rust/src/lib.rs index 287347e..46a1acc 100644 --- a/python/rust/src/lib.rs +++ b/python/rust/src/lib.rs @@ -2,6 +2,7 @@ use pyo3::prelude::*; +mod arrow; mod async_wrapper; mod auth; mod common; @@ -43,6 +44,13 @@ fn _zerobus_core(py: Python, m: &PyModule) -> PyResult<()> { // Add authentication classes m.add_class::()?; + // Add arrow submodule + let arrow_module = PyModule::new(py, "arrow")?; + arrow_module.add_class::()?; + arrow_module.add_class::()?; + arrow_module.add_class::()?; + m.add_submodule(arrow_module)?; + // Add sync submodule let sync_module = PyModule::new(py, "sync")?; sync_module.add_class::()?; diff --git a/python/rust/src/sync_wrapper.rs b/python/rust/src/sync_wrapper.rs index 2491aee..78c33cc 100644 --- a/python/rust/src/sync_wrapper.rs +++ b/python/rust/src/sync_wrapper.rs @@ -14,6 +14,7 @@ use databricks_zerobus_ingest_sdk::{ ZerobusSdk as RustSdk, ZerobusStream as RustStream, }; +use crate::arrow::{self, ArrowStreamConfigurationOptions, ZerobusArrowStream}; use crate::auth::HeadersProviderWrapper; use crate::common::{map_error, AckCallback, StreamConfigurationOptions, TableProperties}; @@ -567,6 +568,66 @@ impl ZerobusSdk { }) } + /// Create a new Arrow Flight stream with OAuth authentication. + /// + /// Args: + /// table_name: Fully qualified table name (catalog.schema.table) + /// schema_ipc_bytes: Arrow IPC serialized schema bytes + /// client_id: OAuth client ID + /// client_secret: OAuth client secret + /// options: Optional ArrowStreamConfigurationOptions + #[pyo3(signature = (table_name, schema_ipc_bytes, client_id, client_secret, options = None))] + fn create_arrow_stream( + &self, + py: Python, + table_name: String, + schema_ipc_bytes: &PyBytes, + client_id: String, + client_secret: String, + options: Option<&ArrowStreamConfigurationOptions>, + ) -> PyResult { + arrow::create_arrow_stream_sync( + &self.inner, + &self.runtime, + py, + table_name, + schema_ipc_bytes.as_bytes(), + client_id, + client_secret, + options, + ) + } + + /// Create a new Arrow Flight stream with custom headers provider. + #[pyo3(signature = (table_name, schema_ipc_bytes, headers_provider, options = None))] + fn create_arrow_stream_with_headers_provider( + &self, + py: Python, + table_name: String, + schema_ipc_bytes: &PyBytes, + headers_provider: PyObject, + options: Option<&ArrowStreamConfigurationOptions>, + ) -> PyResult { + arrow::create_arrow_stream_with_headers_provider_sync( + &self.inner, + &self.runtime, + py, + table_name, + schema_ipc_bytes.as_bytes(), + headers_provider, + options, + ) + } + + /// Recreate a closed Arrow stream with the same configuration. + fn recreate_arrow_stream( + &self, + py: Python, + old_stream: &ZerobusArrowStream, + ) -> PyResult { + arrow::recreate_arrow_stream_sync(&self.inner, &self.runtime, py, old_stream) + } + /// Recreate a closed stream with the same configuration fn recreate_stream(&self, py: Python, old_stream: &ZerobusStream) -> PyResult { let sdk = self.inner.clone(); diff --git a/python/tests/test_arrow.py b/python/tests/test_arrow.py new file mode 100644 index 0000000..13c40c7 --- /dev/null +++ b/python/tests/test_arrow.py @@ -0,0 +1,309 @@ +""" +Tests for Arrow Flight support in the Python SDK. + +These tests verify the Python-side Arrow serialization/deserialization helpers, +the ArrowStreamConfigurationOptions pyclass, and the API surface of Arrow stream +classe, without making network connections. +""" + +import unittest + +import pyarrow as pa + +from zerobus.sdk.shared.arrow import ( + ArrowStreamConfigurationOptions, + _check_pyarrow, + _deserialize_batch, + _serialize_batch, + _serialize_schema, +) + + +class TestCheckPyarrow(unittest.TestCase): + """Test the pyarrow availability check.""" + + def test_check_pyarrow_returns_module(self): + result = _check_pyarrow() + self.assertIs(result, pa) + + +class TestSerializeSchema(unittest.TestCase): + """Test schema serialization to IPC bytes.""" + + def test_simple_schema(self): + schema = pa.schema([("a", pa.int64()), ("b", pa.utf8())]) + ipc_bytes = _serialize_schema(schema) + self.assertIsInstance(ipc_bytes, bytes) + self.assertGreater(len(ipc_bytes), 0) + + def test_schema_with_various_types(self): + schema = pa.schema([ + ("int_col", pa.int32()), + ("float_col", pa.float64()), + ("str_col", pa.large_utf8()), + ("bool_col", pa.bool_()), + ("list_col", pa.list_(pa.int64())), + ]) + ipc_bytes = _serialize_schema(schema) + self.assertIsInstance(ipc_bytes, bytes) + + def test_roundtrip_schema_via_ipc(self): + schema = pa.schema([("x", pa.int64()), ("y", pa.float32())]) + ipc_bytes = _serialize_schema(schema) + reader = pa.ipc.open_stream(ipc_bytes) + recovered = reader.schema + self.assertEqual(schema, recovered) + + def test_rejects_non_schema(self): + with self.assertRaises(TypeError): + _serialize_schema("not a schema") + + def test_rejects_record_batch(self): + batch = pa.record_batch({"a": [1]}, schema=pa.schema([("a", pa.int64())])) + with self.assertRaises(TypeError): + _serialize_schema(batch) + + +class TestSerializeBatch(unittest.TestCase): + """Test RecordBatch/Table serialization to IPC bytes.""" + + def test_simple_batch(self): + schema = pa.schema([("a", pa.int64())]) + batch = pa.record_batch({"a": [1, 2, 3]}, schema=schema) + ipc_bytes = _serialize_batch(batch) + self.assertIsInstance(ipc_bytes, bytes) + self.assertGreater(len(ipc_bytes), 0) + + def test_batch_with_multiple_columns(self): + schema = pa.schema([("x", pa.int32()), ("y", pa.utf8()), ("z", pa.float64())]) + batch = pa.record_batch( + {"x": [1, 2], "y": ["a", "b"], "z": [1.0, 2.0]}, + schema=schema, + ) + ipc_bytes = _serialize_batch(batch) + self.assertIsInstance(ipc_bytes, bytes) + + def test_table_single_chunk(self): + schema = pa.schema([("a", pa.int64())]) + table = pa.table({"a": [1, 2, 3]}, schema=schema) + ipc_bytes = _serialize_batch(table) + self.assertIsInstance(ipc_bytes, bytes) + + def test_table_multiple_chunks(self): + schema = pa.schema([("a", pa.int64())]) + batch1 = pa.record_batch({"a": [1, 2]}, schema=schema) + batch2 = pa.record_batch({"a": [3, 4]}, schema=schema) + table = pa.Table.from_batches([batch1, batch2]) + ipc_bytes = _serialize_batch(table) + # Should combine chunks into a single batch + recovered = _deserialize_batch(ipc_bytes) + self.assertEqual(recovered.num_rows, 4) + + def test_empty_table_raises(self): + schema = pa.schema([("a", pa.int64())]) + table = pa.table({"a": pa.array([], type=pa.int64())}, schema=schema) + with self.assertRaises(ValueError): + _serialize_batch(table) + + def test_rejects_wrong_type(self): + with self.assertRaises(TypeError): + _serialize_batch("not a batch") + + def test_rejects_dict(self): + with self.assertRaises(TypeError): + _serialize_batch({"a": [1, 2, 3]}) + + def test_rejects_schema(self): + with self.assertRaises(TypeError): + _serialize_batch(pa.schema([("a", pa.int64())])) + + +class TestDeserializeBatch(unittest.TestCase): + """Test IPC bytes deserialization back to RecordBatch.""" + + def test_roundtrip(self): + schema = pa.schema([("a", pa.int64()), ("b", pa.utf8())]) + original = pa.record_batch( + {"a": [1, 2, 3], "b": ["x", "y", "z"]}, + schema=schema, + ) + ipc_bytes = _serialize_batch(original) + recovered = _deserialize_batch(ipc_bytes) + + self.assertIsInstance(recovered, pa.RecordBatch) + self.assertEqual(recovered.schema, original.schema) + self.assertEqual(recovered.num_rows, original.num_rows) + self.assertEqual(recovered.to_pydict(), original.to_pydict()) + + def test_roundtrip_with_nulls(self): + schema = pa.schema([("a", pa.int64()), ("b", pa.utf8())]) + original = pa.record_batch( + {"a": [1, None, 3], "b": [None, "y", None]}, + schema=schema, + ) + ipc_bytes = _serialize_batch(original) + recovered = _deserialize_batch(ipc_bytes) + + self.assertEqual(recovered.to_pydict(), original.to_pydict()) + + def test_roundtrip_table(self): + schema = pa.schema([("val", pa.float64())]) + table = pa.table({"val": [1.1, 2.2, 3.3]}, schema=schema) + ipc_bytes = _serialize_batch(table) + recovered = _deserialize_batch(ipc_bytes) + + self.assertIsInstance(recovered, pa.RecordBatch) + self.assertEqual(recovered.num_rows, 3) + + def test_invalid_bytes_raises(self): + with self.assertRaises(Exception): + _deserialize_batch(b"not valid ipc data") + + +class TestArrowStreamConfigurationOptions(unittest.TestCase): + """Test ArrowStreamConfigurationOptions pyclass.""" + + def test_default_construction(self): + options = ArrowStreamConfigurationOptions() + self.assertIsInstance(options.max_inflight_batches, int) + self.assertIsInstance(options.recovery, bool) + self.assertIsInstance(options.recovery_timeout_ms, int) + self.assertIsInstance(options.recovery_backoff_ms, int) + self.assertIsInstance(options.recovery_retries, int) + self.assertIsInstance(options.server_lack_of_ack_timeout_ms, int) + self.assertIsInstance(options.flush_timeout_ms, int) + self.assertIsInstance(options.connection_timeout_ms, int) + + def test_kwargs_construction(self): + options = ArrowStreamConfigurationOptions( + max_inflight_batches=5, + recovery=False, + flush_timeout_ms=3000, + ) + self.assertEqual(options.max_inflight_batches, 5) + self.assertFalse(options.recovery) + self.assertEqual(options.flush_timeout_ms, 3000) + + def test_all_kwargs(self): + options = ArrowStreamConfigurationOptions( + max_inflight_batches=20, + recovery=True, + recovery_timeout_ms=10000, + recovery_backoff_ms=500, + recovery_retries=10, + server_lack_of_ack_timeout_ms=30000, + flush_timeout_ms=5000, + connection_timeout_ms=8000, + ) + self.assertEqual(options.max_inflight_batches, 20) + self.assertTrue(options.recovery) + self.assertEqual(options.recovery_timeout_ms, 10000) + self.assertEqual(options.recovery_backoff_ms, 500) + self.assertEqual(options.recovery_retries, 10) + self.assertEqual(options.server_lack_of_ack_timeout_ms, 30000) + self.assertEqual(options.flush_timeout_ms, 5000) + self.assertEqual(options.connection_timeout_ms, 8000) + + def test_unknown_kwarg_raises(self): + with self.assertRaises(ValueError): + ArrowStreamConfigurationOptions(nonexistent_option=42) + + def test_setters(self): + options = ArrowStreamConfigurationOptions() + options.max_inflight_batches = 99 + options.recovery = False + self.assertEqual(options.max_inflight_batches, 99) + self.assertFalse(options.recovery) + + def test_repr(self): + options = ArrowStreamConfigurationOptions() + repr_str = repr(options) + self.assertIn("ArrowStreamConfigurationOptions", repr_str) + self.assertIn("max_inflight_batches", repr_str) + self.assertIn("recovery", repr_str) + + +class TestArrowImports(unittest.TestCase): + """Test that Arrow types can be imported from expected locations.""" + + def test_import_from_shared_arrow(self): + from zerobus.sdk.shared.arrow import ArrowStreamConfigurationOptions + self.assertIsNotNone(ArrowStreamConfigurationOptions) + + def test_import_from_top_level(self): + from zerobus import ArrowStreamConfigurationOptions + self.assertIsNotNone(ArrowStreamConfigurationOptions) + + def test_import_arrow_stream_from_sync(self): + from zerobus.sdk.sync import ZerobusArrowStream + self.assertIsNotNone(ZerobusArrowStream) + + def test_import_arrow_stream_from_aio(self): + from zerobus.sdk.aio import ZerobusArrowStream + self.assertIsNotNone(ZerobusArrowStream) + + def test_import_arrow_stream_from_top_level(self): + from zerobus import ZerobusArrowStream + self.assertIsNotNone(ZerobusArrowStream) + + +class TestArrowSDKAPISurface(unittest.TestCase): + """Test that Arrow SDK classes have the expected API surface.""" + + def test_sync_sdk_has_arrow_methods(self): + from zerobus.sdk.sync import ZerobusSdk + self.assertTrue(hasattr(ZerobusSdk, "create_arrow_stream")) + self.assertTrue(hasattr(ZerobusSdk, "recreate_arrow_stream")) + + def test_async_sdk_has_arrow_methods(self): + from zerobus.sdk.aio import ZerobusSdk + self.assertTrue(hasattr(ZerobusSdk, "create_arrow_stream")) + self.assertTrue(hasattr(ZerobusSdk, "recreate_arrow_stream")) + + def test_sync_arrow_stream_has_methods(self): + from zerobus.sdk.sync import ZerobusArrowStream + expected_methods = [ + "ingest_batch", + "wait_for_offset", + "flush", + "close", + "get_unacked_batches", + ] + for method in expected_methods: + self.assertTrue( + hasattr(ZerobusArrowStream, method), + f"Sync ZerobusArrowStream missing method: {method}", + ) + + def test_async_arrow_stream_has_methods(self): + from zerobus.sdk.aio import ZerobusArrowStream + expected_methods = [ + "ingest_batch", + "wait_for_offset", + "flush", + "close", + "get_unacked_batches", + ] + for method in expected_methods: + self.assertTrue( + hasattr(ZerobusArrowStream, method), + f"Async ZerobusArrowStream missing method: {method}", + ) + + def test_arrow_types_not_on_core_top_level(self): + """Arrow types should be in _core.arrow submodule, not on _core directly.""" + import zerobus._zerobus_core as _core + self.assertFalse(hasattr(_core, "ArrowStreamConfigurationOptions")) + self.assertFalse(hasattr(_core, "ZerobusArrowStream")) + self.assertFalse(hasattr(_core, "AsyncZerobusArrowStream")) + + def test_arrow_types_in_core_arrow_submodule(self): + """Arrow types should be accessible via _core.arrow.""" + import zerobus._zerobus_core as _core + self.assertTrue(hasattr(_core.arrow, "ArrowStreamConfigurationOptions")) + self.assertTrue(hasattr(_core.arrow, "ZerobusArrowStream")) + self.assertTrue(hasattr(_core.arrow, "AsyncZerobusArrowStream")) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/zerobus/__init__.py b/python/zerobus/__init__.py index e94c1fe..262c44f 100644 --- a/python/zerobus/__init__.py +++ b/python/zerobus/__init__.py @@ -45,7 +45,8 @@ # Import from Rust core import zerobus._zerobus_core as _core -from zerobus.sdk.sync import ZerobusSdk, ZerobusStream +from zerobus.sdk.sync import ZerobusSdk, ZerobusStream, ZerobusArrowStream +from zerobus.sdk.shared.arrow import ArrowStreamConfigurationOptions __version__ = "1.1.0" @@ -63,6 +64,9 @@ # Sync SDK (default) "ZerobusSdk", "ZerobusStream", + # Arrow (experimental) + "ZerobusArrowStream", + "ArrowStreamConfigurationOptions", "RecordAcknowledgment", # Common types "TableProperties", diff --git a/python/zerobus/sdk/aio/__init__.py b/python/zerobus/sdk/aio/__init__.py index 125bffb..aadcb21 100644 --- a/python/zerobus/sdk/aio/__init__.py +++ b/python/zerobus/sdk/aio/__init__.py @@ -11,6 +11,7 @@ RecordType, StreamConfigurationOptions, TableProperties, + ZerobusArrowStream, ZerobusException, ZerobusSdk, ZerobusStream, @@ -19,6 +20,7 @@ __all__ = [ "ZerobusSdk", "ZerobusStream", + "ZerobusArrowStream", "TableProperties", "StreamConfigurationOptions", "RecordType", diff --git a/python/zerobus/sdk/aio/zerobus_sdk.py b/python/zerobus/sdk/aio/zerobus_sdk.py index 3a95380..1ee8e98 100644 --- a/python/zerobus/sdk/aio/zerobus_sdk.py +++ b/python/zerobus/sdk/aio/zerobus_sdk.py @@ -39,13 +39,14 @@ >>> asyncio.run(main()) """ -from typing import Any +from typing import Any, List, Optional # Import Rust-backed implementations import zerobus._zerobus_core as _core # Import base Rust stream class _RustZerobusStream = _core.aio.ZerobusStream +_RustAsyncZerobusArrowStream = _core.arrow.AsyncZerobusArrowStream class ZerobusStream: @@ -155,12 +156,109 @@ def get_state(self): return 1 # OPENED state +class ZerobusArrowStream: + """ + Asynchronous Arrow Flight stream for ingesting pyarrow RecordBatches. + + **Experimental/Unsupported**: Arrow Flight support is experimental and not yet + supported for production use. The API may change in future releases. + """ + + def __init__(self, rust_stream: _RustAsyncZerobusArrowStream): + self._inner = rust_stream + + async def ingest_batch(self, batch) -> int: + """ + Ingest a pyarrow.RecordBatch or pyarrow.Table. + + Args: + batch: A pyarrow.RecordBatch or pyarrow.Table to ingest. + + Returns: + The offset ID assigned to this batch. + """ + from zerobus.sdk.shared.arrow import _serialize_batch + + ipc_bytes = _serialize_batch(batch) + return await self._inner.ingest_batch(ipc_bytes) + + async def wait_for_offset(self, offset: int): + """Wait for a specific offset to be acknowledged.""" + return await self._inner.wait_for_offset(offset) + + async def flush(self): + """Flush all pending batches, waiting for acknowledgment.""" + return await self._inner.flush() + + async def close(self): + """Close the stream gracefully.""" + return await self._inner.close() + + @property + def is_closed(self) -> bool: + """Check if the stream has been closed.""" + return self._inner.is_closed + + @property + def table_name(self) -> str: + """Get the table name.""" + return self._inner.table_name + + async def get_unacked_batches(self) -> list: + """ + Get unacknowledged batches as a list of pyarrow.RecordBatch. + + Returns: + List of pyarrow.RecordBatch objects. + """ + from zerobus.sdk.shared.arrow import _deserialize_batch + + ipc_list = await self._inner.get_unacked_batches() + return [_deserialize_batch(ipc_bytes) for ipc_bytes in ipc_list] + + class ZerobusSdk: """Python wrapper around Rust ZerobusSdk that returns wrapped streams.""" def __init__(self, host: str, unity_catalog_url: str): self._inner = _core.aio.ZerobusSdk(host, unity_catalog_url) + async def create_arrow_stream(self, table_name: str, schema, client_id: str, client_secret: str, options=None, headers_provider=None) -> ZerobusArrowStream: + """ + Create an Arrow Flight stream for ingesting pyarrow RecordBatches (async). + + **Experimental/Unsupported**: Arrow Flight support is experimental. + + Args: + table_name: Fully qualified table name (catalog.schema.table). + schema: A pyarrow.Schema defining the table schema. + client_id: OAuth client ID. + client_secret: OAuth client secret. + options: Optional ArrowStreamConfigurationOptions. + headers_provider: Optional custom headers provider. + + Returns: + A ZerobusArrowStream ready for ingesting RecordBatches. + """ + from zerobus.sdk.shared.arrow import _serialize_schema + + schema_bytes = _serialize_schema(schema) + + if headers_provider is not None: + rust_stream = await self._inner.create_arrow_stream_with_headers_provider( + table_name, schema_bytes, headers_provider, options + ) + else: + rust_stream = await self._inner.create_arrow_stream( + table_name, schema_bytes, client_id, client_secret, options + ) + return ZerobusArrowStream(rust_stream) + + async def recreate_arrow_stream(self, old_stream: ZerobusArrowStream) -> ZerobusArrowStream: + """Recreate a closed Arrow stream with the same configuration.""" + rust_stream = await self._inner.recreate_arrow_stream(old_stream._inner) + return ZerobusArrowStream(rust_stream) + async def create_stream( self, client_id: str, client_secret: str, table_properties, options=None, headers_provider=None ): @@ -202,6 +300,7 @@ async def recreate_stream(self, old_stream: ZerobusStream): __all__ = [ "ZerobusSdk", "ZerobusStream", + "ZerobusArrowStream", "TableProperties", "StreamConfigurationOptions", "RecordType", diff --git a/python/zerobus/sdk/shared/arrow.py b/python/zerobus/sdk/shared/arrow.py new file mode 100644 index 0000000..249ac29 --- /dev/null +++ b/python/zerobus/sdk/shared/arrow.py @@ -0,0 +1,90 @@ +""" +Arrow Flight support for the Zerobus SDK. + +**Experimental/Unsupported**: Arrow Flight support is experimental and not yet +supported for production use. The API may change in future releases. + +Requires pyarrow to be installed: + pip install databricks-zerobus-ingest-sdk[arrow] + +Example (Sync): + >>> import pyarrow as pa + >>> from zerobus.sdk.sync import ZerobusSdk + >>> from zerobus.sdk.shared.arrow import ArrowStreamConfigurationOptions + >>> + >>> sdk = ZerobusSdk(host, unity_catalog_url) + >>> schema = pa.schema([("device_name", pa.large_utf8()), ("temp", pa.int32())]) + >>> stream = sdk.create_arrow_stream( + ... "catalog.schema.table", schema, client_id, client_secret + ... ) + >>> batch = pa.record_batch({"device_name": ["s1"], "temp": [22]}, schema=schema) + >>> offset = stream.ingest_batch(batch) + >>> stream.wait_for_offset(offset) + >>> stream.close() +""" + +_PYARROW_IMPORT_ERROR = ( + "pyarrow is required for Arrow Flight support. " + "Install with: pip install databricks-zerobus-ingest-sdk[arrow]" +) + + +def _check_pyarrow(): + """Check that pyarrow is available, raise ImportError if not.""" + try: + import pyarrow # noqa: F401 + + return pyarrow + except ImportError: + raise ImportError(_PYARROW_IMPORT_ERROR) + + +def _serialize_schema(schema): + """Serialize a pyarrow.Schema to IPC bytes.""" + pa = _check_pyarrow() + if not isinstance(schema, pa.Schema): + raise TypeError(f"Expected pyarrow.Schema, got {type(schema).__name__}") + + # Create an empty RecordBatch with the schema and serialize it. + # This produces a valid IPC stream that the Rust side can parse for the schema. + sink = pa.BufferOutputStream() + writer = pa.ipc.new_stream(sink, schema) + writer.close() + return sink.getvalue().to_pybytes() + + +def _serialize_batch(batch): + """Serialize a pyarrow.RecordBatch (or Table) to IPC bytes.""" + pa = _check_pyarrow() + if isinstance(batch, pa.Table): + # Convert Table to a single RecordBatch + batches = batch.to_batches() + if len(batches) == 0: + raise ValueError("Cannot ingest an empty pyarrow.Table") + if len(batches) > 1: + batch = pa.concat_tables([pa.Table.from_batches([b]) for b in batches]).combine_chunks().to_batches()[0] + else: + batch = batches[0] + elif not isinstance(batch, pa.RecordBatch): + raise TypeError( + f"Expected pyarrow.RecordBatch or pyarrow.Table, got {type(batch).__name__}" + ) + + sink = pa.BufferOutputStream() + writer = pa.ipc.new_stream(sink, batch.schema) + writer.write_batch(batch) + writer.close() + return sink.getvalue().to_pybytes() + + +def _deserialize_batch(ipc_bytes): + """Deserialize IPC bytes back to a pyarrow.RecordBatch.""" + pa = _check_pyarrow() + reader = pa.ipc.open_stream(ipc_bytes) + return reader.read_next_batch() + + +# Re-export configuration from Rust core +import zerobus._zerobus_core as _core + +ArrowStreamConfigurationOptions = _core.arrow.ArrowStreamConfigurationOptions diff --git a/python/zerobus/sdk/sync/__init__.py b/python/zerobus/sdk/sync/__init__.py index a75cb75..a2fe871 100644 --- a/python/zerobus/sdk/sync/__init__.py +++ b/python/zerobus/sdk/sync/__init__.py @@ -11,6 +11,7 @@ RecordType, StreamConfigurationOptions, TableProperties, + ZerobusArrowStream, ZerobusException, ZerobusSdk, ZerobusStream, @@ -19,6 +20,7 @@ __all__ = [ "ZerobusSdk", "ZerobusStream", + "ZerobusArrowStream", "RecordAcknowledgment", "TableProperties", "StreamConfigurationOptions", diff --git a/python/zerobus/sdk/sync/zerobus_sdk.py b/python/zerobus/sdk/sync/zerobus_sdk.py index fd09d33..902b567 100644 --- a/python/zerobus/sdk/sync/zerobus_sdk.py +++ b/python/zerobus/sdk/sync/zerobus_sdk.py @@ -38,7 +38,7 @@ >>> offset = ack.wait_for_ack(timeout_sec=30) """ -from typing import Iterator +from typing import Iterator, List, Optional # Import Rust-backed implementations import zerobus._zerobus_core as _core @@ -122,12 +122,129 @@ def get_state(self): return self._inner.get_state() if hasattr(self._inner, "get_state") else 1 +class ZerobusArrowStream: + """ + Synchronous Arrow Flight stream for ingesting pyarrow RecordBatches. + + **Experimental/Unsupported**: Arrow Flight support is experimental and not yet + supported for production use. The API may change in future releases. + + Example: + >>> import pyarrow as pa + >>> schema = pa.schema([("temp", pa.int32())]) + >>> stream = sdk.create_arrow_stream("catalog.schema.table", schema, client_id, client_secret) + >>> batch = pa.record_batch({"temp": [22, 23]}, schema=schema) + >>> offset = stream.ingest_batch(batch) + >>> stream.wait_for_offset(offset) + >>> stream.close() + """ + + def __init__(self, rust_stream): + self._inner = rust_stream + + def ingest_batch(self, batch) -> int: + """ + Ingest a pyarrow.RecordBatch or pyarrow.Table. + + Args: + batch: A pyarrow.RecordBatch or pyarrow.Table to ingest. + + Returns: + The offset ID assigned to this batch. + """ + from zerobus.sdk.shared.arrow import _serialize_batch + + ipc_bytes = _serialize_batch(batch) + return self._inner.ingest_batch(ipc_bytes) + + def wait_for_offset(self, offset: int): + """Wait for a specific offset to be acknowledged.""" + return self._inner.wait_for_offset(offset) + + def flush(self): + """Flush all pending batches, waiting for acknowledgment.""" + return self._inner.flush() + + def close(self): + """Close the stream gracefully.""" + return self._inner.close() + + @property + def is_closed(self) -> bool: + """Check if the stream has been closed.""" + return self._inner.is_closed + + @property + def table_name(self) -> str: + """Get the table name.""" + return self._inner.table_name + + def get_unacked_batches(self) -> list: + """ + Get unacknowledged batches as a list of pyarrow.RecordBatch. + + The stream must be closed before calling this method. + + Returns: + List of pyarrow.RecordBatch objects. + """ + from zerobus.sdk.shared.arrow import _deserialize_batch + + ipc_list = self._inner.get_unacked_batches() + return [_deserialize_batch(ipc_bytes) for ipc_bytes in ipc_list] + + class ZerobusSdk: """Python wrapper around Rust ZerobusSdk that provides unified create_stream API.""" def __init__(self, host: str, unity_catalog_url: str): self._inner = _RustZerobusSdk(host, unity_catalog_url) + def create_arrow_stream(self, table_name: str, schema, client_id: str, client_secret: str, options=None, headers_provider=None) -> ZerobusArrowStream: + """ + Create an Arrow Flight stream for ingesting pyarrow RecordBatches. + + **Experimental/Unsupported**: Arrow Flight support is experimental. + + Args: + table_name: Fully qualified table name (catalog.schema.table). + schema: A pyarrow.Schema defining the table schema. + client_id: OAuth client ID. + client_secret: OAuth client secret. + options: Optional ArrowStreamConfigurationOptions. + headers_provider: Optional custom headers provider (if set, overrides OAuth). + + Returns: + A ZerobusArrowStream ready for ingesting RecordBatches. + """ + from zerobus.sdk.shared.arrow import _serialize_schema + + schema_bytes = _serialize_schema(schema) + + if headers_provider is not None: + rust_stream = self._inner.create_arrow_stream_with_headers_provider( + table_name, schema_bytes, headers_provider, options + ) + else: + rust_stream = self._inner.create_arrow_stream( + table_name, schema_bytes, client_id, client_secret, options + ) + return ZerobusArrowStream(rust_stream) + + def recreate_arrow_stream(self, old_stream: ZerobusArrowStream) -> ZerobusArrowStream: + """ + Recreate a closed Arrow stream with the same configuration, + re-ingesting unacknowledged batches. + + Args: + old_stream: The closed Arrow stream to recreate. + + Returns: + A new ZerobusArrowStream. + """ + rust_stream = self._inner.recreate_arrow_stream(old_stream._inner) + return ZerobusArrowStream(rust_stream) + def create_stream(self, client_id: str, client_secret: str, table_properties, options=None, headers_provider=None): """ Create a stream with OAuth authentication or custom headers provider. @@ -168,6 +285,7 @@ def recreate_stream(self, old_stream: ZerobusStream): __all__ = [ "ZerobusSdk", "ZerobusStream", + "ZerobusArrowStream", "RecordAcknowledgment", "TableProperties", "StreamConfigurationOptions", From b932a2cf3d01d69bf9a47e0e2c0bd69eabc6b653 Mon Sep 17 00:00:00 2001 From: teodor-delibasic_data Date: Wed, 11 Mar 2026 11:45:27 +0000 Subject: [PATCH 02/13] Fix fmt and lint issues Signed-off-by: teodor-delibasic_data --- python/examples/async_example_arrow.py | 16 +++++++-------- python/examples/sync_example_arrow.py | 16 +++++++-------- python/tests/test_arrow.py | 27 +++++++++++++++++++------- python/zerobus/__init__.py | 2 +- python/zerobus/sdk/aio/zerobus_sdk.py | 6 ++++-- python/zerobus/sdk/shared/arrow.py | 11 ++++------- python/zerobus/sdk/sync/zerobus_sdk.py | 10 +++++----- 7 files changed, 50 insertions(+), 38 deletions(-) diff --git a/python/examples/async_example_arrow.py b/python/examples/async_example_arrow.py index 23225ce..08d07cd 100644 --- a/python/examples/async_example_arrow.py +++ b/python/examples/async_example_arrow.py @@ -43,11 +43,13 @@ ROWS_PER_BATCH = 100 # Define the Arrow schema -SCHEMA = pa.schema([ - ("device_name", pa.large_utf8()), - ("temp", pa.int32()), - ("humidity", pa.int64()), -]) +SCHEMA = pa.schema( + [ + ("device_name", pa.large_utf8()), + ("temp", pa.int32()), + ("humidity", pa.int64()), + ] +) def create_sample_batch(batch_index): @@ -104,9 +106,7 @@ async def main(): # The SDK automatically: # - Includes authorization header with OAuth token # - Includes x-databricks-zerobus-table-name header - stream = await sdk.create_arrow_stream( - TABLE_NAME, SCHEMA, CLIENT_ID, CLIENT_SECRET, options - ) + stream = await sdk.create_arrow_stream(TABLE_NAME, SCHEMA, CLIENT_ID, CLIENT_SECRET, options) logger.info(f"Arrow stream created for table: {stream.table_name}") # Step 4: Ingest Arrow RecordBatches asynchronously diff --git a/python/examples/sync_example_arrow.py b/python/examples/sync_example_arrow.py index b066c8d..c3ea465 100644 --- a/python/examples/sync_example_arrow.py +++ b/python/examples/sync_example_arrow.py @@ -42,11 +42,13 @@ ROWS_PER_BATCH = 100 # Define the Arrow schema -SCHEMA = pa.schema([ - ("device_name", pa.large_utf8()), - ("temp", pa.int32()), - ("humidity", pa.int64()), -]) +SCHEMA = pa.schema( + [ + ("device_name", pa.large_utf8()), + ("temp", pa.int32()), + ("humidity", pa.int64()), + ] +) def create_sample_batch(batch_index): @@ -103,9 +105,7 @@ def main(): # The SDK automatically: # - Includes authorization header with OAuth token # - Includes x-databricks-zerobus-table-name header - stream = sdk.create_arrow_stream( - TABLE_NAME, SCHEMA, CLIENT_ID, CLIENT_SECRET, options - ) + stream = sdk.create_arrow_stream(TABLE_NAME, SCHEMA, CLIENT_ID, CLIENT_SECRET, options) logger.info(f"Arrow stream created for table: {stream.table_name}") # Step 4: Ingest Arrow RecordBatches diff --git a/python/tests/test_arrow.py b/python/tests/test_arrow.py index 13c40c7..cbfb85f 100644 --- a/python/tests/test_arrow.py +++ b/python/tests/test_arrow.py @@ -37,13 +37,15 @@ def test_simple_schema(self): self.assertGreater(len(ipc_bytes), 0) def test_schema_with_various_types(self): - schema = pa.schema([ - ("int_col", pa.int32()), - ("float_col", pa.float64()), - ("str_col", pa.large_utf8()), - ("bool_col", pa.bool_()), - ("list_col", pa.list_(pa.int64())), - ]) + schema = pa.schema( + [ + ("int_col", pa.int32()), + ("float_col", pa.float64()), + ("str_col", pa.large_utf8()), + ("bool_col", pa.bool_()), + ("list_col", pa.list_(pa.int64())), + ] + ) ipc_bytes = _serialize_schema(schema) self.assertIsInstance(ipc_bytes, bytes) @@ -228,22 +230,27 @@ class TestArrowImports(unittest.TestCase): def test_import_from_shared_arrow(self): from zerobus.sdk.shared.arrow import ArrowStreamConfigurationOptions + self.assertIsNotNone(ArrowStreamConfigurationOptions) def test_import_from_top_level(self): from zerobus import ArrowStreamConfigurationOptions + self.assertIsNotNone(ArrowStreamConfigurationOptions) def test_import_arrow_stream_from_sync(self): from zerobus.sdk.sync import ZerobusArrowStream + self.assertIsNotNone(ZerobusArrowStream) def test_import_arrow_stream_from_aio(self): from zerobus.sdk.aio import ZerobusArrowStream + self.assertIsNotNone(ZerobusArrowStream) def test_import_arrow_stream_from_top_level(self): from zerobus import ZerobusArrowStream + self.assertIsNotNone(ZerobusArrowStream) @@ -252,16 +259,19 @@ class TestArrowSDKAPISurface(unittest.TestCase): def test_sync_sdk_has_arrow_methods(self): from zerobus.sdk.sync import ZerobusSdk + self.assertTrue(hasattr(ZerobusSdk, "create_arrow_stream")) self.assertTrue(hasattr(ZerobusSdk, "recreate_arrow_stream")) def test_async_sdk_has_arrow_methods(self): from zerobus.sdk.aio import ZerobusSdk + self.assertTrue(hasattr(ZerobusSdk, "create_arrow_stream")) self.assertTrue(hasattr(ZerobusSdk, "recreate_arrow_stream")) def test_sync_arrow_stream_has_methods(self): from zerobus.sdk.sync import ZerobusArrowStream + expected_methods = [ "ingest_batch", "wait_for_offset", @@ -277,6 +287,7 @@ def test_sync_arrow_stream_has_methods(self): def test_async_arrow_stream_has_methods(self): from zerobus.sdk.aio import ZerobusArrowStream + expected_methods = [ "ingest_batch", "wait_for_offset", @@ -293,6 +304,7 @@ def test_async_arrow_stream_has_methods(self): def test_arrow_types_not_on_core_top_level(self): """Arrow types should be in _core.arrow submodule, not on _core directly.""" import zerobus._zerobus_core as _core + self.assertFalse(hasattr(_core, "ArrowStreamConfigurationOptions")) self.assertFalse(hasattr(_core, "ZerobusArrowStream")) self.assertFalse(hasattr(_core, "AsyncZerobusArrowStream")) @@ -300,6 +312,7 @@ def test_arrow_types_not_on_core_top_level(self): def test_arrow_types_in_core_arrow_submodule(self): """Arrow types should be accessible via _core.arrow.""" import zerobus._zerobus_core as _core + self.assertTrue(hasattr(_core.arrow, "ArrowStreamConfigurationOptions")) self.assertTrue(hasattr(_core.arrow, "ZerobusArrowStream")) self.assertTrue(hasattr(_core.arrow, "AsyncZerobusArrowStream")) diff --git a/python/zerobus/__init__.py b/python/zerobus/__init__.py index 262c44f..e3499e6 100644 --- a/python/zerobus/__init__.py +++ b/python/zerobus/__init__.py @@ -45,8 +45,8 @@ # Import from Rust core import zerobus._zerobus_core as _core -from zerobus.sdk.sync import ZerobusSdk, ZerobusStream, ZerobusArrowStream from zerobus.sdk.shared.arrow import ArrowStreamConfigurationOptions +from zerobus.sdk.sync import ZerobusArrowStream, ZerobusSdk, ZerobusStream __version__ = "1.1.0" diff --git a/python/zerobus/sdk/aio/zerobus_sdk.py b/python/zerobus/sdk/aio/zerobus_sdk.py index 1ee8e98..523ef9e 100644 --- a/python/zerobus/sdk/aio/zerobus_sdk.py +++ b/python/zerobus/sdk/aio/zerobus_sdk.py @@ -39,7 +39,7 @@ >>> asyncio.run(main()) """ -from typing import Any, List, Optional +from typing import Any # Import Rust-backed implementations import zerobus._zerobus_core as _core @@ -223,7 +223,9 @@ class ZerobusSdk: def __init__(self, host: str, unity_catalog_url: str): self._inner = _core.aio.ZerobusSdk(host, unity_catalog_url) - async def create_arrow_stream(self, table_name: str, schema, client_id: str, client_secret: str, options=None, headers_provider=None) -> ZerobusArrowStream: + async def create_arrow_stream( + self, table_name: str, schema, client_id: str, client_secret: str, options=None, headers_provider=None + ) -> ZerobusArrowStream: """ Create an Arrow Flight stream for ingesting pyarrow RecordBatches (async). diff --git a/python/zerobus/sdk/shared/arrow.py b/python/zerobus/sdk/shared/arrow.py index 249ac29..ba8a43d 100644 --- a/python/zerobus/sdk/shared/arrow.py +++ b/python/zerobus/sdk/shared/arrow.py @@ -23,9 +23,10 @@ >>> stream.close() """ +import zerobus._zerobus_core as _core + _PYARROW_IMPORT_ERROR = ( - "pyarrow is required for Arrow Flight support. " - "Install with: pip install databricks-zerobus-ingest-sdk[arrow]" + "pyarrow is required for Arrow Flight support. " "Install with: pip install databricks-zerobus-ingest-sdk[arrow]" ) @@ -66,9 +67,7 @@ def _serialize_batch(batch): else: batch = batches[0] elif not isinstance(batch, pa.RecordBatch): - raise TypeError( - f"Expected pyarrow.RecordBatch or pyarrow.Table, got {type(batch).__name__}" - ) + raise TypeError(f"Expected pyarrow.RecordBatch or pyarrow.Table, got {type(batch).__name__}") sink = pa.BufferOutputStream() writer = pa.ipc.new_stream(sink, batch.schema) @@ -85,6 +84,4 @@ def _deserialize_batch(ipc_bytes): # Re-export configuration from Rust core -import zerobus._zerobus_core as _core - ArrowStreamConfigurationOptions = _core.arrow.ArrowStreamConfigurationOptions diff --git a/python/zerobus/sdk/sync/zerobus_sdk.py b/python/zerobus/sdk/sync/zerobus_sdk.py index 902b567..ca7fedc 100644 --- a/python/zerobus/sdk/sync/zerobus_sdk.py +++ b/python/zerobus/sdk/sync/zerobus_sdk.py @@ -38,7 +38,7 @@ >>> offset = ack.wait_for_ack(timeout_sec=30) """ -from typing import Iterator, List, Optional +from typing import Iterator # Import Rust-backed implementations import zerobus._zerobus_core as _core @@ -200,7 +200,9 @@ class ZerobusSdk: def __init__(self, host: str, unity_catalog_url: str): self._inner = _RustZerobusSdk(host, unity_catalog_url) - def create_arrow_stream(self, table_name: str, schema, client_id: str, client_secret: str, options=None, headers_provider=None) -> ZerobusArrowStream: + def create_arrow_stream( + self, table_name: str, schema, client_id: str, client_secret: str, options=None, headers_provider=None + ) -> ZerobusArrowStream: """ Create an Arrow Flight stream for ingesting pyarrow RecordBatches. @@ -226,9 +228,7 @@ def create_arrow_stream(self, table_name: str, schema, client_id: str, client_se table_name, schema_bytes, headers_provider, options ) else: - rust_stream = self._inner.create_arrow_stream( - table_name, schema_bytes, client_id, client_secret, options - ) + rust_stream = self._inner.create_arrow_stream(table_name, schema_bytes, client_id, client_secret, options) return ZerobusArrowStream(rust_stream) def recreate_arrow_stream(self, old_stream: ZerobusArrowStream) -> ZerobusArrowStream: From ed953291ef31b21b7ebb3a7e1b79c118d2204a96 Mon Sep 17 00:00:00 2001 From: teodor-delibasic_data Date: Fri, 13 Mar 2026 14:04:18 +0000 Subject: [PATCH 03/13] Fix empty table check in _serialize_batch and simplify multi-chunk path Signed-off-by: teodor-delibasic_data --- python/tests/test_arrow.py | 3 ++- python/zerobus/sdk/shared/arrow.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/tests/test_arrow.py b/python/tests/test_arrow.py index cbfb85f..149827a 100644 --- a/python/tests/test_arrow.py +++ b/python/tests/test_arrow.py @@ -104,8 +104,9 @@ def test_table_multiple_chunks(self): def test_empty_table_raises(self): schema = pa.schema([("a", pa.int64())]) table = pa.table({"a": pa.array([], type=pa.int64())}, schema=schema) - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as cm: _serialize_batch(table) + self.assertIn("empty", str(cm.exception).lower()) def test_rejects_wrong_type(self): with self.assertRaises(TypeError): diff --git a/python/zerobus/sdk/shared/arrow.py b/python/zerobus/sdk/shared/arrow.py index ba8a43d..b5b833e 100644 --- a/python/zerobus/sdk/shared/arrow.py +++ b/python/zerobus/sdk/shared/arrow.py @@ -58,12 +58,12 @@ def _serialize_batch(batch): """Serialize a pyarrow.RecordBatch (or Table) to IPC bytes.""" pa = _check_pyarrow() if isinstance(batch, pa.Table): - # Convert Table to a single RecordBatch - batches = batch.to_batches() - if len(batches) == 0: + if batch.num_rows == 0: raise ValueError("Cannot ingest an empty pyarrow.Table") + # Convert Table to a single RecordBatch, combining chunks if needed + batches = batch.to_batches() if len(batches) > 1: - batch = pa.concat_tables([pa.Table.from_batches([b]) for b in batches]).combine_chunks().to_batches()[0] + batch = pa.concat_batches(batches) else: batch = batches[0] elif not isinstance(batch, pa.RecordBatch): From 64738fc4c732c51317fcc091b6266c161827eca9 Mon Sep 17 00:00:00 2001 From: teodor-delibasic_data Date: Fri, 13 Mar 2026 14:14:24 +0000 Subject: [PATCH 04/13] Add pyrefly.toml to .gitignore Signed-off-by: teodor-delibasic_data --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 60199ff..e31a805 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,9 @@ .bsp/ .bazelbsp/ +# Tools +pyrefly.toml + # OS .DS_Store Thumbs.db From 4d134380baa86ebfec64339d16e5d98133f09e43 Mon Sep 17 00:00:00 2001 From: teodor-delibasic_data Date: Fri, 13 Mar 2026 14:28:27 +0000 Subject: [PATCH 05/13] Validate non-negative config values before casting signed to unsigned types Signed-off-by: teodor-delibasic_data --- python/rust/src/arrow.rs | 49 +++++++++++++++++++++++---- python/rust/src/async_wrapper.rs | 58 +++++++++++++++++--------------- python/rust/src/common.rs | 51 ++++++++++++++++++++++++++++ python/rust/src/sync_wrapper.rs | 2 ++ 4 files changed, 126 insertions(+), 34 deletions(-) diff --git a/python/rust/src/arrow.rs b/python/rust/src/arrow.rs index 8fb1e48..c0f5a76 100644 --- a/python/rust/src/arrow.rs +++ b/python/rust/src/arrow.rs @@ -205,8 +205,43 @@ impl ArrowStreamConfigurationOptions { } impl ArrowStreamConfigurationOptions { - pub fn to_rust(&self) -> RustArrowStreamOptions { - RustArrowStreamOptions { + pub fn to_rust(&self) -> Result { + if self.max_inflight_batches < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "max_inflight_batches must be non-negative", + )); + } + if self.recovery_timeout_ms < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "recovery_timeout_ms must be non-negative", + )); + } + if self.recovery_backoff_ms < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "recovery_backoff_ms must be non-negative", + )); + } + if self.recovery_retries < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "recovery_retries must be non-negative", + )); + } + if self.server_lack_of_ack_timeout_ms < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "server_lack_of_ack_timeout_ms must be non-negative", + )); + } + if self.flush_timeout_ms < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "flush_timeout_ms must be non-negative", + )); + } + if self.connection_timeout_ms < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "connection_timeout_ms must be non-negative", + )); + } + Ok(RustArrowStreamOptions { max_inflight_batches: self.max_inflight_batches as usize, recovery: self.recovery, recovery_timeout_ms: self.recovery_timeout_ms as u64, @@ -216,7 +251,7 @@ impl ArrowStreamConfigurationOptions { flush_timeout_ms: self.flush_timeout_ms as u64, connection_timeout_ms: self.connection_timeout_ms as u64, ipc_compression: None, - } + }) } } @@ -498,7 +533,7 @@ pub fn create_arrow_stream_sync( schema: Arc::new(schema), }; - let rust_options = options.map(|o| o.to_rust()); + let rust_options = options.map(|o| o.to_rust()).transpose()?; let sdk_clone = sdk.clone(); let runtime_clone = runtime.clone(); @@ -537,7 +572,7 @@ pub fn create_arrow_stream_with_headers_provider_sync( schema: Arc::new(schema), }; - let rust_options = options.map(|o| o.to_rust()); + let rust_options = options.map(|o| o.to_rust()).transpose()?; let provider = Arc::new(HeadersProviderWrapper::new(headers_provider)); let sdk_clone = sdk.clone(); @@ -609,7 +644,7 @@ pub fn create_arrow_stream_async<'py>( schema: Arc::new(schema), }; - let rust_options = options.map(|o| o.to_rust()); + let rust_options = options.map(|o| o.to_rust()).transpose()?; let sdk_clone = sdk.clone(); pyo3_asyncio::tokio::future_into_py(py, async move { @@ -642,7 +677,7 @@ pub fn create_arrow_stream_with_headers_provider_async<'py>( schema: Arc::new(schema), }; - let rust_options = options.map(|o| o.to_rust()); + let rust_options = options.map(|o| o.to_rust()).transpose()?; let provider = Arc::new(HeadersProviderWrapper::new(headers_provider)); let sdk_clone = sdk.clone(); diff --git a/python/rust/src/async_wrapper.rs b/python/rust/src/async_wrapper.rs index 543aa04..7a4a6e4 100644 --- a/python/rust/src/async_wrapper.rs +++ b/python/rust/src/async_wrapper.rs @@ -425,7 +425,7 @@ impl ZerobusSdk { descriptor_proto: table_properties.descriptor_proto.clone(), }; - let opts = convert_stream_options(options); + let opts = convert_stream_options(options)?; future_into_py(py, async move { let sdk_guard = sdk.read().await; @@ -455,7 +455,7 @@ impl ZerobusSdk { descriptor_proto: table_properties.descriptor_proto.clone(), }; - let opts = convert_stream_options(options); + let opts = convert_stream_options(options)?; let wrapper = Arc::new(HeadersProviderWrapper::new(headers_provider)); future_into_py(py, async move { @@ -548,30 +548,34 @@ impl ZerobusSdk { // Helper to convert Python StreamConfigurationOptions to Rust options fn convert_stream_options( options: Option<&StreamConfigurationOptions>, -) -> Option { - options.map(|opts| { - let ack_callback = opts - .ack_callback - .clone() - .map(|cb| Arc::new(AckCallbackWrapper::new(cb)) as Arc); - - RustStreamOptions { - max_inflight_requests: opts.max_inflight_records as usize, - recovery: opts.recovery, - recovery_timeout_ms: opts.recovery_timeout_ms as u64, - recovery_backoff_ms: opts.recovery_backoff_ms as u64, - recovery_retries: opts.recovery_retries as u32, - server_lack_of_ack_timeout_ms: opts.server_lack_of_ack_timeout_ms as u64, - flush_timeout_ms: opts.flush_timeout_ms as u64, - record_type: match opts.record_type.value { - 1 => RustRecordType::Proto, - 2 => RustRecordType::Json, - _ => RustRecordType::Proto, - }, - stream_paused_max_wait_time_ms: opts.stream_paused_max_wait_time_ms.map(|v| v as u64), - callback_max_wait_time_ms: opts.callback_max_wait_time_ms.map(|v| v as u64), - ack_callback, - ..Default::default() +) -> PyResult> { + match options { + Some(opts) => { + opts.validate()?; + let ack_callback = opts + .ack_callback + .clone() + .map(|cb| Arc::new(AckCallbackWrapper::new(cb)) as Arc); + + Ok(Some(RustStreamOptions { + max_inflight_requests: opts.max_inflight_records as usize, + recovery: opts.recovery, + recovery_timeout_ms: opts.recovery_timeout_ms as u64, + recovery_backoff_ms: opts.recovery_backoff_ms as u64, + recovery_retries: opts.recovery_retries as u32, + server_lack_of_ack_timeout_ms: opts.server_lack_of_ack_timeout_ms as u64, + flush_timeout_ms: opts.flush_timeout_ms as u64, + record_type: match opts.record_type.value { + 1 => RustRecordType::Proto, + 2 => RustRecordType::Json, + _ => RustRecordType::Proto, + }, + stream_paused_max_wait_time_ms: opts.stream_paused_max_wait_time_ms.map(|v| v as u64), + callback_max_wait_time_ms: opts.callback_max_wait_time_ms.map(|v| v as u64), + ack_callback, + ..Default::default() + })) } - }) + None => Ok(None), + } } diff --git a/python/rust/src/common.rs b/python/rust/src/common.rs index f62b578..84bc065 100644 --- a/python/rust/src/common.rs +++ b/python/rust/src/common.rs @@ -207,6 +207,57 @@ pub struct StreamConfigurationOptions { pub ack_callback: Option>, } +impl StreamConfigurationOptions { + /// Validate that all numeric fields are non-negative before casting to unsigned types. + pub fn validate(&self) -> PyResult<()> { + if self.max_inflight_records < 0 { + return Err(PyValueError::new_err( + "max_inflight_records must be non-negative", + )); + } + if self.recovery_timeout_ms < 0 { + return Err(PyValueError::new_err( + "recovery_timeout_ms must be non-negative", + )); + } + if self.recovery_backoff_ms < 0 { + return Err(PyValueError::new_err( + "recovery_backoff_ms must be non-negative", + )); + } + if self.recovery_retries < 0 { + return Err(PyValueError::new_err( + "recovery_retries must be non-negative", + )); + } + if self.server_lack_of_ack_timeout_ms < 0 { + return Err(PyValueError::new_err( + "server_lack_of_ack_timeout_ms must be non-negative", + )); + } + if self.flush_timeout_ms < 0 { + return Err(PyValueError::new_err( + "flush_timeout_ms must be non-negative", + )); + } + if let Some(v) = self.stream_paused_max_wait_time_ms { + if v < 0 { + return Err(PyValueError::new_err( + "stream_paused_max_wait_time_ms must be non-negative", + )); + } + } + if let Some(v) = self.callback_max_wait_time_ms { + if v < 0 { + return Err(PyValueError::new_err( + "callback_max_wait_time_ms must be non-negative", + )); + } + } + Ok(()) + } +} + impl Default for StreamConfigurationOptions { fn default() -> Self { Self { diff --git a/python/rust/src/sync_wrapper.rs b/python/rust/src/sync_wrapper.rs index 78c33cc..4df2fd7 100644 --- a/python/rust/src/sync_wrapper.rs +++ b/python/rust/src/sync_wrapper.rs @@ -449,6 +449,7 @@ impl ZerobusSdk { }; let rust_options = if let Some(opts) = options.clone() { + opts.validate()?; let ack_callback = opts .ack_callback .map(|cb| Arc::new(AckCallbackWrapper::new(cb)) as Arc); @@ -516,6 +517,7 @@ impl ZerobusSdk { }; let rust_options = if let Some(opts) = options.clone() { + opts.validate()?; let ack_callback = opts .ack_callback .map(|cb| Arc::new(AckCallbackWrapper::new(cb)) as Arc); From 43100f4fb7d8c3aae87046ff732fff798f5d2c3c Mon Sep 17 00:00:00 2001 From: teodor-delibasic_data Date: Fri, 13 Mar 2026 14:32:21 +0000 Subject: [PATCH 06/13] Fix is_closed/table_name getters to avoid GIL deadlock and tokio panic Signed-off-by: teodor-delibasic_data --- python/rust/src/arrow.rs | 50 ++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/python/rust/src/arrow.rs b/python/rust/src/arrow.rs index c0f5a76..1416bba 100644 --- a/python/rust/src/arrow.rs +++ b/python/rust/src/arrow.rs @@ -340,21 +340,27 @@ impl ZerobusArrowStream { /// Check if the stream has been closed. #[getter] - fn is_closed(&self) -> bool { + fn is_closed(&self, py: Python) -> bool { let stream_clone = self.inner.clone(); - self.runtime.block_on(async move { - let stream_guard = stream_clone.read().await; - stream_guard.is_closed() + let runtime = self.runtime.clone(); + py.allow_threads(|| { + runtime.block_on(async move { + let stream_guard = stream_clone.read().await; + stream_guard.is_closed() + }) }) } /// Get the table name. #[getter] - fn table_name(&self) -> String { + fn table_name(&self, py: Python) -> String { let stream_clone = self.inner.clone(); - self.runtime.block_on(async move { - let stream_guard = stream_clone.read().await; - stream_guard.table_name().to_string() + let runtime = self.runtime.clone(); + py.allow_threads(|| { + runtime.block_on(async move { + let stream_guard = stream_clone.read().await; + stream_guard.table_name().to_string() + }) }) } @@ -466,24 +472,24 @@ impl AsyncZerobusArrowStream { /// Check if the stream has been closed. #[getter] - fn is_closed(&self) -> bool { - // This needs a runtime to lock, but is_closed is atomic so - // we can use try_read or spawn. - let stream_clone = self.inner.clone(); - pyo3_asyncio::tokio::get_runtime().block_on(async move { - let stream_guard = stream_clone.read().await; - stream_guard.is_closed() - }) + fn is_closed(&self) -> PyResult { + match self.inner.try_read() { + Ok(stream_guard) => Ok(stream_guard.is_closed()), + Err(_) => Err(pyo3::exceptions::PyRuntimeError::new_err( + "Cannot read stream state: lock is held by another operation", + )), + } } /// Get the table name. #[getter] - fn table_name(&self) -> String { - let stream_clone = self.inner.clone(); - pyo3_asyncio::tokio::get_runtime().block_on(async move { - let stream_guard = stream_clone.read().await; - stream_guard.table_name().to_string() - }) + fn table_name(&self) -> PyResult { + match self.inner.try_read() { + Ok(stream_guard) => Ok(stream_guard.table_name().to_string()), + Err(_) => Err(pyo3::exceptions::PyRuntimeError::new_err( + "Cannot read stream state: lock is held by another operation", + )), + } } /// Get unacknowledged batches as a list of Arrow IPC byte buffers. From 1dc512272ccb49c96426549c53804cfc1ec1509c Mon Sep 17 00:00:00 2001 From: teodor-delibasic_data Date: Fri, 13 Mar 2026 14:33:36 +0000 Subject: [PATCH 07/13] Error on multiple batches in ipc_bytes_to_record_batch instead of silently dropping Signed-off-by: teodor-delibasic_data --- python/rust/src/arrow.rs | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/python/rust/src/arrow.rs b/python/rust/src/arrow.rs index 1416bba..8e78bba 100644 --- a/python/rust/src/arrow.rs +++ b/python/rust/src/arrow.rs @@ -23,30 +23,31 @@ fn map_rust_error_to_pyerr(err: RustError) -> PyErr { map_error(err) } -/// Deserialize Arrow IPC bytes into a RecordBatch. +/// Deserialize Arrow IPC bytes into exactly one RecordBatch. fn ipc_bytes_to_record_batch( ipc_bytes: &[u8], ) -> Result { - let reader = + let mut reader = arrow_ipc::reader::StreamReader::try_new(ipc_bytes, None).map_err(|e| { RustError::InvalidArgument(format!("Failed to parse Arrow IPC data: {}", e)) })?; - let mut batches = Vec::new(); - for batch_result in reader { - let batch = batch_result.map_err(|e| { + let batch = reader + .next() + .ok_or_else(|| { + RustError::InvalidArgument("No batches found in Arrow IPC data".to_string()) + })? + .map_err(|e| { RustError::InvalidArgument(format!("Failed to read Arrow batch: {}", e)) })?; - batches.push(batch); - } - if batches.is_empty() { + if reader.next().is_some() { return Err(RustError::InvalidArgument( - "No batches found in Arrow IPC data".to_string(), + "Expected exactly one RecordBatch in Arrow IPC data, found multiple".to_string(), )); } - Ok(batches.into_iter().next().unwrap()) + Ok(batch) } /// Serialize a RecordBatch to Arrow IPC bytes. From faab1d3bcc22f47bdcabdf82dba0edf80ba18d4e Mon Sep 17 00:00:00 2001 From: teodor-delibasic_data Date: Fri, 13 Mar 2026 14:35:47 +0000 Subject: [PATCH 08/13] Add .pyi type stubs for Arrow stream classes Signed-off-by: teodor-delibasic_data --- python/zerobus/_zerobus_core.pyi | 106 +++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/python/zerobus/_zerobus_core.pyi b/python/zerobus/_zerobus_core.pyi index 40950cf..0700f73 100644 --- a/python/zerobus/_zerobus_core.pyi +++ b/python/zerobus/_zerobus_core.pyi @@ -508,3 +508,109 @@ class aio: A new ZerobusStream """ ... + +# ============================================================================= +# ARROW (EXPERIMENTAL) +# ============================================================================= + +class arrow: + """Arrow Flight support submodule.""" + + class ArrowStreamConfigurationOptions: + """Configuration options for Arrow Flight streams.""" + + max_inflight_batches: int + """Maximum number of batches in-flight (pending acknowledgment). Default: 1000""" + + recovery: bool + """Enable automatic stream recovery on retryable failures. Default: True""" + + recovery_timeout_ms: int + """Timeout per recovery attempt in milliseconds. Default: 15000""" + + recovery_backoff_ms: int + """Backoff between recovery attempts in milliseconds. Default: 2000""" + + recovery_retries: int + """Maximum recovery retry attempts. Default: 4""" + + server_lack_of_ack_timeout_ms: int + """Server acknowledgment timeout in milliseconds. Default: 60000""" + + flush_timeout_ms: int + """Flush timeout in milliseconds. Default: 300000""" + + connection_timeout_ms: int + """Connection establishment timeout in milliseconds. Default: 30000""" + + def __init__( + self, + *, + max_inflight_batches: int = 1000, + recovery: bool = True, + recovery_timeout_ms: int = 15000, + recovery_backoff_ms: int = 2000, + recovery_retries: int = 4, + server_lack_of_ack_timeout_ms: int = 60000, + flush_timeout_ms: int = 300000, + connection_timeout_ms: int = 30000, + ) -> None: ... + def __repr__(self) -> str: ... + + class ZerobusArrowStream: + """Synchronous Arrow Flight stream for ingesting pyarrow RecordBatches.""" + + @property + def is_closed(self) -> bool: ... + + @property + def table_name(self) -> str: ... + + def ingest_batch(self, ipc_bytes: bytes) -> int: + """Ingest a RecordBatch (as IPC bytes) and return the logical offset.""" + ... + + def wait_for_offset(self, offset: int) -> None: + """Wait for the server to acknowledge the batch at the given offset.""" + ... + + def flush(self) -> None: + """Flush all pending batches, waiting for acknowledgment.""" + ... + + def close(self) -> None: + """Close the stream gracefully.""" + ... + + def get_unacked_batches(self) -> List[bytes]: + """Return unacknowledged batches as Arrow IPC bytes.""" + ... + + class AsyncZerobusArrowStream: + """Asynchronous Arrow Flight stream for ingesting pyarrow RecordBatches.""" + + @property + def is_closed(self) -> bool: ... + + @property + def table_name(self) -> str: ... + + async def ingest_batch(self, ipc_bytes: bytes) -> int: + """Ingest a RecordBatch (as IPC bytes) and return the logical offset.""" + ... + + async def wait_for_offset(self, offset: int) -> None: + """Wait for the server to acknowledge the batch at the given offset.""" + ... + + async def flush(self) -> None: + """Flush all pending batches, waiting for acknowledgment.""" + ... + + async def close(self) -> None: + """Close the stream gracefully.""" + ... + + async def get_unacked_batches(self) -> List[bytes]: + """Return unacknowledged batches as Arrow IPC bytes.""" + ... From 296344f764364aab9b563a9f5c36e1a5a187f274 Mon Sep 17 00:00:00 2001 From: teodor-delibasic_data Date: Fri, 13 Mar 2026 14:43:10 +0000 Subject: [PATCH 09/13] Register PyO3 submodules in sys.modules to enable direct imports Signed-off-by: teodor-delibasic_data --- python/rust/src/lib.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/rust/src/lib.rs b/python/rust/src/lib.rs index 46a1acc..b8f2eb2 100644 --- a/python/rust/src/lib.rs +++ b/python/rust/src/lib.rs @@ -44,12 +44,15 @@ fn _zerobus_core(py: Python, m: &PyModule) -> PyResult<()> { // Add authentication classes m.add_class::()?; + let sys_modules = py.import("sys")?.getattr("modules")?; + // Add arrow submodule let arrow_module = PyModule::new(py, "arrow")?; arrow_module.add_class::()?; arrow_module.add_class::()?; arrow_module.add_class::()?; m.add_submodule(arrow_module)?; + sys_modules.set_item("zerobus._zerobus_core.arrow", arrow_module)?; // Add sync submodule let sync_module = PyModule::new(py, "sync")?; @@ -57,6 +60,7 @@ fn _zerobus_core(py: Python, m: &PyModule) -> PyResult<()> { sync_module.add_class::()?; sync_module.add_class::()?; m.add_submodule(sync_module)?; + sys_modules.set_item("zerobus._zerobus_core.sync", sync_module)?; // Add aio (async) submodule let aio_module = PyModule::new(py, "aio")?; @@ -64,6 +68,7 @@ fn _zerobus_core(py: Python, m: &PyModule) -> PyResult<()> { aio_module.add_class::()?; aio_module.add_class::()?; m.add_submodule(aio_module)?; + sys_modules.set_item("zerobus._zerobus_core.aio", aio_module)?; Ok(()) } From e6e531aee1642b680545d507fe1dc302f88e402f Mon Sep 17 00:00:00 2001 From: teodor-delibasic_data Date: Fri, 13 Mar 2026 14:46:13 +0000 Subject: [PATCH 10/13] Add tests for empty RecordBatch, negative config values and async API Signed-off-by: teodor-delibasic_data --- python/tests/test_arrow.py | 66 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/python/tests/test_arrow.py b/python/tests/test_arrow.py index 149827a..a34453d 100644 --- a/python/tests/test_arrow.py +++ b/python/tests/test_arrow.py @@ -226,6 +226,72 @@ def test_repr(self): self.assertIn("recovery", repr_str) +class TestSerializeBatchEmptyRecordBatch(unittest.TestCase): + """Test that empty RecordBatch (not Table) is handled correctly.""" + + def test_empty_record_batch_serializes(self): + """An empty RecordBatch (0 rows) should still serialize — only empty Tables are rejected.""" + schema = pa.schema([("a", pa.int64())]) + batch = pa.record_batch({"a": pa.array([], type=pa.int64())}, schema=schema) + # RecordBatch with 0 rows is valid — the check is only on Table + ipc_bytes = _serialize_batch(batch) + self.assertIsInstance(ipc_bytes, bytes) + recovered = _deserialize_batch(ipc_bytes) + self.assertEqual(recovered.num_rows, 0) + + +class TestArrowConfigNegativeValues(unittest.TestCase): + """Test that negative config values are rejected at stream creation time.""" + + def test_negative_max_inflight_batches(self): + options = ArrowStreamConfigurationOptions(max_inflight_batches=-1) + self.assertEqual(options.max_inflight_batches, -1) + # Validation happens in to_rust() at stream creation, not at construction. + # We can't test the Rust-side validation without a server, but we verify + # the value is accepted by the Python constructor and stored. + + def test_negative_recovery_timeout_ms(self): + options = ArrowStreamConfigurationOptions(recovery_timeout_ms=-100) + self.assertEqual(options.recovery_timeout_ms, -100) + + def test_negative_flush_timeout_ms(self): + options = ArrowStreamConfigurationOptions(flush_timeout_ms=-1) + self.assertEqual(options.flush_timeout_ms, -1) + + def test_negative_connection_timeout_ms(self): + options = ArrowStreamConfigurationOptions(connection_timeout_ms=-500) + self.assertEqual(options.connection_timeout_ms, -500) + + +class TestAsyncArrowStreamAPISurface(unittest.TestCase): + """Test that AsyncZerobusArrowStream has the expected async API.""" + + def test_async_arrow_stream_methods_are_coroutines(self): + """Verify async stream methods are actual coroutine functions.""" + from zerobus.sdk.aio import ZerobusArrowStream + + # These methods should be present on the wrapper class + for method_name in ["ingest_batch", "wait_for_offset", "flush", "close", "get_unacked_batches"]: + self.assertTrue( + hasattr(ZerobusArrowStream, method_name), + f"AsyncZerobusArrowStream missing method: {method_name}", + ) + + def test_async_sdk_has_create_and_recreate(self): + """Verify async SDK has arrow stream creation methods.""" + import inspect + + from zerobus.sdk.aio import ZerobusSdk + + self.assertTrue(hasattr(ZerobusSdk, "create_arrow_stream")) + self.assertTrue(hasattr(ZerobusSdk, "recreate_arrow_stream")) + # create_arrow_stream should be a coroutine function + self.assertTrue( + inspect.iscoroutinefunction(ZerobusSdk.create_arrow_stream), + "create_arrow_stream should be async", + ) + + class TestArrowImports(unittest.TestCase): """Test that Arrow types can be imported from expected locations.""" From 6683e65a8d0f945f7da4e7e661cf10d36add9ee7 Mon Sep 17 00:00:00 2001 From: teodor-delibasic_data Date: Fri, 13 Mar 2026 14:47:18 +0000 Subject: [PATCH 11/13] Add perf TODO comment Signed-off-by: teodor-delibasic_data --- python/rust/src/arrow.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/rust/src/arrow.rs b/python/rust/src/arrow.rs index 8e78bba..6308e35 100644 --- a/python/rust/src/arrow.rs +++ b/python/rust/src/arrow.rs @@ -274,6 +274,9 @@ impl ZerobusArrowStream { /// Args: /// ipc_bytes: Arrow IPC serialized bytes from pyarrow.RecordBatch.serialize() fn ingest_batch(&self, py: Python, ipc_bytes: &PyBytes) -> PyResult { + // TODO(perf): eliminate double IPC serialization - Python-to-IPC-to-RecordBatch here, + // then RecordBatch-to-IPC again inside the Rust SDK for Flight. Pass IPC bytes + // directly to the SDK instead. let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes()) .map_err(|e| map_rust_error_to_pyerr(e))?; From 0f6ccc1b9bc254d38276b09c2a78a60165b094c5 Mon Sep 17 00:00:00 2001 From: teodor-delibasic_data Date: Fri, 13 Mar 2026 14:48:23 +0000 Subject: [PATCH 12/13] Fix incorrect comment on schema IPC bytes format Signed-off-by: teodor-delibasic_data --- python/rust/src/arrow.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/rust/src/arrow.rs b/python/rust/src/arrow.rs index 6308e35..36f6ee7 100644 --- a/python/rust/src/arrow.rs +++ b/python/rust/src/arrow.rs @@ -77,19 +77,18 @@ fn record_batch_to_ipc_bytes( Ok(buffer) } -/// Build an ArrowSchema from IPC-serialized schema bytes. +/// Build an ArrowSchema from Arrow IPC stream bytes (schema-only, no batches). /// -/// Python side calls `schema.serialize().to_pybytes()` on a `pyarrow.Schema` -/// to produce the IPC stream bytes. The schema is the first message in the stream. +/// Python side uses `pa.ipc.new_stream(sink, schema)` followed by `writer.close()` +/// to produce a full IPC stream containing only the schema message. fn ipc_schema_bytes_to_arrow_schema( schema_bytes: &[u8], ) -> Result { let reader = arrow_ipc::reader::StreamReader::try_new(schema_bytes, None) .map_err(|e| { RustError::InvalidArgument(format!( - "Failed to parse Arrow schema bytes: {}. \ - Pass schema bytes obtained from pyarrow.Schema via \ - schema.serialize().to_pybytes()", + "Failed to parse Arrow IPC schema bytes: {}. \ + Pass bytes from pa.ipc.new_stream(sink, schema) with no batches written.", e )) })?; From f1b1eb9702ed7ffb1821c04d77f3caf67278a691 Mon Sep 17 00:00:00 2001 From: teodor-delibasic_data Date: Fri, 13 Mar 2026 14:49:53 +0000 Subject: [PATCH 13/13] Remove map_rust_error_to_pyerr wrapper Signed-off-by: teodor-delibasic_data --- python/rust/src/arrow.rs | 52 +++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/python/rust/src/arrow.rs b/python/rust/src/arrow.rs index 36f6ee7..99fadf7 100644 --- a/python/rust/src/arrow.rs +++ b/python/rust/src/arrow.rs @@ -19,10 +19,6 @@ use databricks_zerobus_ingest_sdk::{ use crate::auth::HeadersProviderWrapper; use crate::common::map_error; -fn map_rust_error_to_pyerr(err: RustError) -> PyErr { - map_error(err) -} - /// Deserialize Arrow IPC bytes into exactly one RecordBatch. fn ipc_bytes_to_record_batch( ipc_bytes: &[u8], @@ -277,7 +273,7 @@ impl ZerobusArrowStream { // then RecordBatch-to-IPC again inside the Rust SDK for Flight. Pass IPC bytes // directly to the SDK instead. let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes()) - .map_err(|e| map_rust_error_to_pyerr(e))?; + .map_err(|e| map_error(e))?; let stream_clone = self.inner.clone(); let runtime = self.runtime.clone(); @@ -288,7 +284,7 @@ impl ZerobusArrowStream { stream_guard .ingest_batch(batch) .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + .map_err(|e| Python::with_gil(|_py| map_error(e))) }) }) } @@ -304,7 +300,7 @@ impl ZerobusArrowStream { stream_guard .wait_for_offset(offset) .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + .map_err(|e| Python::with_gil(|_py| map_error(e))) }) }) } @@ -320,7 +316,7 @@ impl ZerobusArrowStream { stream_guard .flush() .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + .map_err(|e| Python::with_gil(|_py| map_error(e))) }) }) } @@ -336,7 +332,7 @@ impl ZerobusArrowStream { stream_guard .close() .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + .map_err(|e| Python::with_gil(|_py| map_error(e))) }) }) } @@ -378,13 +374,13 @@ impl ZerobusArrowStream { let batches = stream_guard .get_unacked_batches() .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; Python::with_gil(|py| { let mut py_batches: Vec = Vec::with_capacity(batches.len()); for batch in &batches { let ipc_bytes = record_batch_to_ipc_bytes(batch) - .map_err(|e| map_rust_error_to_pyerr(e))?; + .map_err(|e| map_error(e))?; py_batches .push(PyBytes::new(py, &ipc_bytes).into()); } @@ -414,7 +410,7 @@ impl AsyncZerobusArrowStream { ipc_bytes: &PyBytes, ) -> PyResult<&'py PyAny> { let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes()) - .map_err(|e| map_rust_error_to_pyerr(e))?; + .map_err(|e| map_error(e))?; let stream_clone = self.inner.clone(); @@ -423,7 +419,7 @@ impl AsyncZerobusArrowStream { stream_guard .ingest_batch(batch) .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + .map_err(|e| Python::with_gil(|_py| map_error(e))) }) } @@ -440,7 +436,7 @@ impl AsyncZerobusArrowStream { stream_guard .wait_for_offset(offset) .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; Ok(()) }) } @@ -454,7 +450,7 @@ impl AsyncZerobusArrowStream { stream_guard .flush() .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; Ok(()) }) } @@ -468,7 +464,7 @@ impl AsyncZerobusArrowStream { stream_guard .close() .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; Ok(()) }) } @@ -504,13 +500,13 @@ impl AsyncZerobusArrowStream { let batches = stream_guard .get_unacked_batches() .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; Python::with_gil(|py| { let mut py_batches: Vec = Vec::with_capacity(batches.len()); for batch in &batches { let ipc_bytes = record_batch_to_ipc_bytes(batch) - .map_err(|e| map_rust_error_to_pyerr(e))?; + .map_err(|e| map_error(e))?; py_batches.push(PyBytes::new(py, &ipc_bytes).into()); } Ok(py_batches) @@ -535,7 +531,7 @@ pub fn create_arrow_stream_sync( options: Option<&ArrowStreamConfigurationOptions>, ) -> PyResult { let schema = ipc_schema_bytes_to_arrow_schema(schema_ipc_bytes) - .map_err(|e| map_rust_error_to_pyerr(e))?; + .map_err(|e| map_error(e))?; let table_props = RustArrowTableProperties { table_name, @@ -553,7 +549,7 @@ pub fn create_arrow_stream_sync( sdk_guard .create_arrow_stream(table_props, client_id, client_secret, rust_options) .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + .map_err(|e| Python::with_gil(|_py| map_error(e))) }) })?; @@ -574,7 +570,7 @@ pub fn create_arrow_stream_with_headers_provider_sync( options: Option<&ArrowStreamConfigurationOptions>, ) -> PyResult { let schema = ipc_schema_bytes_to_arrow_schema(schema_ipc_bytes) - .map_err(|e| map_rust_error_to_pyerr(e))?; + .map_err(|e| map_error(e))?; let table_props = RustArrowTableProperties { table_name, @@ -597,7 +593,7 @@ pub fn create_arrow_stream_with_headers_provider_sync( rust_options, ) .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + .map_err(|e| Python::with_gil(|_py| map_error(e))) }) })?; @@ -625,7 +621,7 @@ pub fn recreate_arrow_stream_sync( sdk_guard .recreate_arrow_stream(&*old_guard) .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e))) + .map_err(|e| Python::with_gil(|_py| map_error(e))) }) })?; @@ -646,7 +642,7 @@ pub fn create_arrow_stream_async<'py>( options: Option, ) -> PyResult<&'py PyAny> { let schema = ipc_schema_bytes_to_arrow_schema(&schema_ipc_bytes) - .map_err(|e| map_rust_error_to_pyerr(e))?; + .map_err(|e| map_error(e))?; let table_props = RustArrowTableProperties { table_name, @@ -661,7 +657,7 @@ pub fn create_arrow_stream_async<'py>( let stream = sdk_guard .create_arrow_stream(table_props, client_id, client_secret, rust_options) .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; Ok(AsyncZerobusArrowStream { inner: Arc::new(RwLock::new(stream)), @@ -679,7 +675,7 @@ pub fn create_arrow_stream_with_headers_provider_async<'py>( options: Option, ) -> PyResult<&'py PyAny> { let schema = ipc_schema_bytes_to_arrow_schema(&schema_ipc_bytes) - .map_err(|e| map_rust_error_to_pyerr(e))?; + .map_err(|e| map_error(e))?; let table_props = RustArrowTableProperties { table_name, @@ -699,7 +695,7 @@ pub fn create_arrow_stream_with_headers_provider_async<'py>( rust_options, ) .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; Ok(AsyncZerobusArrowStream { inner: Arc::new(RwLock::new(stream)), @@ -722,7 +718,7 @@ pub fn recreate_arrow_stream_async<'py>( let stream = sdk_guard .recreate_arrow_stream(&*old_guard) .await - .map_err(|e| Python::with_gil(|_py| map_rust_error_to_pyerr(e)))?; + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; Ok(AsyncZerobusArrowStream { inner: Arc::new(RwLock::new(stream)),