Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .jixia_db import load_data
from .vector_db import create_vector_db
from .create_schema import create_schema
from .migrate import migrate


def main():
Expand All @@ -23,6 +24,7 @@ def main():
"prefixes",
help="Comma-separated list of module prefixes to be included in the index; e.g., Init,Mathlib",
)
jixia_parser.add_argument("--project-name")
informal_parser = subparser.add_parser("informal")
informal_parser.set_defaults(command="informal")
informal_parser.add_argument("--batch-size", type=int, default=50)
Expand All @@ -39,16 +41,19 @@ def main():
vector_db_parser = subparser.add_parser("vector-db")
vector_db_parser.set_defaults(command="vector-db")
vector_db_parser.add_argument("--batch-size", type=int, default=8)
vector_db_parser.add_argument("--project-name")

args = parser.parse_args()

with psycopg.connect(os.environ["CONNECTION_STRING"], autocommit=True) as conn:
if args.command == "migrate":
migrate(conn)
if args.command == "schema":
create_schema(conn)
elif args.command == "jixia":
project = LeanProject(args.project_root)
prefixes = [parse_name(p) for p in args.prefixes.split(",")]
load_data(project, prefixes, conn)
load_data(project, prefixes, conn, project_name=args.project_name)
elif args.command == "informal":
generate_informal(
conn,
Expand All @@ -57,4 +62,4 @@ def main():
limit_num_per_level=args.limit_num_per_level,
)
elif args.command == "vector-db":
create_vector_db(conn, os.environ["CHROMA_PATH"], batch_size=args.batch_size)
create_vector_db(conn, os.environ["CHROMA_PATH"], batch_size=args.batch_size, project_name=args.project_name)
21 changes: 14 additions & 7 deletions database/informalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ def find_neighbor(conn: Connection, module_name: LeanName, index: int, num_neigh
with conn.cursor(row_factory=args_row(TranslatedItem)) as cursor:
cursor.execute(
"""
SELECT d.name, d.signature, i.name, i.description
SELECT s.name, d.signature, d.value, d.docstring, d.kind, i.name, i.description, d.signature
FROM
declaration d
symbol s
INNER JOIN declaration d ON s.name = d.name
INNER JOIN module m ON s.module_name = m.name
INNER JOIN level l ON s.name = l.symbol_name
LEFT JOIN informal i ON d.name = i.symbol_name
WHERE
d.module_name = %s AND d.index >= %s AND d.index <= %s
Expand All @@ -32,11 +35,14 @@ def find_dependency(conn: Connection, name: LeanName) -> list[TranslatedItem]:
with conn.cursor(row_factory=args_row(TranslatedItem)) as cursor:
cursor.execute(
"""
SELECT d.name, d.signature, i.name, i.description
SELECT s.name, d.signature, d.value, d.docstring, d.kind, i.name, i.description, d.signature
FROM
declaration d
INNER JOIN dependency e ON d.name = e.target
symbol s
INNER JOIN declaration d ON s.name = d.name
INNER JOIN module m ON s.module_name = m.name
INNER JOIN level l ON s.name = l.symbol_name
LEFT JOIN informal i ON d.name = i.symbol_name
INNER JOIN dependency e ON d.name = e.target
WHERE
e.source = %s
""",
Expand All @@ -57,7 +63,7 @@ def generate_informal(conn: Connection, batch_size: int = 50, limit_level: int |
with conn.cursor() as cursor, conn.cursor() as insert_cursor:
for level in range(max_level + 1):
query = """
SELECT s.name, d.signature, d.value, d.docstring, d.kind, m.docstring, d.module_name, d.index
SELECT s.name, d.signature, d.value, d.docstring, d.kind, m.docstring, d.module_name, d.index, m.project_name
FROM
symbol s
INNER JOIN declaration d ON s.name = d.name
Expand Down Expand Up @@ -92,7 +98,7 @@ async def translate_and_insert(name: LeanName, data: TranslationInput):

tasks.clear()
for row in batch:
name, signature, value, docstring, kind, header, module_name, index = row
name, signature, value, docstring, kind, header, module_name, index, project_name = row

logger.info("translating %s", name)
neighbor = find_neighbor(conn, module_name, index)
Expand All @@ -107,6 +113,7 @@ async def translate_and_insert(name: LeanName, data: TranslationInput):
header=header,
neighbor=neighbor,
dependency=dependency,
project_name=project_name,
)
tasks.append(translate_and_insert(name, ti))

Expand Down
6 changes: 3 additions & 3 deletions database/jixia_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ def _get_value(declaration: Declaration, module_content):
else:
return None

def load_data(project: LeanProject, prefixes: list[LeanName], conn: Connection):
def load_data(project: LeanProject, prefixes: list[LeanName], conn: Connection, project_name):
def load_module(data: Iterable[LeanName], base_dir: Path):
values = ((Jsonb(m), project.path_of_module(m, base_dir).read_bytes(), project.load_module_info(m).docstring) for m in data)
values = ((Jsonb(m), project.path_of_module(m, base_dir).read_bytes(), project.load_module_info(m).docstring, project_name) for m in data)
cursor.executemany(
"""
INSERT INTO module (name, content, docstring) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING
INSERT INTO module (name, content, docstring, project_name) VALUES (%s, %s, %s, %s) ON CONFLICT DO NOTHING
""",
values,
)
Expand Down
6 changes: 6 additions & 0 deletions database/migrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from psycopg import Connection

def migrate(conn: Connection):
with conn.cursor() as cursor:
cursor.execute("ALTER TABLE module ADD COLUMN project_name TEXT")
cursor.execute("UPDATE module SET project_name = 'mathlib'")
40 changes: 31 additions & 9 deletions database/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,30 @@
import re
from dataclasses import dataclass
from json import JSONDecodeError
import xml.etree.ElementTree as ET

import jinja2
from jinja2 import Environment, FileSystemLoader
from jixia.structs import DeclarationKind, LeanName, pp_name
from openai import AsyncOpenAI
from prompt.metaprogramming import metaprogramming_prompt

logger = logging.getLogger(__name__)


@dataclass
class TranslatedItem:
name: LeanName
description: str
signature: str
value: str | None
docstring: str | None
kind: DeclarationKind

informal_name: str | None
informal_description: str | None

# The "description" field is a copypaste of the more appropriately named "signature" field, it's here for backwards compatibility, and can be removed when all prompts are switch to using the "signature" field.
description: str

@dataclass
class TranslationInput:
Expand All @@ -28,11 +36,13 @@ class TranslationInput:
value: str | None
docstring: str | None
kind: DeclarationKind
header: str

neighbor: list[TranslatedItem]
dependency: list[TranslatedItem]

header: str
project_name: str

@property
def value_matters(self):
return self.kind in ["classInductive", "definition", "inductive", "structure"]
Expand Down Expand Up @@ -60,7 +70,12 @@ async def translate(self, data: TranslationInput) -> tuple[str, str] | None:
kind = "instance"
else:
kind = "definition" if data.value_matters else "theorem"
prompt = await self.template[kind].render_async(input=data)

if data.project_name == "metaprogramming":
prompt = metaprogramming_prompt(data)
else:
prompt = await self.template[kind].render_async(input=data)

if os.environ["DRY_RUN"] == "true":
logger.info("DRY_RUN:skipped informalization: %s", data.name)
return "Fake Name", f"Fake Description\nPrompt:\n{data}"
Expand All @@ -80,10 +95,17 @@ async def translate(self, data: TranslationInput) -> tuple[str, str] | None:
await asyncio.sleep(1)
continue
answer = response.choices[0].message.content
try:
name = self.pattern_name.search(answer).group(1)
description = self.pattern_description.search(answer).group(1)
except AttributeError: # unable to parse the result, at least one of the regex did not match
logger.info("while translating %s: unable to parse the result; retrying", data.name)
continue
if data.project_name == "metaprogramming":
root = ET.fromstring(f"<root>{answer}</root>")
name = root.findtext('informal_name')
description = root.findtext('informal_description')
if (name is None or description is None):
continue
else:
try:
name = self.pattern_name.search(answer).group(1)
description = self.pattern_description.search(answer).group(1)
except AttributeError: # unable to parse the result, at least one of the regex did not match
logger.info("while translating %s: unable to parse the result; retrying", data.name)
continue
return name.strip(), description.strip()
34 changes: 22 additions & 12 deletions database/vector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import logging

import chromadb
from chromadb.api.types import Metadata, ID, Embedding, Document
from jixia.structs import pp_name
from psycopg import Connection

from .embedding import MistralEmbedding

logger = logging.getLogger(__name__)

def create_vector_db(conn: Connection, path: str, batch_size: int):
def create_vector_db(conn: Connection, path: str, batch_size: int, project_name: str):
with open("prompt/embedding_instruction.txt") as fp:
instruction = fp.read()
MistralEmbedding.setup_env()
Expand All @@ -23,23 +24,32 @@ def create_vector_db(conn: Connection, path: str, batch_size: int):
)

