diff --git a/.gitignore b/.gitignore index 60199ff..e31a805 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,9 @@ .bsp/ .bazelbsp/ +# Tools +pyrefly.toml + # OS .DS_Store Thumbs.db diff --git a/python/NEXT_CHANGELOG.md b/python/NEXT_CHANGELOG.md index d5e2f40..d735d2b 100644 --- a/python/NEXT_CHANGELOG.md +++ b/python/NEXT_CHANGELOG.md @@ -6,15 +6,35 @@ ### New Features and Improvements +- **Arrow Flight Support (Experimental)**: Added support for ingesting `pyarrow.RecordBatch` and `pyarrow.Table` objects via Arrow Flight protocol + - **Note**: Arrow Flight is not yet supported by default from the Zerobus server side. + - New `ZerobusArrowStream` class (sync in `zerobus.sdk.sync`, async in `zerobus.sdk.aio`) with `ingest_batch()`, `wait_for_offset()`, `flush()`, `close()`, `get_unacked_batches()` methods + - New `ArrowStreamConfigurationOptions` for configuring Arrow streams (max inflight batches, recovery, timeouts) + - New `create_arrow_stream()` and `recreate_arrow_stream()` methods on both sync and async `ZerobusSdk` + - Accepts both `pyarrow.RecordBatch` and `pyarrow.Table` (Tables are combined to a single batch internally) + - Arrow is opt-in: install via `pip install databricks-zerobus-ingest-sdk[arrow]` (requires `pyarrow>=14.0.0`) + - Arrow types gated behind `_core.arrow` submodule — not loaded unless pyarrow is installed + - Available from both `zerobus.sdk.sync` and `zerobus.sdk.aio`, and re-exported from top-level `zerobus` package + ### Bug Fixes ### Documentation ### Internal Changes +- Bumped Rust SDK dependency to v1.0.1 with `arrow-flight` feature +- Added `arrow-ipc`, `arrow-schema`, `arrow-array` (v56.2.0) Rust dependencies for IPC serialization +- Added PyO3 arrow module (`arrow.rs`) with `ArrowStreamConfigurationOptions`, `ZerobusArrowStream`, `AsyncZerobusArrowStream` pyclasses +- Added Python-side serialization helpers in `zerobus.sdk.shared.arrow` (`_serialize_schema`, `_serialize_batch`, `_deserialize_batch`) + ### Breaking Changes ### Deprecations ### API Changes +- Added `create_arrow_stream(table_name, schema, client_id, client_secret, options=None, headers_provider=None)` to sync and async `ZerobusSdk` +- Added `recreate_arrow_stream(old_stream)` to sync and async `ZerobusSdk` +- Added `ZerobusArrowStream` class (sync and async variants) with methods: `ingest_batch()`, `wait_for_offset()`, `flush()`, `close()`, `get_unacked_batches()`, properties: `is_closed`, `table_name` +- Added `ArrowStreamConfigurationOptions` class with fields: `max_inflight_batches`, `recovery`, `recovery_timeout_ms`, `recovery_backoff_ms`, `recovery_retries`, `server_lack_of_ack_timeout_ms`, `flush_timeout_ms`, `connection_timeout_ms` +- Added optional dependency: `pyarrow>=14.0.0` via `pip install databricks-zerobus-ingest-sdk[arrow]` diff --git a/python/examples/async_example_arrow.py b/python/examples/async_example_arrow.py new file mode 100644 index 0000000..08d07cd --- /dev/null +++ b/python/examples/async_example_arrow.py @@ -0,0 +1,200 @@ +""" +Asynchronous Ingestion Example - Arrow Flight Mode + +This example demonstrates record ingestion using the asynchronous API with Arrow Flight. + +Record Type Mode: Arrow (RecordBatch) + - Records are sent as pyarrow RecordBatches + - Uses Arrow Flight protocol for columnar data transfer + - Best for structured/columnar data, DataFrames, Parquet workflows + +Requirements: + pip install databricks-zerobus-ingest-sdk[arrow] + +Note: Arrow Flight support is experimental and not yet supported for production use. +""" + +import asyncio +import logging +import os +import time + +import pyarrow as pa + +from zerobus.sdk.aio import ZerobusSdk +from zerobus.sdk.shared.arrow import ArrowStreamConfigurationOptions + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +# Configuration - update these with your values +SERVER_ENDPOINT = os.getenv("ZEROBUS_SERVER_ENDPOINT", "https://your-shard-id.zerobus.region.cloud.databricks.com") +UNITY_CATALOG_ENDPOINT = os.getenv("DATABRICKS_WORKSPACE_URL", "https://your-workspace.cloud.databricks.com") +TABLE_NAME = os.getenv("ZEROBUS_TABLE_NAME", "catalog.schema.table") + +# For OAuth authentication +CLIENT_ID = os.getenv("DATABRICKS_CLIENT_ID", "your-oauth-client-id") +CLIENT_SECRET = os.getenv("DATABRICKS_CLIENT_SECRET", "your-oauth-client-secret") + +# Number of batches to ingest +NUM_BATCHES = 10 +ROWS_PER_BATCH = 100 + +# Define the Arrow schema +SCHEMA = pa.schema( + [ + ("device_name", pa.large_utf8()), + ("temp", pa.int32()), + ("humidity", pa.int64()), + ] +) + + +def create_sample_batch(batch_index): + """ + Creates a sample RecordBatch with air quality data. + + Returns a pyarrow.RecordBatch with ROWS_PER_BATCH rows. + """ + return pa.record_batch( + { + "device_name": [f"sensor-{(batch_index * ROWS_PER_BATCH + i) % 10}" for i in range(ROWS_PER_BATCH)], + "temp": [20 + ((batch_index * ROWS_PER_BATCH + i) % 15) for i in range(ROWS_PER_BATCH)], + "humidity": [50 + ((batch_index * ROWS_PER_BATCH + i) % 40) for i in range(ROWS_PER_BATCH)], + }, + schema=SCHEMA, + ) + + +async def main(): + print("Starting asynchronous ingestion example (Arrow Flight Mode)...") + print("=" * 60) + + # Check if credentials are configured + if CLIENT_ID == "your-oauth-client-id" or CLIENT_SECRET == "your-oauth-client-secret": + logger.error("Please set DATABRICKS_CLIENT_ID and DATABRICKS_CLIENT_SECRET environment variables") + return + + if SERVER_ENDPOINT == "https://your-shard-id.zerobus.region.cloud.databricks.com": + logger.error("Please set ZEROBUS_SERVER_ENDPOINT environment variable") + return + + if TABLE_NAME == "catalog.schema.table": + logger.error("Please set ZEROBUS_TABLE_NAME environment variable") + return + + try: + # Step 1: Initialize the SDK + sdk = ZerobusSdk(SERVER_ENDPOINT, UNITY_CATALOG_ENDPOINT) + logger.info("SDK initialized") + + # Step 2: Configure arrow stream options (all optional, shown with defaults) + options = ArrowStreamConfigurationOptions( + max_inflight_batches=10, + recovery=True, + recovery_timeout_ms=15000, + recovery_backoff_ms=2000, + recovery_retries=3, + ) + logger.info("Arrow stream configuration created") + + # Step 3: Create an Arrow Flight stream + # + # Pass a pyarrow.Schema - the SDK handles serialization internally. + # The SDK automatically: + # - Includes authorization header with OAuth token + # - Includes x-databricks-zerobus-table-name header + stream = await sdk.create_arrow_stream(TABLE_NAME, SCHEMA, CLIENT_ID, CLIENT_SECRET, options) + logger.info(f"Arrow stream created for table: {stream.table_name}") + + # Step 4: Ingest Arrow RecordBatches asynchronously + logger.info(f"\nIngesting {NUM_BATCHES} batches of {ROWS_PER_BATCH} rows each...") + start_time = time.time() + total_rows = 0 + + try: + # ======================================================================== + # Ingest RecordBatches - each call returns an offset + # ======================================================================== + offsets = [] + for i in range(NUM_BATCHES): + batch = create_sample_batch(i) + offset = await stream.ingest_batch(batch) + offsets.append(offset) + total_rows += batch.num_rows + logger.info(f" Batch {i + 1}: {batch.num_rows} rows, offset: {offset}") + + # ======================================================================== + # You can also ingest a pyarrow.Table directly + # The SDK converts it to a single RecordBatch internally + # ======================================================================== + table = pa.table( + { + "device_name": [f"sensor-table-{i}" for i in range(50)], + "temp": list(range(20, 70)), + "humidity": list(range(50, 100)), + }, + schema=SCHEMA, + ) + offset = await stream.ingest_batch(table) + offsets.append(offset) + total_rows += table.num_rows + logger.info(f" Table ingested: {table.num_rows} rows, offset: {offset}") + + submit_end_time = time.time() + submit_duration = submit_end_time - start_time + logger.info(f"\nAll batches submitted in {submit_duration:.2f} seconds") + + # ======================================================================== + # Wait for the last offset to be acknowledged + # ======================================================================== + logger.info(f"Waiting for offset {offsets[-1]} to be acknowledged...") + await stream.wait_for_offset(offsets[-1]) + logger.info(f" Offset {offsets[-1]} acknowledged") + + # Step 5: Flush and close the stream + logger.info("\nFlushing stream...") + await stream.flush() + logger.info("Stream flushed") + + end_time = time.time() + total_duration = end_time - start_time + rows_per_second = total_rows / total_duration + + await stream.close() + logger.info("Stream closed") + + # Print summary + print("\n" + "=" * 60) + print("Ingestion Summary:") + print(f" Total batches: {NUM_BATCHES + 1}") + print(f" Total rows: {total_rows}") + print(f" Submit time: {submit_duration:.2f} seconds") + print(f" Total time: {total_duration:.2f} seconds") + print(f" Throughput: {rows_per_second:.2f} rows/sec") + print(f" Record type: Arrow Flight (RecordBatch)") + print("=" * 60) + + except Exception as e: + logger.error(f"\nError during ingestion: {e}") + + # On failure, you can retrieve unacked batches for retry + if stream.is_closed: + unacked = await stream.get_unacked_batches() + if unacked: + logger.info(f" {len(unacked)} unacked batches available for retry") + for i, batch in enumerate(unacked): + logger.info(f" Batch {i}: {batch.num_rows} rows, schema: {batch.schema}") + + await stream.close() + raise + + except Exception as e: + logger.error(f"\nFailed to initialize stream: {e}") + raise + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/examples/sync_example_arrow.py b/python/examples/sync_example_arrow.py new file mode 100644 index 0000000..c3ea465 --- /dev/null +++ b/python/examples/sync_example_arrow.py @@ -0,0 +1,191 @@ +""" +Synchronous Ingestion Example - Arrow Flight Mode + +This example demonstrates record ingestion using the synchronous API with Arrow Flight. + +Record Type Mode: Arrow (RecordBatch) + - Records are sent as pyarrow RecordBatches + - Uses Arrow Flight protocol for columnar data transfer + - Best for structured/columnar data, DataFrames, Parquet workflows + +Requirements: + pip install databricks-zerobus-ingest-sdk[arrow] + +Note: Arrow Flight support is experimental and not yet supported for production use. +""" + +import logging +import os +import time + +import pyarrow as pa + +from zerobus.sdk.shared.arrow import ArrowStreamConfigurationOptions +from zerobus.sdk.sync import ZerobusSdk + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +# Configuration - update these with your values +SERVER_ENDPOINT = os.getenv("ZEROBUS_SERVER_ENDPOINT", "https://your-shard-id.zerobus.region.cloud.databricks.com") +UNITY_CATALOG_ENDPOINT = os.getenv("DATABRICKS_WORKSPACE_URL", "https://your-workspace.cloud.databricks.com") +TABLE_NAME = os.getenv("ZEROBUS_TABLE_NAME", "catalog.schema.table") + +# For OAuth authentication +CLIENT_ID = os.getenv("DATABRICKS_CLIENT_ID", "your-oauth-client-id") +CLIENT_SECRET = os.getenv("DATABRICKS_CLIENT_SECRET", "your-oauth-client-secret") + +# Number of batches to ingest +NUM_BATCHES = 10 +ROWS_PER_BATCH = 100 + +# Define the Arrow schema +SCHEMA = pa.schema( + [ + ("device_name", pa.large_utf8()), + ("temp", pa.int32()), + ("humidity", pa.int64()), + ] +) + + +def create_sample_batch(batch_index): + """ + Creates a sample RecordBatch with air quality data. + + Returns a pyarrow.RecordBatch with ROWS_PER_BATCH rows. + """ + return pa.record_batch( + { + "device_name": [f"sensor-{(batch_index * ROWS_PER_BATCH + i) % 10}" for i in range(ROWS_PER_BATCH)], + "temp": [20 + ((batch_index * ROWS_PER_BATCH + i) % 15) for i in range(ROWS_PER_BATCH)], + "humidity": [50 + ((batch_index * ROWS_PER_BATCH + i) % 40) for i in range(ROWS_PER_BATCH)], + }, + schema=SCHEMA, + ) + + +def main(): + print("Starting synchronous ingestion example (Arrow Flight Mode)...") + print("=" * 60) + + # Check if credentials are configured + if CLIENT_ID == "your-oauth-client-id" or CLIENT_SECRET == "your-oauth-client-secret": + logger.error("Please set DATABRICKS_CLIENT_ID and DATABRICKS_CLIENT_SECRET environment variables") + return + + if SERVER_ENDPOINT == "https://your-shard-id.zerobus.region.cloud.databricks.com": + logger.error("Please set ZEROBUS_SERVER_ENDPOINT environment variable") + return + + if TABLE_NAME == "catalog.schema.table": + logger.error("Please set ZEROBUS_TABLE_NAME environment variable") + return + + try: + # Step 1: Initialize the SDK + sdk = ZerobusSdk(SERVER_ENDPOINT, UNITY_CATALOG_ENDPOINT) + logger.info("SDK initialized") + + # Step 2: Configure arrow stream options (all optional, shown with defaults) + options = ArrowStreamConfigurationOptions( + max_inflight_batches=10, + recovery=True, + recovery_timeout_ms=15000, + recovery_backoff_ms=2000, + recovery_retries=3, + ) + logger.info("Arrow stream configuration created") + + # Step 3: Create an Arrow Flight stream + # + # Pass a pyarrow.Schema - the SDK handles serialization internally. + # The SDK automatically: + # - Includes authorization header with OAuth token + # - Includes x-databricks-zerobus-table-name header + stream = sdk.create_arrow_stream(TABLE_NAME, SCHEMA, CLIENT_ID, CLIENT_SECRET, options) + logger.info(f"Arrow stream created for table: {stream.table_name}") + + # Step 4: Ingest Arrow RecordBatches + logger.info(f"\nIngesting {NUM_BATCHES} batches of {ROWS_PER_BATCH} rows each...") + start_time = time.time() + total_rows = 0 + + try: + # ======================================================================== + # Ingest RecordBatches - each call returns an offset + # ======================================================================== + for i in range(NUM_BATCHES): + batch = create_sample_batch(i) + offset = stream.ingest_batch(batch) + total_rows += batch.num_rows + logger.info(f" Batch {i + 1}: {batch.num_rows} rows, offset: {offset}") + + # ======================================================================== + # You can also ingest a pyarrow.Table directly + # The SDK converts it to a single RecordBatch internally + # ======================================================================== + table = pa.table( + { + "device_name": [f"sensor-table-{i}" for i in range(50)], + "temp": list(range(20, 70)), + "humidity": list(range(50, 100)), + }, + schema=SCHEMA, + ) + offset = stream.ingest_batch(table) + total_rows += table.num_rows + logger.info(f" Table ingested: {table.num_rows} rows, offset: {offset}") + + # ======================================================================== + # Wait for a specific offset to be acknowledged + # ======================================================================== + logger.info(f"\nWaiting for offset {offset} to be acknowledged...") + stream.wait_for_offset(offset) + logger.info(f" Offset {offset} acknowledged") + + end_time = time.time() + duration_seconds = end_time - start_time + rows_per_second = total_rows / duration_seconds + + # Step 5: Flush and close the stream + logger.info("\nFlushing stream...") + stream.flush() + logger.info("Stream flushed") + + stream.close() + logger.info("Stream closed") + + # Print summary + print("\n" + "=" * 60) + print("Ingestion Summary:") + print(f" Total batches: {NUM_BATCHES + 1}") + print(f" Total rows: {total_rows}") + print(f" Duration: {duration_seconds:.2f} seconds") + print(f" Throughput: {rows_per_second:.2f} rows/sec") + print(f" Record type: Arrow Flight (RecordBatch)") + print("=" * 60) + + except Exception as e: + logger.error(f"\nError during ingestion: {e}") + + # On failure, you can retrieve unacked batches for retry + if stream.is_closed: + unacked = stream.get_unacked_batches() + if unacked: + logger.info(f" {len(unacked)} unacked batches available for retry") + for i, batch in enumerate(unacked): + logger.info(f" Batch {i}: {batch.num_rows} rows, schema: {batch.schema}") + + stream.close() + raise + + except Exception as e: + logger.error(f"\nFailed to initialize stream: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/python/pyproject.toml b/python/pyproject.toml index 4dac31d..cc547a1 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -28,6 +28,9 @@ dependencies = [ ] [project.optional-dependencies] +arrow = [ + "pyarrow>=14.0.0", +] dev = [ # Build tools "wheel", @@ -43,6 +46,7 @@ dev = [ "pytest-cov", "pytest-xdist>=3.6.1,<4.0", "pytest-asyncio", + "pyarrow>=14.0.0", ] [tool.maturin] diff --git a/python/rust/Cargo.toml b/python/rust/Cargo.toml index 9594fa6..871a79f 100644 --- a/python/rust/Cargo.toml +++ b/python/rust/Cargo.toml @@ -11,9 +11,12 @@ crate-type = ["cdylib"] pyo3 = { version = "0.20", features = ["extension-module", "abi3-py39"] } pyo3-asyncio = { version = "0.20", features = ["tokio-runtime"] } tokio = { version = "1", features = ["full"] } -databricks-zerobus-ingest-sdk = "1.0.0" +databricks-zerobus-ingest-sdk = { version = "1.0.1", features = ["arrow-flight"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] } async-trait = "0.1" prost = "0.13" prost-types = "0.13" tonic = "0.13" +arrow-ipc = { version = "56.2.0", default-features = false } +arrow-schema = { version = "56.2.0", default-features = false } +arrow-array = { version = "56.2.0", default-features = false } diff --git a/python/rust/src/arrow.rs b/python/rust/src/arrow.rs new file mode 100644 index 0000000..99fadf7 --- /dev/null +++ b/python/rust/src/arrow.rs @@ -0,0 +1,727 @@ +//! PyO3 bindings for Arrow Flight stream support. +//! +//! These types are always compiled into the wheel, but the Python-side API +//! gates usage on `pyarrow` being installed at runtime. + +use std::sync::Arc; + +use pyo3::prelude::*; +use pyo3::types::PyBytes; +use tokio::sync::RwLock; + +use databricks_zerobus_ingest_sdk::{ + ArrowStreamConfigurationOptions as RustArrowStreamOptions, + ArrowTableProperties as RustArrowTableProperties, + ZerobusArrowStream as RustZerobusArrowStream, ZerobusError as RustError, + ZerobusSdk as RustSdk, +}; + +use crate::auth::HeadersProviderWrapper; +use crate::common::map_error; + +/// Deserialize Arrow IPC bytes into exactly one RecordBatch. +fn ipc_bytes_to_record_batch( + ipc_bytes: &[u8], +) -> Result { + let mut reader = + arrow_ipc::reader::StreamReader::try_new(ipc_bytes, None).map_err(|e| { + RustError::InvalidArgument(format!("Failed to parse Arrow IPC data: {}", e)) + })?; + + let batch = reader + .next() + .ok_or_else(|| { + RustError::InvalidArgument("No batches found in Arrow IPC data".to_string()) + })? + .map_err(|e| { + RustError::InvalidArgument(format!("Failed to read Arrow batch: {}", e)) + })?; + + if reader.next().is_some() { + return Err(RustError::InvalidArgument( + "Expected exactly one RecordBatch in Arrow IPC data, found multiple".to_string(), + )); + } + + Ok(batch) +} + +/// Serialize a RecordBatch to Arrow IPC bytes. +fn record_batch_to_ipc_bytes( + batch: &arrow_array::RecordBatch, +) -> Result, RustError> { + let mut buffer = Vec::new(); + { + let mut writer = + arrow_ipc::writer::StreamWriter::try_new(&mut buffer, &batch.schema()) + .map_err(|e| { + RustError::InvalidArgument(format!( + "Failed to create Arrow IPC writer: {}", + e + )) + })?; + writer.write(batch).map_err(|e| { + RustError::InvalidArgument(format!("Failed to write Arrow batch: {}", e)) + })?; + writer.finish().map_err(|e| { + RustError::InvalidArgument(format!( + "Failed to finish Arrow IPC stream: {}", + e + )) + })?; + } + Ok(buffer) +} + +/// Build an ArrowSchema from Arrow IPC stream bytes (schema-only, no batches). +/// +/// Python side uses `pa.ipc.new_stream(sink, schema)` followed by `writer.close()` +/// to produce a full IPC stream containing only the schema message. +fn ipc_schema_bytes_to_arrow_schema( + schema_bytes: &[u8], +) -> Result { + let reader = arrow_ipc::reader::StreamReader::try_new(schema_bytes, None) + .map_err(|e| { + RustError::InvalidArgument(format!( + "Failed to parse Arrow IPC schema bytes: {}. \ + Pass bytes from pa.ipc.new_stream(sink, schema) with no batches written.", + e + )) + })?; + Ok(reader.schema().as_ref().clone()) +} + +// ============================================================================= +// ARROW STREAM CONFIGURATION OPTIONS +// ============================================================================= + +/// Configuration options for Arrow Flight streams. +#[pyclass] +#[derive(Clone)] +pub struct ArrowStreamConfigurationOptions { + #[pyo3(get, set)] + pub max_inflight_batches: i32, + + #[pyo3(get, set)] + pub recovery: bool, + + #[pyo3(get, set)] + pub recovery_timeout_ms: i64, + + #[pyo3(get, set)] + pub recovery_backoff_ms: i64, + + #[pyo3(get, set)] + pub recovery_retries: i32, + + #[pyo3(get, set)] + pub server_lack_of_ack_timeout_ms: i64, + + #[pyo3(get, set)] + pub flush_timeout_ms: i64, + + #[pyo3(get, set)] + pub connection_timeout_ms: i64, +} + +impl Default for ArrowStreamConfigurationOptions { + fn default() -> Self { + let rust_default = RustArrowStreamOptions::default(); + Self { + max_inflight_batches: rust_default.max_inflight_batches as i32, + recovery: rust_default.recovery, + recovery_timeout_ms: rust_default.recovery_timeout_ms as i64, + recovery_backoff_ms: rust_default.recovery_backoff_ms as i64, + recovery_retries: rust_default.recovery_retries as i32, + server_lack_of_ack_timeout_ms: rust_default.server_lack_of_ack_timeout_ms + as i64, + flush_timeout_ms: rust_default.flush_timeout_ms as i64, + connection_timeout_ms: rust_default.connection_timeout_ms as i64, + } + } +} + +#[pymethods] +impl ArrowStreamConfigurationOptions { + #[new] + #[pyo3(signature = (**kwargs))] + fn new(kwargs: Option<&pyo3::types::PyDict>) -> PyResult { + let mut options = Self::default(); + + if let Some(kwargs) = kwargs { + for (key, value) in kwargs { + let key_str: &str = key.extract()?; + match key_str { + "max_inflight_batches" => { + options.max_inflight_batches = value.extract()? + } + "recovery" => options.recovery = value.extract()?, + "recovery_timeout_ms" => { + options.recovery_timeout_ms = value.extract()? + } + "recovery_backoff_ms" => { + options.recovery_backoff_ms = value.extract()? + } + "recovery_retries" => options.recovery_retries = value.extract()?, + "server_lack_of_ack_timeout_ms" => { + options.server_lack_of_ack_timeout_ms = value.extract()? + } + "flush_timeout_ms" => options.flush_timeout_ms = value.extract()?, + "connection_timeout_ms" => { + options.connection_timeout_ms = value.extract()? + } + _ => { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Unknown configuration option: {}", + key_str + ))); + } + } + } + } + + Ok(options) + } + + fn __repr__(&self) -> String { + format!( + "ArrowStreamConfigurationOptions(max_inflight_batches={}, recovery={}, \ + recovery_timeout_ms={}, recovery_backoff_ms={}, recovery_retries={}, \ + server_lack_of_ack_timeout_ms={}, flush_timeout_ms={}, connection_timeout_ms={})", + self.max_inflight_batches, + self.recovery, + self.recovery_timeout_ms, + self.recovery_backoff_ms, + self.recovery_retries, + self.server_lack_of_ack_timeout_ms, + self.flush_timeout_ms, + self.connection_timeout_ms, + ) + } +} + +impl ArrowStreamConfigurationOptions { + pub fn to_rust(&self) -> Result { + if self.max_inflight_batches < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "max_inflight_batches must be non-negative", + )); + } + if self.recovery_timeout_ms < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "recovery_timeout_ms must be non-negative", + )); + } + if self.recovery_backoff_ms < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "recovery_backoff_ms must be non-negative", + )); + } + if self.recovery_retries < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "recovery_retries must be non-negative", + )); + } + if self.server_lack_of_ack_timeout_ms < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "server_lack_of_ack_timeout_ms must be non-negative", + )); + } + if self.flush_timeout_ms < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "flush_timeout_ms must be non-negative", + )); + } + if self.connection_timeout_ms < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "connection_timeout_ms must be non-negative", + )); + } + Ok(RustArrowStreamOptions { + max_inflight_batches: self.max_inflight_batches as usize, + recovery: self.recovery, + recovery_timeout_ms: self.recovery_timeout_ms as u64, + recovery_backoff_ms: self.recovery_backoff_ms as u64, + recovery_retries: self.recovery_retries as u32, + server_lack_of_ack_timeout_ms: self.server_lack_of_ack_timeout_ms as u64, + flush_timeout_ms: self.flush_timeout_ms as u64, + connection_timeout_ms: self.connection_timeout_ms as u64, + ipc_compression: None, + }) + } +} + +// ============================================================================= +// SYNC ARROW STREAM +// ============================================================================= + +/// Synchronous Arrow Flight stream for ingesting pyarrow RecordBatches. +#[pyclass] +pub struct ZerobusArrowStream { + inner: Arc>, + runtime: Arc, +} + +#[pymethods] +impl ZerobusArrowStream { + /// Ingest a single Arrow RecordBatch (as IPC bytes) and return the offset. + /// + /// Args: + /// ipc_bytes: Arrow IPC serialized bytes from pyarrow.RecordBatch.serialize() + fn ingest_batch(&self, py: Python, ipc_bytes: &PyBytes) -> PyResult { + // TODO(perf): eliminate double IPC serialization - Python-to-IPC-to-RecordBatch here, + // then RecordBatch-to-IPC again inside the Rust SDK for Flight. Pass IPC bytes + // directly to the SDK instead. + let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes()) + .map_err(|e| map_error(e))?; + + let stream_clone = self.inner.clone(); + let runtime = self.runtime.clone(); + + py.allow_threads(|| { + runtime.block_on(async move { + let stream_guard = stream_clone.read().await; + stream_guard + .ingest_batch(batch) + .await + .map_err(|e| Python::with_gil(|_py| map_error(e))) + }) + }) + } + + /// Wait for a specific offset to be acknowledged. + fn wait_for_offset(&self, py: Python, offset: i64) -> PyResult<()> { + let stream_clone = self.inner.clone(); + let runtime = self.runtime.clone(); + + py.allow_threads(|| { + runtime.block_on(async move { + let stream_guard = stream_clone.read().await; + stream_guard + .wait_for_offset(offset) + .await + .map_err(|e| Python::with_gil(|_py| map_error(e))) + }) + }) + } + + /// Flush all pending batches, waiting for acknowledgment. + fn flush(&self, py: Python) -> PyResult<()> { + let stream_clone = self.inner.clone(); + let runtime = self.runtime.clone(); + + py.allow_threads(|| { + runtime.block_on(async move { + let stream_guard = stream_clone.read().await; + stream_guard + .flush() + .await + .map_err(|e| Python::with_gil(|_py| map_error(e))) + }) + }) + } + + /// Close the stream gracefully. + fn close(&self, py: Python) -> PyResult<()> { + let stream_clone = self.inner.clone(); + let runtime = self.runtime.clone(); + + py.allow_threads(|| { + runtime.block_on(async move { + let mut stream_guard = stream_clone.write().await; + stream_guard + .close() + .await + .map_err(|e| Python::with_gil(|_py| map_error(e))) + }) + }) + } + + /// Check if the stream has been closed. + #[getter] + fn is_closed(&self, py: Python) -> bool { + let stream_clone = self.inner.clone(); + let runtime = self.runtime.clone(); + py.allow_threads(|| { + runtime.block_on(async move { + let stream_guard = stream_clone.read().await; + stream_guard.is_closed() + }) + }) + } + + /// Get the table name. + #[getter] + fn table_name(&self, py: Python) -> String { + let stream_clone = self.inner.clone(); + let runtime = self.runtime.clone(); + py.allow_threads(|| { + runtime.block_on(async move { + let stream_guard = stream_clone.read().await; + stream_guard.table_name().to_string() + }) + }) + } + + /// Get unacknowledged batches as a list of Arrow IPC byte buffers. + fn get_unacked_batches(&self, py: Python) -> PyResult> { + let stream_clone = self.inner.clone(); + let runtime = self.runtime.clone(); + + py.allow_threads(|| { + runtime.block_on(async move { + let stream_guard = stream_clone.read().await; + let batches = stream_guard + .get_unacked_batches() + .await + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; + + Python::with_gil(|py| { + let mut py_batches: Vec = Vec::with_capacity(batches.len()); + for batch in &batches { + let ipc_bytes = record_batch_to_ipc_bytes(batch) + .map_err(|e| map_error(e))?; + py_batches + .push(PyBytes::new(py, &ipc_bytes).into()); + } + Ok(py_batches) + }) + }) + }) + } +} + +// ============================================================================= +// ASYNC ARROW STREAM +// ============================================================================= + +/// Asynchronous Arrow Flight stream for ingesting pyarrow RecordBatches. +#[pyclass(name = "AsyncZerobusArrowStream")] +pub struct AsyncZerobusArrowStream { + inner: Arc>, +} + +#[pymethods] +impl AsyncZerobusArrowStream { + /// Ingest a single Arrow RecordBatch (as IPC bytes) and return the offset. + fn ingest_batch<'py>( + &self, + py: Python<'py>, + ipc_bytes: &PyBytes, + ) -> PyResult<&'py PyAny> { + let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes()) + .map_err(|e| map_error(e))?; + + let stream_clone = self.inner.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let stream_guard = stream_clone.read().await; + stream_guard + .ingest_batch(batch) + .await + .map_err(|e| Python::with_gil(|_py| map_error(e))) + }) + } + + /// Wait for a specific offset to be acknowledged. + fn wait_for_offset<'py>( + &self, + py: Python<'py>, + offset: i64, + ) -> PyResult<&'py PyAny> { + let stream_clone = self.inner.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let stream_guard = stream_clone.read().await; + stream_guard + .wait_for_offset(offset) + .await + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; + Ok(()) + }) + } + + /// Flush all pending batches. + fn flush<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> { + let stream_clone = self.inner.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let stream_guard = stream_clone.read().await; + stream_guard + .flush() + .await + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; + Ok(()) + }) + } + + /// Close the stream gracefully. + fn close<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> { + let stream_clone = self.inner.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let mut stream_guard = stream_clone.write().await; + stream_guard + .close() + .await + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; + Ok(()) + }) + } + + /// Check if the stream has been closed. + #[getter] + fn is_closed(&self) -> PyResult { + match self.inner.try_read() { + Ok(stream_guard) => Ok(stream_guard.is_closed()), + Err(_) => Err(pyo3::exceptions::PyRuntimeError::new_err( + "Cannot read stream state: lock is held by another operation", + )), + } + } + + /// Get the table name. + #[getter] + fn table_name(&self) -> PyResult { + match self.inner.try_read() { + Ok(stream_guard) => Ok(stream_guard.table_name().to_string()), + Err(_) => Err(pyo3::exceptions::PyRuntimeError::new_err( + "Cannot read stream state: lock is held by another operation", + )), + } + } + + /// Get unacknowledged batches as a list of Arrow IPC byte buffers. + fn get_unacked_batches<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> { + let stream_clone = self.inner.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let stream_guard = stream_clone.read().await; + let batches = stream_guard + .get_unacked_batches() + .await + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; + + Python::with_gil(|py| { + let mut py_batches: Vec = Vec::with_capacity(batches.len()); + for batch in &batches { + let ipc_bytes = record_batch_to_ipc_bytes(batch) + .map_err(|e| map_error(e))?; + py_batches.push(PyBytes::new(py, &ipc_bytes).into()); + } + Ok(py_batches) + }) + }) + } +} + +// ============================================================================= +// SDK METHODS — called from sync_wrapper and async_wrapper +// ============================================================================= + +/// Create an Arrow stream (sync helper). +pub fn create_arrow_stream_sync( + sdk: &Arc>, + runtime: &Arc, + py: Python, + table_name: String, + schema_ipc_bytes: &[u8], + client_id: String, + client_secret: String, + options: Option<&ArrowStreamConfigurationOptions>, +) -> PyResult { + let schema = ipc_schema_bytes_to_arrow_schema(schema_ipc_bytes) + .map_err(|e| map_error(e))?; + + let table_props = RustArrowTableProperties { + table_name, + schema: Arc::new(schema), + }; + + let rust_options = options.map(|o| o.to_rust()).transpose()?; + + let sdk_clone = sdk.clone(); + let runtime_clone = runtime.clone(); + + let stream = py.allow_threads(|| { + runtime_clone.block_on(async move { + let sdk_guard = sdk_clone.read().await; + sdk_guard + .create_arrow_stream(table_props, client_id, client_secret, rust_options) + .await + .map_err(|e| Python::with_gil(|_py| map_error(e))) + }) + })?; + + Ok(ZerobusArrowStream { + inner: Arc::new(RwLock::new(stream)), + runtime: runtime.clone(), + }) +} + +/// Create an Arrow stream with headers provider (sync helper). +pub fn create_arrow_stream_with_headers_provider_sync( + sdk: &Arc>, + runtime: &Arc, + py: Python, + table_name: String, + schema_ipc_bytes: &[u8], + headers_provider: PyObject, + options: Option<&ArrowStreamConfigurationOptions>, +) -> PyResult { + let schema = ipc_schema_bytes_to_arrow_schema(schema_ipc_bytes) + .map_err(|e| map_error(e))?; + + let table_props = RustArrowTableProperties { + table_name, + schema: Arc::new(schema), + }; + + let rust_options = options.map(|o| o.to_rust()).transpose()?; + let provider = Arc::new(HeadersProviderWrapper::new(headers_provider)); + + let sdk_clone = sdk.clone(); + let runtime_clone = runtime.clone(); + + let stream = py.allow_threads(|| { + runtime_clone.block_on(async move { + let sdk_guard = sdk_clone.read().await; + sdk_guard + .create_arrow_stream_with_headers_provider( + table_props, + provider, + rust_options, + ) + .await + .map_err(|e| Python::with_gil(|_py| map_error(e))) + }) + })?; + + Ok(ZerobusArrowStream { + inner: Arc::new(RwLock::new(stream)), + runtime: runtime.clone(), + }) +} + +/// Recreate an Arrow stream from a closed stream (sync helper). +pub fn recreate_arrow_stream_sync( + sdk: &Arc>, + runtime: &Arc, + py: Python, + old_stream: &ZerobusArrowStream, +) -> PyResult { + let sdk_clone = sdk.clone(); + let old_inner = old_stream.inner.clone(); + let runtime_clone = runtime.clone(); + + let stream = py.allow_threads(|| { + runtime_clone.block_on(async move { + let old_guard = old_inner.read().await; + let sdk_guard = sdk_clone.read().await; + sdk_guard + .recreate_arrow_stream(&*old_guard) + .await + .map_err(|e| Python::with_gil(|_py| map_error(e))) + }) + })?; + + Ok(ZerobusArrowStream { + inner: Arc::new(RwLock::new(stream)), + runtime: runtime.clone(), + }) +} + +/// Create an Arrow stream (async helper). +pub fn create_arrow_stream_async<'py>( + sdk: &Arc>, + py: Python<'py>, + table_name: String, + schema_ipc_bytes: Vec, + client_id: String, + client_secret: String, + options: Option, +) -> PyResult<&'py PyAny> { + let schema = ipc_schema_bytes_to_arrow_schema(&schema_ipc_bytes) + .map_err(|e| map_error(e))?; + + let table_props = RustArrowTableProperties { + table_name, + schema: Arc::new(schema), + }; + + let rust_options = options.map(|o| o.to_rust()).transpose()?; + let sdk_clone = sdk.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let sdk_guard = sdk_clone.read().await; + let stream = sdk_guard + .create_arrow_stream(table_props, client_id, client_secret, rust_options) + .await + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; + + Ok(AsyncZerobusArrowStream { + inner: Arc::new(RwLock::new(stream)), + }) + }) +} + +/// Create an Arrow stream with headers provider (async helper). +pub fn create_arrow_stream_with_headers_provider_async<'py>( + sdk: &Arc>, + py: Python<'py>, + table_name: String, + schema_ipc_bytes: Vec, + headers_provider: PyObject, + options: Option, +) -> PyResult<&'py PyAny> { + let schema = ipc_schema_bytes_to_arrow_schema(&schema_ipc_bytes) + .map_err(|e| map_error(e))?; + + let table_props = RustArrowTableProperties { + table_name, + schema: Arc::new(schema), + }; + + let rust_options = options.map(|o| o.to_rust()).transpose()?; + let provider = Arc::new(HeadersProviderWrapper::new(headers_provider)); + let sdk_clone = sdk.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let sdk_guard = sdk_clone.read().await; + let stream = sdk_guard + .create_arrow_stream_with_headers_provider( + table_props, + provider, + rust_options, + ) + .await + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; + + Ok(AsyncZerobusArrowStream { + inner: Arc::new(RwLock::new(stream)), + }) + }) +} + +/// Recreate an Arrow stream from a closed stream (async helper). +pub fn recreate_arrow_stream_async<'py>( + sdk: &Arc>, + py: Python<'py>, + old_stream: &AsyncZerobusArrowStream, +) -> PyResult<&'py PyAny> { + let sdk_clone = sdk.clone(); + let old_inner = old_stream.inner.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let old_guard = old_inner.read().await; + let sdk_guard = sdk_clone.read().await; + let stream = sdk_guard + .recreate_arrow_stream(&*old_guard) + .await + .map_err(|e| Python::with_gil(|_py| map_error(e)))?; + + Ok(AsyncZerobusArrowStream { + inner: Arc::new(RwLock::new(stream)), + }) + }) +} diff --git a/python/rust/src/async_wrapper.rs b/python/rust/src/async_wrapper.rs index d4aa79f..7a4a6e4 100644 --- a/python/rust/src/async_wrapper.rs +++ b/python/rust/src/async_wrapper.rs @@ -14,6 +14,7 @@ use databricks_zerobus_ingest_sdk::{ ZerobusStream as RustStream, }; +use crate::arrow::{self, ArrowStreamConfigurationOptions, AsyncZerobusArrowStream}; use crate::auth::HeadersProviderWrapper; use crate::common::{map_error, AckCallback, StreamConfigurationOptions, TableProperties}; @@ -424,7 +425,7 @@ impl ZerobusSdk { descriptor_proto: table_properties.descriptor_proto.clone(), }; - let opts = convert_stream_options(options); + let opts = convert_stream_options(options)?; future_into_py(py, async move { let sdk_guard = sdk.read().await; @@ -454,7 +455,7 @@ impl ZerobusSdk { descriptor_proto: table_properties.descriptor_proto.clone(), }; - let opts = convert_stream_options(options); + let opts = convert_stream_options(options)?; let wrapper = Arc::new(HeadersProviderWrapper::new(headers_provider)); future_into_py(py, async move { @@ -469,6 +470,57 @@ impl ZerobusSdk { }) } + /// Create a new Arrow Flight stream with OAuth authentication (async). + #[pyo3(signature = (table_name, schema_ipc_bytes, client_id, client_secret, options = None))] + fn create_arrow_stream<'py>( + &self, + py: Python<'py>, + table_name: String, + schema_ipc_bytes: Vec, + client_id: String, + client_secret: String, + options: Option, + ) -> PyResult<&'py PyAny> { + arrow::create_arrow_stream_async( + &self.inner, + py, + table_name, + schema_ipc_bytes, + client_id, + client_secret, + options, + ) + } + + /// Create a new Arrow Flight stream with custom headers provider (async). + #[pyo3(signature = (table_name, schema_ipc_bytes, headers_provider, options = None))] + fn create_arrow_stream_with_headers_provider<'py>( + &self, + py: Python<'py>, + table_name: String, + schema_ipc_bytes: Vec, + headers_provider: PyObject, + options: Option, + ) -> PyResult<&'py PyAny> { + arrow::create_arrow_stream_with_headers_provider_async( + &self.inner, + py, + table_name, + schema_ipc_bytes, + headers_provider, + options, + ) + } + + /// Recreate a closed Arrow stream (async). + fn recreate_arrow_stream<'py>( + &self, + py: Python<'py>, + old_stream: &AsyncZerobusArrowStream, + ) -> PyResult<&'py PyAny> { + arrow::recreate_arrow_stream_async(&self.inner, py, old_stream) + } + /// Recreate a stream from an old stream (async) fn recreate_stream<'py>( &self, @@ -496,30 +548,34 @@ impl ZerobusSdk { // Helper to convert Python StreamConfigurationOptions to Rust options fn convert_stream_options( options: Option<&StreamConfigurationOptions>, -) -> Option { - options.map(|opts| { - let ack_callback = opts - .ack_callback - .clone() - .map(|cb| Arc::new(AckCallbackWrapper::new(cb)) as Arc); - - RustStreamOptions { - max_inflight_requests: opts.max_inflight_records as usize, - recovery: opts.recovery, - recovery_timeout_ms: opts.recovery_timeout_ms as u64, - recovery_backoff_ms: opts.recovery_backoff_ms as u64, - recovery_retries: opts.recovery_retries as u32, - server_lack_of_ack_timeout_ms: opts.server_lack_of_ack_timeout_ms as u64, - flush_timeout_ms: opts.flush_timeout_ms as u64, - record_type: match opts.record_type.value { - 1 => RustRecordType::Proto, - 2 => RustRecordType::Json, - _ => RustRecordType::Proto, - }, - stream_paused_max_wait_time_ms: opts.stream_paused_max_wait_time_ms.map(|v| v as u64), - callback_max_wait_time_ms: opts.callback_max_wait_time_ms.map(|v| v as u64), - ack_callback, - ..Default::default() +) -> PyResult> { + match options { + Some(opts) => { + opts.validate()?; + let ack_callback = opts + .ack_callback + .clone() + .map(|cb| Arc::new(AckCallbackWrapper::new(cb)) as Arc); + + Ok(Some(RustStreamOptions { + max_inflight_requests: opts.max_inflight_records as usize, + recovery: opts.recovery, + recovery_timeout_ms: opts.recovery_timeout_ms as u64, + recovery_backoff_ms: opts.recovery_backoff_ms as u64, + recovery_retries: opts.recovery_retries as u32, + server_lack_of_ack_timeout_ms: opts.server_lack_of_ack_timeout_ms as u64, + flush_timeout_ms: opts.flush_timeout_ms as u64, + record_type: match opts.record_type.value { + 1 => RustRecordType::Proto, + 2 => RustRecordType::Json, + _ => RustRecordType::Proto, + }, + stream_paused_max_wait_time_ms: opts.stream_paused_max_wait_time_ms.map(|v| v as u64), + callback_max_wait_time_ms: opts.callback_max_wait_time_ms.map(|v| v as u64), + ack_callback, + ..Default::default() + })) } - }) + None => Ok(None), + } } diff --git a/python/rust/src/common.rs b/python/rust/src/common.rs index f62b578..84bc065 100644 --- a/python/rust/src/common.rs +++ b/python/rust/src/common.rs @@ -207,6 +207,57 @@ pub struct StreamConfigurationOptions { pub ack_callback: Option>, } +impl StreamConfigurationOptions { + /// Validate that all numeric fields are non-negative before casting to unsigned types. + pub fn validate(&self) -> PyResult<()> { + if self.max_inflight_records < 0 { + return Err(PyValueError::new_err( + "max_inflight_records must be non-negative", + )); + } + if self.recovery_timeout_ms < 0 { + return Err(PyValueError::new_err( + "recovery_timeout_ms must be non-negative", + )); + } + if self.recovery_backoff_ms < 0 { + return Err(PyValueError::new_err( + "recovery_backoff_ms must be non-negative", + )); + } + if self.recovery_retries < 0 { + return Err(PyValueError::new_err( + "recovery_retries must be non-negative", + )); + } + if self.server_lack_of_ack_timeout_ms < 0 { + return Err(PyValueError::new_err( + "server_lack_of_ack_timeout_ms must be non-negative", + )); + } + if self.flush_timeout_ms < 0 { + return Err(PyValueError::new_err( + "flush_timeout_ms must be non-negative", + )); + } + if let Some(v) = self.stream_paused_max_wait_time_ms { + if v < 0 { + return Err(PyValueError::new_err( + "stream_paused_max_wait_time_ms must be non-negative", + )); + } + } + if let Some(v) = self.callback_max_wait_time_ms { + if v < 0 { + return Err(PyValueError::new_err( + "callback_max_wait_time_ms must be non-negative", + )); + } + } + Ok(()) + } +} + impl Default for StreamConfigurationOptions { fn default() -> Self { Self { diff --git a/python/rust/src/lib.rs b/python/rust/src/lib.rs index 287347e..b8f2eb2 100644 --- a/python/rust/src/lib.rs +++ b/python/rust/src/lib.rs @@ -2,6 +2,7 @@ use pyo3::prelude::*; +mod arrow; mod async_wrapper; mod auth; mod common; @@ -43,12 +44,23 @@ fn _zerobus_core(py: Python, m: &PyModule) -> PyResult<()> { // Add authentication classes m.add_class::()?; + let sys_modules = py.import("sys")?.getattr("modules")?; + + // Add arrow submodule + let arrow_module = PyModule::new(py, "arrow")?; + arrow_module.add_class::()?; + arrow_module.add_class::()?; + arrow_module.add_class::()?; + m.add_submodule(arrow_module)?; + sys_modules.set_item("zerobus._zerobus_core.arrow", arrow_module)?; + // Add sync submodule let sync_module = PyModule::new(py, "sync")?; sync_module.add_class::()?; sync_module.add_class::()?; sync_module.add_class::()?; m.add_submodule(sync_module)?; + sys_modules.set_item("zerobus._zerobus_core.sync", sync_module)?; // Add aio (async) submodule let aio_module = PyModule::new(py, "aio")?; @@ -56,6 +68,7 @@ fn _zerobus_core(py: Python, m: &PyModule) -> PyResult<()> { aio_module.add_class::()?; aio_module.add_class::()?; m.add_submodule(aio_module)?; + sys_modules.set_item("zerobus._zerobus_core.aio", aio_module)?; Ok(()) } diff --git a/python/rust/src/sync_wrapper.rs b/python/rust/src/sync_wrapper.rs index 2491aee..4df2fd7 100644 --- a/python/rust/src/sync_wrapper.rs +++ b/python/rust/src/sync_wrapper.rs @@ -14,6 +14,7 @@ use databricks_zerobus_ingest_sdk::{ ZerobusSdk as RustSdk, ZerobusStream as RustStream, }; +use crate::arrow::{self, ArrowStreamConfigurationOptions, ZerobusArrowStream}; use crate::auth::HeadersProviderWrapper; use crate::common::{map_error, AckCallback, StreamConfigurationOptions, TableProperties}; @@ -448,6 +449,7 @@ impl ZerobusSdk { }; let rust_options = if let Some(opts) = options.clone() { + opts.validate()?; let ack_callback = opts .ack_callback .map(|cb| Arc::new(AckCallbackWrapper::new(cb)) as Arc); @@ -515,6 +517,7 @@ impl ZerobusSdk { }; let rust_options = if let Some(opts) = options.clone() { + opts.validate()?; let ack_callback = opts .ack_callback .map(|cb| Arc::new(AckCallbackWrapper::new(cb)) as Arc); @@ -567,6 +570,66 @@ impl ZerobusSdk { }) } + /// Create a new Arrow Flight stream with OAuth authentication. + /// + /// Args: + /// table_name: Fully qualified table name (catalog.schema.table) + /// schema_ipc_bytes: Arrow IPC serialized schema bytes + /// client_id: OAuth client ID + /// client_secret: OAuth client secret + /// options: Optional ArrowStreamConfigurationOptions + #[pyo3(signature = (table_name, schema_ipc_bytes, client_id, client_secret, options = None))] + fn create_arrow_stream( + &self, + py: Python, + table_name: String, + schema_ipc_bytes: &PyBytes, + client_id: String, + client_secret: String, + options: Option<&ArrowStreamConfigurationOptions>, + ) -> PyResult { + arrow::create_arrow_stream_sync( + &self.inner, + &self.runtime, + py, + table_name, + schema_ipc_bytes.as_bytes(), + client_id, + client_secret, + options, + ) + } + + /// Create a new Arrow Flight stream with custom headers provider. + #[pyo3(signature = (table_name, schema_ipc_bytes, headers_provider, options = None))] + fn create_arrow_stream_with_headers_provider( + &self, + py: Python, + table_name: String, + schema_ipc_bytes: &PyBytes, + headers_provider: PyObject, + options: Option<&ArrowStreamConfigurationOptions>, + ) -> PyResult { + arrow::create_arrow_stream_with_headers_provider_sync( + &self.inner, + &self.runtime, + py, + table_name, + schema_ipc_bytes.as_bytes(), + headers_provider, + options, + ) + } + + /// Recreate a closed Arrow stream with the same configuration. + fn recreate_arrow_stream( + &self, + py: Python, + old_stream: &ZerobusArrowStream, + ) -> PyResult { + arrow::recreate_arrow_stream_sync(&self.inner, &self.runtime, py, old_stream) + } + /// Recreate a closed stream with the same configuration fn recreate_stream(&self, py: Python, old_stream: &ZerobusStream) -> PyResult { let sdk = self.inner.clone(); diff --git a/python/tests/test_arrow.py b/python/tests/test_arrow.py new file mode 100644 index 0000000..a34453d --- /dev/null +++ b/python/tests/test_arrow.py @@ -0,0 +1,389 @@ +""" +Tests for Arrow Flight support in the Python SDK. + +These tests verify the Python-side Arrow serialization/deserialization helpers, +the ArrowStreamConfigurationOptions pyclass, and the API surface of Arrow stream +classe, without making network connections. +""" + +import unittest + +import pyarrow as pa + +from zerobus.sdk.shared.arrow import ( + ArrowStreamConfigurationOptions, + _check_pyarrow, + _deserialize_batch, + _serialize_batch, + _serialize_schema, +) + + +class TestCheckPyarrow(unittest.TestCase): + """Test the pyarrow availability check.""" + + def test_check_pyarrow_returns_module(self): + result = _check_pyarrow() + self.assertIs(result, pa) + + +class TestSerializeSchema(unittest.TestCase): + """Test schema serialization to IPC bytes.""" + + def test_simple_schema(self): + schema = pa.schema([("a", pa.int64()), ("b", pa.utf8())]) + ipc_bytes = _serialize_schema(schema) + self.assertIsInstance(ipc_bytes, bytes) + self.assertGreater(len(ipc_bytes), 0) + + def test_schema_with_various_types(self): + schema = pa.schema( + [ + ("int_col", pa.int32()), + ("float_col", pa.float64()), + ("str_col", pa.large_utf8()), + ("bool_col", pa.bool_()), + ("list_col", pa.list_(pa.int64())), + ] + ) + ipc_bytes = _serialize_schema(schema) + self.assertIsInstance(ipc_bytes, bytes) + + def test_roundtrip_schema_via_ipc(self): + schema = pa.schema([("x", pa.int64()), ("y", pa.float32())]) + ipc_bytes = _serialize_schema(schema) + reader = pa.ipc.open_stream(ipc_bytes) + recovered = reader.schema + self.assertEqual(schema, recovered) + + def test_rejects_non_schema(self): + with self.assertRaises(TypeError): + _serialize_schema("not a schema") + + def test_rejects_record_batch(self): + batch = pa.record_batch({"a": [1]}, schema=pa.schema([("a", pa.int64())])) + with self.assertRaises(TypeError): + _serialize_schema(batch) + + +class TestSerializeBatch(unittest.TestCase): + """Test RecordBatch/Table serialization to IPC bytes.""" + + def test_simple_batch(self): + schema = pa.schema([("a", pa.int64())]) + batch = pa.record_batch({"a": [1, 2, 3]}, schema=schema) + ipc_bytes = _serialize_batch(batch) + self.assertIsInstance(ipc_bytes, bytes) + self.assertGreater(len(ipc_bytes), 0) + + def test_batch_with_multiple_columns(self): + schema = pa.schema([("x", pa.int32()), ("y", pa.utf8()), ("z", pa.float64())]) + batch = pa.record_batch( + {"x": [1, 2], "y": ["a", "b"], "z": [1.0, 2.0]}, + schema=schema, + ) + ipc_bytes = _serialize_batch(batch) + self.assertIsInstance(ipc_bytes, bytes) + + def test_table_single_chunk(self): + schema = pa.schema([("a", pa.int64())]) + table = pa.table({"a": [1, 2, 3]}, schema=schema) + ipc_bytes = _serialize_batch(table) + self.assertIsInstance(ipc_bytes, bytes) + + def test_table_multiple_chunks(self): + schema = pa.schema([("a", pa.int64())]) + batch1 = pa.record_batch({"a": [1, 2]}, schema=schema) + batch2 = pa.record_batch({"a": [3, 4]}, schema=schema) + table = pa.Table.from_batches([batch1, batch2]) + ipc_bytes = _serialize_batch(table) + # Should combine chunks into a single batch + recovered = _deserialize_batch(ipc_bytes) + self.assertEqual(recovered.num_rows, 4) + + def test_empty_table_raises(self): + schema = pa.schema([("a", pa.int64())]) + table = pa.table({"a": pa.array([], type=pa.int64())}, schema=schema) + with self.assertRaises(ValueError) as cm: + _serialize_batch(table) + self.assertIn("empty", str(cm.exception).lower()) + + def test_rejects_wrong_type(self): + with self.assertRaises(TypeError): + _serialize_batch("not a batch") + + def test_rejects_dict(self): + with self.assertRaises(TypeError): + _serialize_batch({"a": [1, 2, 3]}) + + def test_rejects_schema(self): + with self.assertRaises(TypeError): + _serialize_batch(pa.schema([("a", pa.int64())])) + + +class TestDeserializeBatch(unittest.TestCase): + """Test IPC bytes deserialization back to RecordBatch.""" + + def test_roundtrip(self): + schema = pa.schema([("a", pa.int64()), ("b", pa.utf8())]) + original = pa.record_batch( + {"a": [1, 2, 3], "b": ["x", "y", "z"]}, + schema=schema, + ) + ipc_bytes = _serialize_batch(original) + recovered = _deserialize_batch(ipc_bytes) + + self.assertIsInstance(recovered, pa.RecordBatch) + self.assertEqual(recovered.schema, original.schema) + self.assertEqual(recovered.num_rows, original.num_rows) + self.assertEqual(recovered.to_pydict(), original.to_pydict()) + + def test_roundtrip_with_nulls(self): + schema = pa.schema([("a", pa.int64()), ("b", pa.utf8())]) + original = pa.record_batch( + {"a": [1, None, 3], "b": [None, "y", None]}, + schema=schema, + ) + ipc_bytes = _serialize_batch(original) + recovered = _deserialize_batch(ipc_bytes) + + self.assertEqual(recovered.to_pydict(), original.to_pydict()) + + def test_roundtrip_table(self): + schema = pa.schema([("val", pa.float64())]) + table = pa.table({"val": [1.1, 2.2, 3.3]}, schema=schema) + ipc_bytes = _serialize_batch(table) + recovered = _deserialize_batch(ipc_bytes) + + self.assertIsInstance(recovered, pa.RecordBatch) + self.assertEqual(recovered.num_rows, 3) + + def test_invalid_bytes_raises(self): + with self.assertRaises(Exception): + _deserialize_batch(b"not valid ipc data") + + +class TestArrowStreamConfigurationOptions(unittest.TestCase): + """Test ArrowStreamConfigurationOptions pyclass.""" + + def test_default_construction(self): + options = ArrowStreamConfigurationOptions() + self.assertIsInstance(options.max_inflight_batches, int) + self.assertIsInstance(options.recovery, bool) + self.assertIsInstance(options.recovery_timeout_ms, int) + self.assertIsInstance(options.recovery_backoff_ms, int) + self.assertIsInstance(options.recovery_retries, int) + self.assertIsInstance(options.server_lack_of_ack_timeout_ms, int) + self.assertIsInstance(options.flush_timeout_ms, int) + self.assertIsInstance(options.connection_timeout_ms, int) + + def test_kwargs_construction(self): + options = ArrowStreamConfigurationOptions( + max_inflight_batches=5, + recovery=False, + flush_timeout_ms=3000, + ) + self.assertEqual(options.max_inflight_batches, 5) + self.assertFalse(options.recovery) + self.assertEqual(options.flush_timeout_ms, 3000) + + def test_all_kwargs(self): + options = ArrowStreamConfigurationOptions( + max_inflight_batches=20, + recovery=True, + recovery_timeout_ms=10000, + recovery_backoff_ms=500, + recovery_retries=10, + server_lack_of_ack_timeout_ms=30000, + flush_timeout_ms=5000, + connection_timeout_ms=8000, + ) + self.assertEqual(options.max_inflight_batches, 20) + self.assertTrue(options.recovery) + self.assertEqual(options.recovery_timeout_ms, 10000) + self.assertEqual(options.recovery_backoff_ms, 500) + self.assertEqual(options.recovery_retries, 10) + self.assertEqual(options.server_lack_of_ack_timeout_ms, 30000) + self.assertEqual(options.flush_timeout_ms, 5000) + self.assertEqual(options.connection_timeout_ms, 8000) + + def test_unknown_kwarg_raises(self): + with self.assertRaises(ValueError): + ArrowStreamConfigurationOptions(nonexistent_option=42) + + def test_setters(self): + options = ArrowStreamConfigurationOptions() + options.max_inflight_batches = 99 + options.recovery = False + self.assertEqual(options.max_inflight_batches, 99) + self.assertFalse(options.recovery) + + def test_repr(self): + options = ArrowStreamConfigurationOptions() + repr_str = repr(options) + self.assertIn("ArrowStreamConfigurationOptions", repr_str) + self.assertIn("max_inflight_batches", repr_str) + self.assertIn("recovery", repr_str) + + +class TestSerializeBatchEmptyRecordBatch(unittest.TestCase): + """Test that empty RecordBatch (not Table) is handled correctly.""" + + def test_empty_record_batch_serializes(self): + """An empty RecordBatch (0 rows) should still serialize — only empty Tables are rejected.""" + schema = pa.schema([("a", pa.int64())]) + batch = pa.record_batch({"a": pa.array([], type=pa.int64())}, schema=schema) + # RecordBatch with 0 rows is valid — the check is only on Table + ipc_bytes = _serialize_batch(batch) + self.assertIsInstance(ipc_bytes, bytes) + recovered = _deserialize_batch(ipc_bytes) + self.assertEqual(recovered.num_rows, 0) + + +class TestArrowConfigNegativeValues(unittest.TestCase): + """Test that negative config values are rejected at stream creation time.""" + + def test_negative_max_inflight_batches(self): + options = ArrowStreamConfigurationOptions(max_inflight_batches=-1) + self.assertEqual(options.max_inflight_batches, -1) + # Validation happens in to_rust() at stream creation, not at construction. + # We can't test the Rust-side validation without a server, but we verify + # the value is accepted by the Python constructor and stored. + + def test_negative_recovery_timeout_ms(self): + options = ArrowStreamConfigurationOptions(recovery_timeout_ms=-100) + self.assertEqual(options.recovery_timeout_ms, -100) + + def test_negative_flush_timeout_ms(self): + options = ArrowStreamConfigurationOptions(flush_timeout_ms=-1) + self.assertEqual(options.flush_timeout_ms, -1) + + def test_negative_connection_timeout_ms(self): + options = ArrowStreamConfigurationOptions(connection_timeout_ms=-500) + self.assertEqual(options.connection_timeout_ms, -500) + + +class TestAsyncArrowStreamAPISurface(unittest.TestCase): + """Test that AsyncZerobusArrowStream has the expected async API.""" + + def test_async_arrow_stream_methods_are_coroutines(self): + """Verify async stream methods are actual coroutine functions.""" + from zerobus.sdk.aio import ZerobusArrowStream + + # These methods should be present on the wrapper class + for method_name in ["ingest_batch", "wait_for_offset", "flush", "close", "get_unacked_batches"]: + self.assertTrue( + hasattr(ZerobusArrowStream, method_name), + f"AsyncZerobusArrowStream missing method: {method_name}", + ) + + def test_async_sdk_has_create_and_recreate(self): + """Verify async SDK has arrow stream creation methods.""" + import inspect + + from zerobus.sdk.aio import ZerobusSdk + + self.assertTrue(hasattr(ZerobusSdk, "create_arrow_stream")) + self.assertTrue(hasattr(ZerobusSdk, "recreate_arrow_stream")) + # create_arrow_stream should be a coroutine function + self.assertTrue( + inspect.iscoroutinefunction(ZerobusSdk.create_arrow_stream), + "create_arrow_stream should be async", + ) + + +class TestArrowImports(unittest.TestCase): + """Test that Arrow types can be imported from expected locations.""" + + def test_import_from_shared_arrow(self): + from zerobus.sdk.shared.arrow import ArrowStreamConfigurationOptions + + self.assertIsNotNone(ArrowStreamConfigurationOptions) + + def test_import_from_top_level(self): + from zerobus import ArrowStreamConfigurationOptions + + self.assertIsNotNone(ArrowStreamConfigurationOptions) + + def test_import_arrow_stream_from_sync(self): + from zerobus.sdk.sync import ZerobusArrowStream + + self.assertIsNotNone(ZerobusArrowStream) + + def test_import_arrow_stream_from_aio(self): + from zerobus.sdk.aio import ZerobusArrowStream + + self.assertIsNotNone(ZerobusArrowStream) + + def test_import_arrow_stream_from_top_level(self): + from zerobus import ZerobusArrowStream + + self.assertIsNotNone(ZerobusArrowStream) + + +class TestArrowSDKAPISurface(unittest.TestCase): + """Test that Arrow SDK classes have the expected API surface.""" + + def test_sync_sdk_has_arrow_methods(self): + from zerobus.sdk.sync import ZerobusSdk + + self.assertTrue(hasattr(ZerobusSdk, "create_arrow_stream")) + self.assertTrue(hasattr(ZerobusSdk, "recreate_arrow_stream")) + + def test_async_sdk_has_arrow_methods(self): + from zerobus.sdk.aio import ZerobusSdk + + self.assertTrue(hasattr(ZerobusSdk, "create_arrow_stream")) + self.assertTrue(hasattr(ZerobusSdk, "recreate_arrow_stream")) + + def test_sync_arrow_stream_has_methods(self): + from zerobus.sdk.sync import ZerobusArrowStream + + expected_methods = [ + "ingest_batch", + "wait_for_offset", + "flush", + "close", + "get_unacked_batches", + ] + for method in expected_methods: + self.assertTrue( + hasattr(ZerobusArrowStream, method), + f"Sync ZerobusArrowStream missing method: {method}", + ) + + def test_async_arrow_stream_has_methods(self): + from zerobus.sdk.aio import ZerobusArrowStream + + expected_methods = [ + "ingest_batch", + "wait_for_offset", + "flush", + "close", + "get_unacked_batches", + ] + for method in expected_methods: + self.assertTrue( + hasattr(ZerobusArrowStream, method), + f"Async ZerobusArrowStream missing method: {method}", + ) + + def test_arrow_types_not_on_core_top_level(self): + """Arrow types should be in _core.arrow submodule, not on _core directly.""" + import zerobus._zerobus_core as _core + + self.assertFalse(hasattr(_core, "ArrowStreamConfigurationOptions")) + self.assertFalse(hasattr(_core, "ZerobusArrowStream")) + self.assertFalse(hasattr(_core, "AsyncZerobusArrowStream")) + + def test_arrow_types_in_core_arrow_submodule(self): + """Arrow types should be accessible via _core.arrow.""" + import zerobus._zerobus_core as _core + + self.assertTrue(hasattr(_core.arrow, "ArrowStreamConfigurationOptions")) + self.assertTrue(hasattr(_core.arrow, "ZerobusArrowStream")) + self.assertTrue(hasattr(_core.arrow, "AsyncZerobusArrowStream")) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/zerobus/__init__.py b/python/zerobus/__init__.py index e94c1fe..e3499e6 100644 --- a/python/zerobus/__init__.py +++ b/python/zerobus/__init__.py @@ -45,7 +45,8 @@ # Import from Rust core import zerobus._zerobus_core as _core -from zerobus.sdk.sync import ZerobusSdk, ZerobusStream +from zerobus.sdk.shared.arrow import ArrowStreamConfigurationOptions +from zerobus.sdk.sync import ZerobusArrowStream, ZerobusSdk, ZerobusStream __version__ = "1.1.0" @@ -63,6 +64,9 @@ # Sync SDK (default) "ZerobusSdk", "ZerobusStream", + # Arrow (experimental) + "ZerobusArrowStream", + "ArrowStreamConfigurationOptions", "RecordAcknowledgment", # Common types "TableProperties", diff --git a/python/zerobus/_zerobus_core.pyi b/python/zerobus/_zerobus_core.pyi index 40950cf..0700f73 100644 --- a/python/zerobus/_zerobus_core.pyi +++ b/python/zerobus/_zerobus_core.pyi @@ -508,3 +508,109 @@ class aio: A new ZerobusStream """ ... + +# ============================================================================= +# ARROW (EXPERIMENTAL) +# ============================================================================= + +class arrow: + """Arrow Flight support submodule.""" + + class ArrowStreamConfigurationOptions: + """Configuration options for Arrow Flight streams.""" + + max_inflight_batches: int + """Maximum number of batches in-flight (pending acknowledgment). Default: 1000""" + + recovery: bool + """Enable automatic stream recovery on retryable failures. Default: True""" + + recovery_timeout_ms: int + """Timeout per recovery attempt in milliseconds. Default: 15000""" + + recovery_backoff_ms: int + """Backoff between recovery attempts in milliseconds. Default: 2000""" + + recovery_retries: int + """Maximum recovery retry attempts. Default: 4""" + + server_lack_of_ack_timeout_ms: int + """Server acknowledgment timeout in milliseconds. Default: 60000""" + + flush_timeout_ms: int + """Flush timeout in milliseconds. Default: 300000""" + + connection_timeout_ms: int + """Connection establishment timeout in milliseconds. Default: 30000""" + + def __init__( + self, + *, + max_inflight_batches: int = 1000, + recovery: bool = True, + recovery_timeout_ms: int = 15000, + recovery_backoff_ms: int = 2000, + recovery_retries: int = 4, + server_lack_of_ack_timeout_ms: int = 60000, + flush_timeout_ms: int = 300000, + connection_timeout_ms: int = 30000, + ) -> None: ... + def __repr__(self) -> str: ... + + class ZerobusArrowStream: + """Synchronous Arrow Flight stream for ingesting pyarrow RecordBatches.""" + + @property + def is_closed(self) -> bool: ... + + @property + def table_name(self) -> str: ... + + def ingest_batch(self, ipc_bytes: bytes) -> int: + """Ingest a RecordBatch (as IPC bytes) and return the logical offset.""" + ... + + def wait_for_offset(self, offset: int) -> None: + """Wait for the server to acknowledge the batch at the given offset.""" + ... + + def flush(self) -> None: + """Flush all pending batches, waiting for acknowledgment.""" + ... + + def close(self) -> None: + """Close the stream gracefully.""" + ... + + def get_unacked_batches(self) -> List[bytes]: + """Return unacknowledged batches as Arrow IPC bytes.""" + ... + + class AsyncZerobusArrowStream: + """Asynchronous Arrow Flight stream for ingesting pyarrow RecordBatches.""" + + @property + def is_closed(self) -> bool: ... + + @property + def table_name(self) -> str: ... + + async def ingest_batch(self, ipc_bytes: bytes) -> int: + """Ingest a RecordBatch (as IPC bytes) and return the logical offset.""" + ... + + async def wait_for_offset(self, offset: int) -> None: + """Wait for the server to acknowledge the batch at the given offset.""" + ... + + async def flush(self) -> None: + """Flush all pending batches, waiting for acknowledgment.""" + ... + + async def close(self) -> None: + """Close the stream gracefully.""" + ... + + async def get_unacked_batches(self) -> List[bytes]: + """Return unacknowledged batches as Arrow IPC bytes.""" + ... diff --git a/python/zerobus/sdk/aio/__init__.py b/python/zerobus/sdk/aio/__init__.py index 125bffb..aadcb21 100644 --- a/python/zerobus/sdk/aio/__init__.py +++ b/python/zerobus/sdk/aio/__init__.py @@ -11,6 +11,7 @@ RecordType, StreamConfigurationOptions, TableProperties, + ZerobusArrowStream, ZerobusException, ZerobusSdk, ZerobusStream, @@ -19,6 +20,7 @@ __all__ = [ "ZerobusSdk", "ZerobusStream", + "ZerobusArrowStream", "TableProperties", "StreamConfigurationOptions", "RecordType", diff --git a/python/zerobus/sdk/aio/zerobus_sdk.py b/python/zerobus/sdk/aio/zerobus_sdk.py index 3a95380..523ef9e 100644 --- a/python/zerobus/sdk/aio/zerobus_sdk.py +++ b/python/zerobus/sdk/aio/zerobus_sdk.py @@ -46,6 +46,7 @@ # Import base Rust stream class _RustZerobusStream = _core.aio.ZerobusStream +_RustAsyncZerobusArrowStream = _core.arrow.AsyncZerobusArrowStream class ZerobusStream: @@ -155,12 +156,111 @@ def get_state(self): return 1 # OPENED state +class ZerobusArrowStream: + """ + Asynchronous Arrow Flight stream for ingesting pyarrow RecordBatches. + + **Experimental/Unsupported**: Arrow Flight support is experimental and not yet + supported for production use. The API may change in future releases. + """ + + def __init__(self, rust_stream: _RustAsyncZerobusArrowStream): + self._inner = rust_stream + + async def ingest_batch(self, batch) -> int: + """ + Ingest a pyarrow.RecordBatch or pyarrow.Table. + + Args: + batch: A pyarrow.RecordBatch or pyarrow.Table to ingest. + + Returns: + The offset ID assigned to this batch. + """ + from zerobus.sdk.shared.arrow import _serialize_batch + + ipc_bytes = _serialize_batch(batch) + return await self._inner.ingest_batch(ipc_bytes) + + async def wait_for_offset(self, offset: int): + """Wait for a specific offset to be acknowledged.""" + return await self._inner.wait_for_offset(offset) + + async def flush(self): + """Flush all pending batches, waiting for acknowledgment.""" + return await self._inner.flush() + + async def close(self): + """Close the stream gracefully.""" + return await self._inner.close() + + @property + def is_closed(self) -> bool: + """Check if the stream has been closed.""" + return self._inner.is_closed + + @property + def table_name(self) -> str: + """Get the table name.""" + return self._inner.table_name + + async def get_unacked_batches(self) -> list: + """ + Get unacknowledged batches as a list of pyarrow.RecordBatch. + + Returns: + List of pyarrow.RecordBatch objects. + """ + from zerobus.sdk.shared.arrow import _deserialize_batch + + ipc_list = await self._inner.get_unacked_batches() + return [_deserialize_batch(ipc_bytes) for ipc_bytes in ipc_list] + + class ZerobusSdk: """Python wrapper around Rust ZerobusSdk that returns wrapped streams.""" def __init__(self, host: str, unity_catalog_url: str): self._inner = _core.aio.ZerobusSdk(host, unity_catalog_url) + async def create_arrow_stream( + self, table_name: str, schema, client_id: str, client_secret: str, options=None, headers_provider=None + ) -> ZerobusArrowStream: + """ + Create an Arrow Flight stream for ingesting pyarrow RecordBatches (async). + + **Experimental/Unsupported**: Arrow Flight support is experimental. + + Args: + table_name: Fully qualified table name (catalog.schema.table). + schema: A pyarrow.Schema defining the table schema. + client_id: OAuth client ID. + client_secret: OAuth client secret. + options: Optional ArrowStreamConfigurationOptions. + headers_provider: Optional custom headers provider. + + Returns: + A ZerobusArrowStream ready for ingesting RecordBatches. + """ + from zerobus.sdk.shared.arrow import _serialize_schema + + schema_bytes = _serialize_schema(schema) + + if headers_provider is not None: + rust_stream = await self._inner.create_arrow_stream_with_headers_provider( + table_name, schema_bytes, headers_provider, options + ) + else: + rust_stream = await self._inner.create_arrow_stream( + table_name, schema_bytes, client_id, client_secret, options + ) + return ZerobusArrowStream(rust_stream) + + async def recreate_arrow_stream(self, old_stream: ZerobusArrowStream) -> ZerobusArrowStream: + """Recreate a closed Arrow stream with the same configuration.""" + rust_stream = await self._inner.recreate_arrow_stream(old_stream._inner) + return ZerobusArrowStream(rust_stream) + async def create_stream( self, client_id: str, client_secret: str, table_properties, options=None, headers_provider=None ): @@ -202,6 +302,7 @@ async def recreate_stream(self, old_stream: ZerobusStream): __all__ = [ "ZerobusSdk", "ZerobusStream", + "ZerobusArrowStream", "TableProperties", "StreamConfigurationOptions", "RecordType", diff --git a/python/zerobus/sdk/shared/arrow.py b/python/zerobus/sdk/shared/arrow.py new file mode 100644 index 0000000..b5b833e --- /dev/null +++ b/python/zerobus/sdk/shared/arrow.py @@ -0,0 +1,87 @@ +""" +Arrow Flight support for the Zerobus SDK. + +**Experimental/Unsupported**: Arrow Flight support is experimental and not yet +supported for production use. The API may change in future releases. + +Requires pyarrow to be installed: + pip install databricks-zerobus-ingest-sdk[arrow] + +Example (Sync): + >>> import pyarrow as pa + >>> from zerobus.sdk.sync import ZerobusSdk + >>> from zerobus.sdk.shared.arrow import ArrowStreamConfigurationOptions + >>> + >>> sdk = ZerobusSdk(host, unity_catalog_url) + >>> schema = pa.schema([("device_name", pa.large_utf8()), ("temp", pa.int32())]) + >>> stream = sdk.create_arrow_stream( + ... "catalog.schema.table", schema, client_id, client_secret + ... ) + >>> batch = pa.record_batch({"device_name": ["s1"], "temp": [22]}, schema=schema) + >>> offset = stream.ingest_batch(batch) + >>> stream.wait_for_offset(offset) + >>> stream.close() +""" + +import zerobus._zerobus_core as _core + +_PYARROW_IMPORT_ERROR = ( + "pyarrow is required for Arrow Flight support. " "Install with: pip install databricks-zerobus-ingest-sdk[arrow]" +) + + +def _check_pyarrow(): + """Check that pyarrow is available, raise ImportError if not.""" + try: + import pyarrow # noqa: F401 + + return pyarrow + except ImportError: + raise ImportError(_PYARROW_IMPORT_ERROR) + + +def _serialize_schema(schema): + """Serialize a pyarrow.Schema to IPC bytes.""" + pa = _check_pyarrow() + if not isinstance(schema, pa.Schema): + raise TypeError(f"Expected pyarrow.Schema, got {type(schema).__name__}") + + # Create an empty RecordBatch with the schema and serialize it. + # This produces a valid IPC stream that the Rust side can parse for the schema. + sink = pa.BufferOutputStream() + writer = pa.ipc.new_stream(sink, schema) + writer.close() + return sink.getvalue().to_pybytes() + + +def _serialize_batch(batch): + """Serialize a pyarrow.RecordBatch (or Table) to IPC bytes.""" + pa = _check_pyarrow() + if isinstance(batch, pa.Table): + if batch.num_rows == 0: + raise ValueError("Cannot ingest an empty pyarrow.Table") + # Convert Table to a single RecordBatch, combining chunks if needed + batches = batch.to_batches() + if len(batches) > 1: + batch = pa.concat_batches(batches) + else: + batch = batches[0] + elif not isinstance(batch, pa.RecordBatch): + raise TypeError(f"Expected pyarrow.RecordBatch or pyarrow.Table, got {type(batch).__name__}") + + sink = pa.BufferOutputStream() + writer = pa.ipc.new_stream(sink, batch.schema) + writer.write_batch(batch) + writer.close() + return sink.getvalue().to_pybytes() + + +def _deserialize_batch(ipc_bytes): + """Deserialize IPC bytes back to a pyarrow.RecordBatch.""" + pa = _check_pyarrow() + reader = pa.ipc.open_stream(ipc_bytes) + return reader.read_next_batch() + + +# Re-export configuration from Rust core +ArrowStreamConfigurationOptions = _core.arrow.ArrowStreamConfigurationOptions diff --git a/python/zerobus/sdk/sync/__init__.py b/python/zerobus/sdk/sync/__init__.py index a75cb75..a2fe871 100644 --- a/python/zerobus/sdk/sync/__init__.py +++ b/python/zerobus/sdk/sync/__init__.py @@ -11,6 +11,7 @@ RecordType, StreamConfigurationOptions, TableProperties, + ZerobusArrowStream, ZerobusException, ZerobusSdk, ZerobusStream, @@ -19,6 +20,7 @@ __all__ = [ "ZerobusSdk", "ZerobusStream", + "ZerobusArrowStream", "RecordAcknowledgment", "TableProperties", "StreamConfigurationOptions", diff --git a/python/zerobus/sdk/sync/zerobus_sdk.py b/python/zerobus/sdk/sync/zerobus_sdk.py index fd09d33..ca7fedc 100644 --- a/python/zerobus/sdk/sync/zerobus_sdk.py +++ b/python/zerobus/sdk/sync/zerobus_sdk.py @@ -122,12 +122,129 @@ def get_state(self): return self._inner.get_state() if hasattr(self._inner, "get_state") else 1 +class ZerobusArrowStream: + """ + Synchronous Arrow Flight stream for ingesting pyarrow RecordBatches. + + **Experimental/Unsupported**: Arrow Flight support is experimental and not yet + supported for production use. The API may change in future releases. + + Example: + >>> import pyarrow as pa + >>> schema = pa.schema([("temp", pa.int32())]) + >>> stream = sdk.create_arrow_stream("catalog.schema.table", schema, client_id, client_secret) + >>> batch = pa.record_batch({"temp": [22, 23]}, schema=schema) + >>> offset = stream.ingest_batch(batch) + >>> stream.wait_for_offset(offset) + >>> stream.close() + """ + + def __init__(self, rust_stream): + self._inner = rust_stream + + def ingest_batch(self, batch) -> int: + """ + Ingest a pyarrow.RecordBatch or pyarrow.Table. + + Args: + batch: A pyarrow.RecordBatch or pyarrow.Table to ingest. + + Returns: + The offset ID assigned to this batch. + """ + from zerobus.sdk.shared.arrow import _serialize_batch + + ipc_bytes = _serialize_batch(batch) + return self._inner.ingest_batch(ipc_bytes) + + def wait_for_offset(self, offset: int): + """Wait for a specific offset to be acknowledged.""" + return self._inner.wait_for_offset(offset) + + def flush(self): + """Flush all pending batches, waiting for acknowledgment.""" + return self._inner.flush() + + def close(self): + """Close the stream gracefully.""" + return self._inner.close() + + @property + def is_closed(self) -> bool: + """Check if the stream has been closed.""" + return self._inner.is_closed + + @property + def table_name(self) -> str: + """Get the table name.""" + return self._inner.table_name + + def get_unacked_batches(self) -> list: + """ + Get unacknowledged batches as a list of pyarrow.RecordBatch. + + The stream must be closed before calling this method. + + Returns: + List of pyarrow.RecordBatch objects. + """ + from zerobus.sdk.shared.arrow import _deserialize_batch + + ipc_list = self._inner.get_unacked_batches() + return [_deserialize_batch(ipc_bytes) for ipc_bytes in ipc_list] + + class ZerobusSdk: """Python wrapper around Rust ZerobusSdk that provides unified create_stream API.""" def __init__(self, host: str, unity_catalog_url: str): self._inner = _RustZerobusSdk(host, unity_catalog_url) + def create_arrow_stream( + self, table_name: str, schema, client_id: str, client_secret: str, options=None, headers_provider=None + ) -> ZerobusArrowStream: + """ + Create an Arrow Flight stream for ingesting pyarrow RecordBatches. + + **Experimental/Unsupported**: Arrow Flight support is experimental. + + Args: + table_name: Fully qualified table name (catalog.schema.table). + schema: A pyarrow.Schema defining the table schema. + client_id: OAuth client ID. + client_secret: OAuth client secret. + options: Optional ArrowStreamConfigurationOptions. + headers_provider: Optional custom headers provider (if set, overrides OAuth). + + Returns: + A ZerobusArrowStream ready for ingesting RecordBatches. + """ + from zerobus.sdk.shared.arrow import _serialize_schema + + schema_bytes = _serialize_schema(schema) + + if headers_provider is not None: + rust_stream = self._inner.create_arrow_stream_with_headers_provider( + table_name, schema_bytes, headers_provider, options + ) + else: + rust_stream = self._inner.create_arrow_stream(table_name, schema_bytes, client_id, client_secret, options) + return ZerobusArrowStream(rust_stream) + + def recreate_arrow_stream(self, old_stream: ZerobusArrowStream) -> ZerobusArrowStream: + """ + Recreate a closed Arrow stream with the same configuration, + re-ingesting unacknowledged batches. + + Args: + old_stream: The closed Arrow stream to recreate. + + Returns: + A new ZerobusArrowStream. + """ + rust_stream = self._inner.recreate_arrow_stream(old_stream._inner) + return ZerobusArrowStream(rust_stream) + def create_stream(self, client_id: str, client_secret: str, table_properties, options=None, headers_provider=None): """ Create a stream with OAuth authentication or custom headers provider. @@ -168,6 +285,7 @@ def recreate_stream(self, old_stream: ZerobusStream): __all__ = [ "ZerobusSdk", "ZerobusStream", + "ZerobusArrowStream", "RecordAcknowledgment", "TableProperties", "StreamConfigurationOptions",