-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathstream_manager.py
More file actions
170 lines (140 loc) · 6.11 KB
/
stream_manager.py
File metadata and controls
170 lines (140 loc) · 6.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Stream Manager for Zerobus Streams
Manages persistent Zerobus streams for each configured table.
Handles stream lifecycle, health checks, and graceful cleanup.
"""
import asyncio
import logging
from typing import Dict
from importlib import import_module
from zerobus_sdk.aio import ZerobusSdk
from zerobus_sdk.shared import StreamConfigurationOptions, TableProperties, get_zerobus_token
logger = logging.getLogger(__name__)
class StreamManager:
"""
Manages persistent Zerobus streams for multiple tables.
Each table gets one persistent stream that is kept alive for performance.
Streams are created on-demand and cleaned up when no longer needed.
"""
def __init__(self, server_endpoint: str, workspace_id: str, workspace_url: str,
client_id: str, client_secret: str):
"""
Initialize the StreamManager.
Args:
server_endpoint: Zerobus gRPC endpoint
workspace_id: Databricks workspace ID
workspace_url: Databricks workspace URL
client_id: OAuth client ID (service principal)
client_secret: OAuth client secret
"""
self.server_endpoint = server_endpoint
self.workspace_id = workspace_id
self.workspace_url = workspace_url
self.client_id = client_id
self.client_secret = client_secret
self.sdk = ZerobusSdk(server_endpoint)
self.streams: Dict[str, any] = {}
self._locks: Dict[str, asyncio.Lock] = {}
async def get_stream(self, table_key: str, table_name: str,
proto_module: str, message_name: str):
"""
Get or create a stream for a table.
Args:
table_key: Unique key for the table (e.g., "station_one")
table_name: Fully qualified table name in Databricks
proto_module: Python module path for the protobuf (e.g., "tables.station_one.schema_pb2")
message_name: Name of the protobuf message (e.g., "StationOne")
Returns:
The Zerobus stream for this table
"""
if table_key not in self._locks:
self._locks[table_key] = asyncio.Lock()
async with self._locks[table_key]:
if table_key in self.streams:
stream = self.streams[table_key]
state = stream.get_state()
if str(state) in ["OPENED", "StreamState.OPENED"]:
return stream
else:
logger.warning(f"Stream {table_key} is in state {state}, recreating...")
await self._close_stream(table_key)
logger.info(f"Creating new stream for table {table_key} ({table_name})")
try:
pb_module = import_module(proto_module)
message_class = getattr(pb_module, message_name)
descriptor = message_class.DESCRIPTOR
except (ImportError, AttributeError) as e:
logger.error(f"Failed to import protobuf for {table_key}: {e}")
raise
def token_factory():
return get_zerobus_token(
table_name=table_name,
workspace_id=self.workspace_id,
workspace_url=self.workspace_url,
client_id=self.client_id,
client_secret=self.client_secret
)
options = StreamConfigurationOptions(
max_inflight_records=50_000,
recovery=True,
token_factory=token_factory,
ack_callback=self._create_ack_callback(table_key)
)
table_properties = TableProperties(table_name, descriptor)
try:
stream = await self.sdk.create_stream(table_properties, options)
self.streams[table_key] = stream
logger.info(f"✓ Stream created for {table_key}: {stream.stream_id}")
return stream
except Exception as e:
logger.error(f"Failed to create stream for {table_key}: {e}")
raise
def _create_ack_callback(self, table_key: str):
"""Create an acknowledgment callback for a specific table."""
def callback(response):
offset = response.durability_ack_up_to_offset
if offset % 1000 == 0:
logger.info(f"[{table_key}] Acknowledged up to offset: {offset}")
return callback
async def ingest_record(self, table_key: str, record):
"""
Ingest a record into the specified table's stream.
Args:
table_key: Table identifier
record: Protobuf message to ingest
Returns:
Future that resolves when record is acknowledged
"""
if table_key not in self.streams:
raise ValueError(f"No stream available for table {table_key}")
stream = self.streams[table_key]
future = await stream.ingest_record(record)
return future
async def _close_stream(self, table_key: str):
"""Close a specific stream."""
if table_key in self.streams:
stream = self.streams[table_key]
try:
await stream.close()
logger.info(f"✓ Stream closed for {table_key}")
except Exception as e:
logger.error(f"Error closing stream for {table_key}: {e}")
finally:
del self.streams[table_key]
async def close_all(self):
"""Close all active streams gracefully."""
logger.info("Closing all streams...")
for table_key in list(self.streams.keys()):
await self._close_stream(table_key)
logger.info("✓ All streams closed")
async def remove_table(self, table_key: str):
"""
Remove a table's stream (called when table is removed from config).
Args:
table_key: Table identifier to remove
"""
logger.info(f"Removing stream for table {table_key}")
await self._close_stream(table_key)
def get_active_tables(self) -> list:
"""Get list of tables with active streams."""
return list(self.streams.keys())