Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ requires-python = ">=3.12"
dependencies = [
"aiofiles>=25.1.0",
"aiosqlite>=0.21.0",
"fastapi>=0.120.2",
"httpx>=0.28.1",
"msgpack>=1.1.2",
"pydantic>=2.12.3",
"uvicorn>=0.38.0",
]


Expand Down
210 changes: 184 additions & 26 deletions tinystream/broker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import defaultdict
import json
from pathlib import Path
from typing import Dict, Any, Optional, Type, Literal
import asyncio
Expand All @@ -9,10 +9,14 @@

from aiosqlite import Connection as DBConnection

from tinystream.metastore import Metastore
from tinystream.models import TopicMetadata, PartitionMetadata
from tinystream import DEFAULT_CONFIG_PATH
from tinystream.client.connection import TinyStreamAPI
from tinystream.client.base import BaseAsyncClient
from tinystream.client.topic_manager import TopicManager
from tinystream.config.parser import TinyStreamConfig
from tinystream.controller import BrokerInfo
from tinystream.partitions.base import BasePartition
from tinystream.partitions.partition import SingleLogPartition
from tinystream.serializer.base import AbstractSerializer
Expand Down Expand Up @@ -58,7 +62,7 @@ def __init__(self, config: TinyStreamConfig, broker_id: Optional[int]) -> None:
port=int(controller_config.get("port")), # type: ignore
serializer=self.serializer,
)
self.heartbeat_task: Optional[asyncio.Task] = None
self.heartbeat_task: Optional[asyncio.Task[Any]] = None

super().__init__(
prefix_size=self.prefix_size,
Expand All @@ -71,12 +75,38 @@ def __init__(self, config: TinyStreamConfig, broker_id: Optional[int]) -> None:
self.broker_config.get("partition_type", "singlelogpartition")
)

# In-memory mapping of:
self.brokers: Dict[int, BrokerInfo] = {}
self.topic_metadata: Dict[str, TopicMetadata] = {}

self.metastore_task: Optional[asyncio.Task[Any]] = None

# In-memory mapping of actual partition objects:
# { topic_name -> { partition_id -> BasePartition } }
self.topics: Dict[str, Dict[int, BasePartition]] = defaultdict(dict)
self.partitions: Dict[str, Dict[int, BasePartition]] = {}
self._lock = asyncio.Lock()
self._server: Optional[asyncio.Server] = None

if self.mode == "single":
print(
f"[Broker {broker_id}] Running in 'single' mode. Initializing topic manager."
)
self.topic_manager = TopicManager(
db_connection=None,
brokers=self.brokers,
topics=self.topic_metadata,
lock=self._lock,
)

metastore_config = self.config.metastore
self.metastore_http_port = int(metastore_config.get("http_port", 6000))
self.metastore = Metastore(
topic_manager=self.topic_manager,
topics=self.topic_metadata,
brokers=self.brokers,
lock=self._lock,
port=self.metastore_http_port,
)

@staticmethod
def init_partition_class(partition_name: str) -> Type[BasePartition]:
if partition_name == "singlelogpartition":
Expand Down Expand Up @@ -123,17 +153,21 @@ async def _create_new_partition(
)

await partition.load()
self.topics[topic_name][partition_id] = partition

if topic_name not in self.partitions:
self.partitions[topic_name] = {}
self.partitions[topic_name][partition_id] = partition

if self.db_conn:
try:
await self.db_conn.execute(
"""
INSERT
OR IGNORE INTO partitions (topic_name, partition_id)
VALUES (?, ?)
OR IGNORE
INTO partitions (topic_name, partition_id, replicas)
VALUES (?, ?, ?)
""",
(topic_name, partition_id),
(topic_name, partition_id, json.dumps([])),
)
await self.db_conn.commit()
except Exception as exception:
Expand Down Expand Up @@ -166,7 +200,7 @@ async def load_partitions(self) -> None:
except ValueError:
print(f"Skipping non-numeric log file: {log_file}")

print(f"Finished loading. Found {len(self.topics)} topics.")
print(f"Finished loading. Found {len(self.partitions)} topics.")