with conn.cursor() as cursor:
cursor.execute("""
SELECT d.module_name, d.index, d.kind, d.name, d.signature, i.name, i.description
FROM
declaration d INNER JOIN informal i ON d.name = i.symbol_name
WHERE d.visible = TRUE
""")
cursor.execute(
"""
SELECT d.module_name, d.index, d.kind, d.name, d.signature, i.name, i.description, m.project_name
FROM
declaration d
INNER JOIN informal i ON d.name = i.symbol_name
INNER JOIN module m ON d.module_name = m.name
WHERE d.visible = TRUE AND m.project_name = %(project_name)s
""",
{
"project_name": project_name
}
)

while batch := cursor.fetchmany(batch_size):
batch_doc = []
batch_id = []
for module_name, index, kind, name, signature, informal_name, informal_description in batch:
batch_doc: list[Document] = []
batch_id: list[ID] = []
metadatas: list[Metadata] = []
for module_name, index, kind, name, signature, informal_name, informal_description, this_row_project_name in batch:
batch_doc.append(f"{kind} {name} {signature}\n{informal_name}: {informal_description}")
# NOTE: we use module name + index as document id as they cannot contain special characters
batch_id.append(f"{pp_name(module_name)}:{index}")
metadatas.append({ "project_name": this_row_project_name })
if os.environ["DRY_RUN"] == "true":
logger.info("DRY_RUN:skipped embedding: %s", f"{kind} {name} {signature} {informal_name}")
if os.environ["DRY_RUN"] == "true":
return
batch_embedding = embedding.embed(batch_doc)
collection.add(embeddings=batch_embedding, ids=batch_id)
batch_embedding : list[Embedding] = embedding.embed(batch_doc)
collection.add(embeddings=batch_embedding, ids=batch_id, metadatas=metadatas)
42 changes: 42 additions & 0 deletions prompt/metaprogramming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from database.translate import TranslatedItem, TranslationInput


