Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 42 additions & 24 deletions src/disco/events/gen_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Field:
description: str
variants: Optional[Dict[str, Variant]] = None
fields: Optional[Dict[str, "Field"]] = None
shared_name: Optional[str] = None

@dataclass
class Schema:
Expand All @@ -78,13 +79,15 @@ class Schema:
description: str
fields: Dict[str, Field]

def parse_field(f: dict) -> Field:
return self._TO_PROTO.value[self.name]

def parse_field(f: dict) -> Field:
def parse_field(f: dict, shared_types: Dict[str, dict]) -> Field:
if f["type"].startswith("ref:"):
field = parse_field(shared_types[f["type"][4:]], shared_types)
field.shared_name = f["type"][4:]
return field

fields = None
if f["type"] == "Flatten":
fields = {k: parse_field(v) for k, v in f["fields"].items()}
fields = {k: parse_field(v, shared_types) for k, v in f["fields"].items()}

return Field(
chtype=ClickHouseType.from_str(f["type"]),
Expand All @@ -93,62 +96,75 @@ def parse_field(f: dict) -> Field:
fields=fields
)

def parse_schema(path: Path) -> Schema:
def parse_schema(path: Path, shared_types: Dict[str, dict]) -> Schema:
data = json.loads(path.read_text())

fields = {k: parse_field(v) for k, v in data["fields"].items()}
fields = {k: parse_field(v, shared_types) for k, v in data["fields"].items()}
return Schema(data["name"], data["id"], data["description"], fields)

def collect_nested_messages(schema_name: str, fields: Dict[str, Field], prefix: str = "") -> List[tuple]:
def collect_nested_messages(fields: Dict[str, Field], prefix: str = "") -> List[tuple]:
msgs = []

for name, f in fields.items():
if f.chtype == ClickHouseType.Flatten:
prefix = f"{prefix}{to_pascal_case(name)}"
msgs.append((f"{to_pascal_case(schema_name)}{prefix}", f.fields, f.description))
msgs += collect_nested_messages(schema_name, f.fields, prefix)
new_prefix = f.shared_name or f"{prefix}{to_pascal_case(name)}"
msgs.append((new_prefix, f.fields, f.description))
msgs += collect_nested_messages(f.fields, new_prefix)

return msgs

def generate_message_fields(schema_name: str, fields: Dict[str, Field], prefix: str = "") -> List[str]:
def generate_message_fields(fields: Dict[str, Field], prefix: str = "") -> List[str]:
lines = []

for i, (name, f) in enumerate(fields.items(), 1):
if f.chtype == ClickHouseType.Flatten or f.variants:
proto_type = f"{to_pascal_case(schema_name)}{prefix}{to_pascal_case(name)}"
proto_type = f.shared_name or f"{prefix}{to_pascal_case(name)}"
else:
proto_type = f.chtype.to_protobuf_type()
lines += [f" // {f.description}", f" {proto_type} {name} = {i};"]

return lines

def generate_enums(schema_name: str, fields: Dict[str, Field], prefix: str = "") -> List[str]:
def generate_enums(fields: Dict[str, Field], prefix: str, generated: set) -> List[str]:
lines = []
for name, f in fields.items():
if f.variants:
enum = f"{to_pascal_case(schema_name)}{prefix}{to_pascal_case(name)}"
enum = f.shared_name or f"{prefix}{to_pascal_case(name)}"
if enum in generated:
continue

generated.add(enum)
ep = to_screaming_snake_case(enum)
lines += [f"// {f.description}", f"enum {enum} {{", f" {ep}_UNSPECIFIED = 0;"]
for i, (vn, v) in enumerate(f.variants.items(), 1):
lines.append(f" {ep}_{to_screaming_snake_case(vn)} = {i}; // {v.description}")
lines += ["}", ""]
if f.chtype == ClickHouseType.Flatten:
lines += generate_enums(schema_name, f.fields, f"{prefix}{to_pascal_case(name)}")
nested_prefix = f.shared_name or nested_prefix = f"{prefix}{to_pascal_case(name)}"
lines += generate_enums(f.fields, nested_prefix, generated)
return lines

def generate_protobuf(schemas: List[Schema]) -> str:
lines = ['syntax = "proto3";', "", "package events.v1;", ""]

generated_enums = set()
for s in schemas:
lines += generate_enums(s.name, s.fields)
schema_prefix = to_pascal_case(s.name)
lines += generate_enums(s.fields, schema_prefix, generated_enums)

generated_msgs = set()
for s in schemas:
for msg, flds, desc in reversed(collect_nested_messages(s.name, s.fields)):
prefix = msg[len(to_pascal_case(s.name)):]
lines += [f"// {desc}", f"message {msg} {{"] + generate_message_fields(s.name, flds, prefix) + ["}", ""]
schema_prefix = to_pascal_case(s.name)
for msg, flds, desc in reversed(collect_nested_messages(s.fields, schema_prefix)):
if msg in generated_msgs:
continue

generated_msgs.add(msg)
lines += [f"// {desc}", f"message {msg} {{"] + generate_message_fields(flds, msg) + ["}", ""]

for s in schemas:
lines += [f"// {s.description}", f"message {to_pascal_case(s.name)} {{"] + generate_message_fields(s.name, s.fields) + ["}", ""]
schema_prefix = to_pascal_case(s.name)
lines += [f"// {s.description}", f"message {schema_prefix} {{"] + generate_message_fields(s.fields, schema_prefix) + ["}", ""]

lines += ["// Combined event type", "message Event {", " oneof event {"]
for s in schemas:
Expand All @@ -171,7 +187,7 @@ def check_breaking_changes(schema_dir: Path) -> None:
check=True
)

print("No breaking changes detected")
print("No breaking changes detected")

def main() -> None:
parser = argparse.ArgumentParser(description="Generate protobuf from JSON schemas")
Expand All @@ -181,10 +197,12 @@ def main() -> None:
schema_dir = Path(__file__).parent / "schema"
proto_path = schema_dir / "events.proto"

schemas = sorted([parse_schema(f) for f in schema_dir.glob("*.json")], key=lambda s: s.id)
shared_types = json.loads((schema_dir / "shared.json").read_text())
schema_files = [f for f in schema_dir.glob("*.json") if f.name != "shared.json"]
schemas = sorted([parse_schema(f, shared_types) for f in schema_files], key=lambda s: s.id)
proto_path.write_text(generate_protobuf(schemas))

print(f"Protobuf generated successfully for {len(schemas)} schemas")
print(f"Protobuf generated successfully from {len(schemas)} schemas")

if not args.skip_check:
check_breaking_changes(schema_dir)
Expand Down
Loading
Loading