async def get_or_create_partition(
self, topic_name: str, partition_id: int
Expand All @@ -175,31 +209,60 @@ async def get_or_create_partition(
Retrieves a partition, creating it if it doesn't exist.
This is a central part of the broker's logic.
"""
if partition := self.topics.get(topic_name, {}).get(partition_id):
if partition := self.partitions.get(topic_name, {}).get(partition_id):
return partition

async with self._lock:
if partition := self.topics.get(topic_name, {}).get(partition_id):
if partition := self.partitions.get(topic_name, {}).get(partition_id):
return partition

return await self._create_new_partition(topic_name, partition_id)

async def start(self) -> None:
"""Starts the main broker server."""
await self.init_metastore(db_path=self.metastore_db_path)

if self.mode == "cluster":
await self.controller_client.ensure_connected()
print(f"[Broker {self.broker_id}] Registering with controller...")
await self.controller_client.send_request(

response = await self.controller_client.send_request(
{
"command": "register_broker",
"broker_id": self.broker_id,
"host": self.host,
"port": self.port,
}
)

if response and response.get("status") == "ok":
print(f"[Broker {self.broker_id}] Registered successfully.")
assignments = response.get("assignments", [])
await self._reconcile_partitions(assignments)
else:
raise Exception(f"Could not register with controller: {response}")

self.heartbeat_task = asyncio.create_task(self._heartbeat_loop())

else:
await self.init_metastore(db_path=self.metastore_db_path)
self.topic_manager.db_connection = self.db_conn
await self._load_metadata_from_db()
if self.broker_id is None:
raise ValueError("broker_id must not be None")
self_info = BrokerInfo(
broker_id=self.broker_id, host=self.host, port=self.port, is_alive=True
)
async with self._lock:
self.brokers[self.broker_id] = self_info

self.metastore_task = asyncio.create_task( # type: ignore
self.metastore.start(), name="broker-metastore-api"
)
print(
f"[Broker {self.broker_id}] Metastore API docs at http://localhost:{self.port}/docs"
)

print("Metastore initialized successfully.")

await self.load_partitions()
await self.start_server()

Expand Down Expand Up @@ -248,6 +311,33 @@ async def send_request(self, payload_bytes: bytes) -> Dict[str, Any]:
elif command == "commit_offset":
return await self._handle_commit_offset(request)

elif command == "create_topic":
if self.mode == "single" and self.topic_manager:
try:
await self.topic_manager.create_topic(
request["name"],
request["partitions"],
request["replication_factor"],
)
return {
"status": "success",
"message": f"Topic {request['name']} created.",
}
except ValueError as exception:
return {"status": "error", "message": str(exception)}
except Exception as exception:
print(
f"FATAL error in handle_create_topic_request: {exception}"
)
return {
"status": "error",
"message": f"Internal server error: {exception}",
}
else:
return {
"status": "error",
"message": "create_topic is only allowed in single mode. Use admin client.",
}
else:
return {"status": "error", "message": "Unknown command"}

Expand All @@ -257,6 +347,35 @@ async def send_request(self, payload_bytes: bytes) -> Dict[str, Any]:
"message": f"Failed to process request: {exception}",
}

async def _load_metadata_from_db(self):
print(f"[Broker {self.broker_id}] Loading metadata from metastore...")
async with self._lock:
# --- MODIFIED: Use self.topic_metadata ---
async with self.db_conn.execute("SELECT * FROM topics") as cursor:
async for row in cursor:
topic_name = row["topic_name"]
self.topic_metadata[topic_name] = TopicMetadata(
name=topic_name, partitions={}
)

async with self.db_conn.execute("SELECT * FROM partitions") as cursor:
async for row in cursor:
topic_name = row["topic_name"]
if topic_name in self.topic_metadata:
replicas_list = json.loads(row["replicas"])
part_meta = PartitionMetadata(
partition_id=row["partition_id"],
leader=row["leader"],
replicas=replicas_list,
)
self.topic_metadata[topic_name].partitions[
part_meta.partition_id
] = part_meta

print(
f"[Broker {self.broker_id}] Loaded {len(self.topic_metadata)} topics from metastore."
)

async def _handle_append(self, request: Dict[str, Any]) -> Dict[str, Any]:
topic = request["topic"]
partition_id = request["partition"]
Expand Down Expand Up @@ -286,23 +405,56 @@ async def _handle_read(self, request: Dict[str, Any]) -> Dict[str, Any]:
except KeyError:
return {"status": "error", "message": "Topic or partition not found"}

async def _reconcile_partitions(self, assignments: list[dict]):
"""
Compares the controller's assignments with local state and creates
any missing partitions. This is the core of the "pull" model.
"""
print(
f"[Broker {self.broker_id}] Reconciling {len(assignments)} assignments..."
)
for assignment in assignments:
topic = assignment["topic"]
part_id = assignment["partition_id"]

try:
await self.get_or_create_partition(topic, part_id)

# TODO: A full implementation would also:
# 1. Store the 'role' (leader/follower)
# 2. Trigger follower fetching logic if role == 'follower'
# 3. Handle partition *removals* (i.e., assignments that
# disappeared from the controller's list)

except Exception as e:
print(f"Error reconciling partition {topic}/{part_id}: {e}")

async def _heartbeat_loop(self):
"""Runs in the background, sending heartbeats to the controller."""
"""Runs in the background, sending heartbeats AND processing assignments."""
while True:
try:
await self.controller_client.send_request(
response = await self.controller_client.send_request(
{"command": "heartbeat", "broker_id": self.broker_id}
)
print(f"[Broker {self.broker_id}] Heartbeat sent.")

if response and response.get("status") == "ok":
assignments = response.get("assignments", [])
await self._reconcile_partitions(assignments)
print(f"[Broker {self.broker_id}] Heartbeat sent and processed.")
else:
print(
f"[Broker {self.broker_id}] Invalid heartbeat response: {response}"
)

except Exception as exception:
print(
f"[Broker {self.broker_id}] Failed to send heartbeat: {exception}"
)

await asyncio.sleep(3)
await asyncio.sleep(3) # Configurable interval

async def _handle_get_hwm(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Gets the High Watermark (next write offset) for a partition."""
"""Next write offset for a partition."""
topic = request["topic"]
partition_id = request["partition"]

Expand Down Expand Up @@ -365,7 +517,7 @@ async def main(
base_config_path = config or DEFAULT_CONFIG_PATH
try:
base_config = TinyStreamConfig.from_ini(base_config_path)
base_port = int(base_config.broker_config.get("port", "909"))
base_port = int(base_config.broker_config.get("port", "9090"))
except Exception as exception:
print(
f"FATAL: Could not load base config from {base_config_path}: {exception}"
Expand All @@ -386,7 +538,6 @@ async def main(
broker_id=i,
)
)

start_tasks = [b.start() for b in broker_instances]
try:
await asyncio.gather(*start_tasks)
Expand All @@ -409,9 +560,15 @@ async def main(

config_obj.mode = mode # type: ignore

base_port = int(config_obj.broker_config.get("port", "909"))
broker_port = base_port + broker_id_to_use # type: ignore
config_obj.broker_config["port"] = f"{broker_port}"
try:
port_value = config_obj.broker_config.get("port")
if port_value is None:
print(f"FATAL: 'port' not found or is invalid in config file: {config}")
return
broker_port = int(port_value)
except (TypeError, ValueError):
print(f"FATAL: 'port' not found or is invalid in config file: {config}")
return

broker = Broker(
config=config_obj,
Expand Down Expand Up @@ -442,7 +599,6 @@ async def main(
if __name__ == "__main__":

def print_usage(parser_instance, message):
"""Prints a validation error and the parser's help message."""
print(f"Error: {message}\n")
parser_instance.print_help()
sys.exit(1)
Expand Down Expand Up @@ -503,5 +659,7 @@ def print_usage(parser_instance, message):
)
except Exception as e:
print(f"FATAL: Broker main loop crashed: {e}")
# (Consider adding `import traceback; traceback.print_exc()` for debug)
import traceback

traceback.print_exc()
sys.exit(1)
Loading