def format_declaration(declaration: TranslatedItem):
result : list[str] = []

result.append("<declaration>")
result.append(f"<docstring>{declaration.docstring}</docstring>")
result.append(f"<name>{declaration.kind} {declaration.name}</name>")
result.append(f"<signature>{declaration.signature}</signature>")
result.append(f"<definition>{declaration.value}</definition>")
if declaration.informal_name is not None:
result.append(f"<informal_name>{declaration.informal_name}</informal_name>")
if declaration.informal_description is not None:
result.append(f"<informal_description>{declaration.informal_description}</informal_description>")
result.append("</declaration>")

return "\n".join(result)

def metaprogramming_prompt(input : TranslationInput):
result : list[str] = []

result.append("<instructions>Your task is to create an informal description of the following metaprogramming declaration in Lean 4. Later on, we will use them to create embedding vectors.</instructions>")

result.append(f"<file_docstring explanation='This is a docstring in the beginning of the .lean file where the declaration you are informalizing is located'>{input.header}</file_docstring>")

result.append("<neighbor_declarations explanation='These are the descriptions of the declarations that are located nearby in this Lean file.'>")
for neighbor in input.neighbor:
result.append(format_declaration(neighbor)) # type: ignore
result.append("</neighbor_declarations>")

result.append("<dependent_declarations explanation='These are the descriptions of the declarations that our declaration depends on.'>")
for dependency in input.dependency:
result.append(format_declaration(dependency)) # type: ignore
result.append("</dependent_declarations>")

result.append("<instructions>Finally, here is the declaration that you should create the description of.</instructions>")
result.append(format_declaration(input)) # type: ignore

result.append("<instructions>Please put your informal description into the following format: <informal_name>...<informal_name> (this is where you put the informal name of this Lean 4 declaration), <informal_description>...</informal_description> (this is where you put the informal description of this Lean 4 declaration). You can put your thinking into <thinking>...</thinking> tags.</instructions>")

return "\n".join(result)