diff --git a/database/__init__.py b/database/__init__.py index 30b7d19..ccf6e66 100644 --- a/database/__init__.py +++ b/database/__init__.py @@ -9,6 +9,7 @@ from .jixia_db import load_data from .vector_db import create_vector_db from .create_schema import create_schema +from .migrate import migrate def main(): @@ -23,6 +24,7 @@ def main(): "prefixes", help="Comma-separated list of module prefixes to be included in the index; e.g., Init,Mathlib", ) + jixia_parser.add_argument("--project-name") informal_parser = subparser.add_parser("informal") informal_parser.set_defaults(command="informal") informal_parser.add_argument("--batch-size", type=int, default=50) @@ -39,16 +41,19 @@ def main(): vector_db_parser = subparser.add_parser("vector-db") vector_db_parser.set_defaults(command="vector-db") vector_db_parser.add_argument("--batch-size", type=int, default=8) + vector_db_parser.add_argument("--project-name") args = parser.parse_args() with psycopg.connect(os.environ["CONNECTION_STRING"], autocommit=True) as conn: + if args.command == "migrate": + migrate(conn) if args.command == "schema": create_schema(conn) elif args.command == "jixia": project = LeanProject(args.project_root) prefixes = [parse_name(p) for p in args.prefixes.split(",")] - load_data(project, prefixes, conn) + load_data(project, prefixes, conn, project_name=args.project_name) elif args.command == "informal": generate_informal( conn, @@ -57,4 +62,4 @@ def main(): limit_num_per_level=args.limit_num_per_level, ) elif args.command == "vector-db": - create_vector_db(conn, os.environ["CHROMA_PATH"], batch_size=args.batch_size) + create_vector_db(conn, os.environ["CHROMA_PATH"], batch_size=args.batch_size, project_name=args.project_name) diff --git a/database/informalize.py b/database/informalize.py index 11123d7..2fbbb29 100644 --- a/database/informalize.py +++ b/database/informalize.py @@ -16,9 +16,12 @@ def find_neighbor(conn: Connection, module_name: LeanName, index: int, num_neigh with conn.cursor(row_factory=args_row(TranslatedItem)) as cursor: cursor.execute( """ - SELECT d.name, d.signature, i.name, i.description + SELECT s.name, d.signature, d.value, d.docstring, d.kind, i.name, i.description, d.signature FROM - declaration d + symbol s + INNER JOIN declaration d ON s.name = d.name + INNER JOIN module m ON s.module_name = m.name + INNER JOIN level l ON s.name = l.symbol_name LEFT JOIN informal i ON d.name = i.symbol_name WHERE d.module_name = %s AND d.index >= %s AND d.index <= %s @@ -32,11 +35,14 @@ def find_dependency(conn: Connection, name: LeanName) -> list[TranslatedItem]: with conn.cursor(row_factory=args_row(TranslatedItem)) as cursor: cursor.execute( """ - SELECT d.name, d.signature, i.name, i.description + SELECT s.name, d.signature, d.value, d.docstring, d.kind, i.name, i.description, d.signature FROM - declaration d - INNER JOIN dependency e ON d.name = e.target + symbol s + INNER JOIN declaration d ON s.name = d.name + INNER JOIN module m ON s.module_name = m.name + INNER JOIN level l ON s.name = l.symbol_name LEFT JOIN informal i ON d.name = i.symbol_name + INNER JOIN dependency e ON d.name = e.target WHERE e.source = %s """, @@ -57,7 +63,7 @@ def generate_informal(conn: Connection, batch_size: int = 50, limit_level: int | with conn.cursor() as cursor, conn.cursor() as insert_cursor: for level in range(max_level + 1): query = """ - SELECT s.name, d.signature, d.value, d.docstring, d.kind, m.docstring, d.module_name, d.index + SELECT s.name, d.signature, d.value, d.docstring, d.kind, m.docstring, d.module_name, d.index, m.project_name FROM symbol s INNER JOIN declaration d ON s.name = d.name @@ -92,7 +98,7 @@ async def translate_and_insert(name: LeanName, data: TranslationInput): tasks.clear() for row in batch: - name, signature, value, docstring, kind, header, module_name, index = row + name, signature, value, docstring, kind, header, module_name, index, project_name = row logger.info("translating %s", name) neighbor = find_neighbor(conn, module_name, index) @@ -107,6 +113,7 @@ async def translate_and_insert(name: LeanName, data: TranslationInput): header=header, neighbor=neighbor, dependency=dependency, + project_name=project_name, ) tasks.append(translate_and_insert(name, ti)) diff --git a/database/jixia_db.py b/database/jixia_db.py index 03cfc9a..be3da36 100644 --- a/database/jixia_db.py +++ b/database/jixia_db.py @@ -24,12 +24,12 @@ def _get_value(declaration: Declaration, module_content): else: return None -def load_data(project: LeanProject, prefixes: list[LeanName], conn: Connection): +def load_data(project: LeanProject, prefixes: list[LeanName], conn: Connection, project_name): def load_module(data: Iterable[LeanName], base_dir: Path): - values = ((Jsonb(m), project.path_of_module(m, base_dir).read_bytes(), project.load_module_info(m).docstring) for m in data) + values = ((Jsonb(m), project.path_of_module(m, base_dir).read_bytes(), project.load_module_info(m).docstring, project_name) for m in data) cursor.executemany( """ - INSERT INTO module (name, content, docstring) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING + INSERT INTO module (name, content, docstring, project_name) VALUES (%s, %s, %s, %s) ON CONFLICT DO NOTHING """, values, ) diff --git a/database/migrate.py b/database/migrate.py new file mode 100644 index 0000000..81f33b2 --- /dev/null +++ b/database/migrate.py @@ -0,0 +1,6 @@ +from psycopg import Connection + +def migrate(conn: Connection): + with conn.cursor() as cursor: + cursor.execute("ALTER TABLE module ADD COLUMN project_name TEXT") + cursor.execute("UPDATE module SET project_name = 'mathlib'") diff --git a/database/translate.py b/database/translate.py index d2d9da4..df772e2 100644 --- a/database/translate.py +++ b/database/translate.py @@ -4,11 +4,13 @@ import re from dataclasses import dataclass from json import JSONDecodeError +import xml.etree.ElementTree as ET import jinja2 from jinja2 import Environment, FileSystemLoader from jixia.structs import DeclarationKind, LeanName, pp_name from openai import AsyncOpenAI +from prompt.metaprogramming import metaprogramming_prompt logger = logging.getLogger(__name__) @@ -16,10 +18,16 @@ @dataclass class TranslatedItem: name: LeanName - description: str + signature: str + value: str | None + docstring: str | None + kind: DeclarationKind + informal_name: str | None informal_description: str | None + # The "description" field is a copypaste of the more appropriately named "signature" field, it's here for backwards compatibility, and can be removed when all prompts are switch to using the "signature" field. + description: str @dataclass class TranslationInput: @@ -28,11 +36,13 @@ class TranslationInput: value: str | None docstring: str | None kind: DeclarationKind - header: str neighbor: list[TranslatedItem] dependency: list[TranslatedItem] + header: str + project_name: str + @property def value_matters(self): return self.kind in ["classInductive", "definition", "inductive", "structure"] @@ -60,7 +70,12 @@ async def translate(self, data: TranslationInput) -> tuple[str, str] | None: kind = "instance" else: kind = "definition" if data.value_matters else "theorem" - prompt = await self.template[kind].render_async(input=data) + + if data.project_name == "metaprogramming": + prompt = metaprogramming_prompt(data) + else: + prompt = await self.template[kind].render_async(input=data) + if os.environ["DRY_RUN"] == "true": logger.info("DRY_RUN:skipped informalization: %s", data.name) return "Fake Name", f"Fake Description\nPrompt:\n{data}" @@ -80,10 +95,17 @@ async def translate(self, data: TranslationInput) -> tuple[str, str] | None: await asyncio.sleep(1) continue answer = response.choices[0].message.content - try: - name = self.pattern_name.search(answer).group(1) - description = self.pattern_description.search(answer).group(1) - except AttributeError: # unable to parse the result, at least one of the regex did not match - logger.info("while translating %s: unable to parse the result; retrying", data.name) - continue + if data.project_name == "metaprogramming": + root = ET.fromstring(f"{answer}") + name = root.findtext('informal_name') + description = root.findtext('informal_description') + if (name is None or description is None): + continue + else: + try: + name = self.pattern_name.search(answer).group(1) + description = self.pattern_description.search(answer).group(1) + except AttributeError: # unable to parse the result, at least one of the regex did not match + logger.info("while translating %s: unable to parse the result; retrying", data.name) + continue return name.strip(), description.strip() diff --git a/database/vector_db.py b/database/vector_db.py index 1b458eb..a80a949 100644 --- a/database/vector_db.py +++ b/database/vector_db.py @@ -2,6 +2,7 @@ import logging import chromadb +from chromadb.api.types import Metadata, ID, Embedding, Document from jixia.structs import pp_name from psycopg import Connection @@ -9,7 +10,7 @@ logger = logging.getLogger(__name__) -def create_vector_db(conn: Connection, path: str, batch_size: int): +def create_vector_db(conn: Connection, path: str, batch_size: int, project_name: str): with open("prompt/embedding_instruction.txt") as fp: instruction = fp.read() MistralEmbedding.setup_env() @@ -23,23 +24,32 @@ def create_vector_db(conn: Connection, path: str, batch_size: int): ) with conn.cursor() as cursor: - cursor.execute(""" - SELECT d.module_name, d.index, d.kind, d.name, d.signature, i.name, i.description - FROM - declaration d INNER JOIN informal i ON d.name = i.symbol_name - WHERE d.visible = TRUE - """) + cursor.execute( + """ + SELECT d.module_name, d.index, d.kind, d.name, d.signature, i.name, i.description, m.project_name + FROM + declaration d + INNER JOIN informal i ON d.name = i.symbol_name + INNER JOIN module m ON d.module_name = m.name + WHERE d.visible = TRUE AND m.project_name = %(project_name)s + """, + { + "project_name": project_name + } + ) while batch := cursor.fetchmany(batch_size): - batch_doc = [] - batch_id = [] - for module_name, index, kind, name, signature, informal_name, informal_description in batch: + batch_doc: list[Document] = [] + batch_id: list[ID] = [] + metadatas: list[Metadata] = [] + for module_name, index, kind, name, signature, informal_name, informal_description, this_row_project_name in batch: batch_doc.append(f"{kind} {name} {signature}\n{informal_name}: {informal_description}") # NOTE: we use module name + index as document id as they cannot contain special characters batch_id.append(f"{pp_name(module_name)}:{index}") + metadatas.append({ "project_name": this_row_project_name }) if os.environ["DRY_RUN"] == "true": logger.info("DRY_RUN:skipped embedding: %s", f"{kind} {name} {signature} {informal_name}") if os.environ["DRY_RUN"] == "true": return - batch_embedding = embedding.embed(batch_doc) - collection.add(embeddings=batch_embedding, ids=batch_id) + batch_embedding : list[Embedding] = embedding.embed(batch_doc) + collection.add(embeddings=batch_embedding, ids=batch_id, metadatas=metadatas) diff --git a/prompt/metaprogramming.py b/prompt/metaprogramming.py new file mode 100644 index 0000000..c15a603 --- /dev/null +++ b/prompt/metaprogramming.py @@ -0,0 +1,42 @@ +from database.translate import TranslatedItem, TranslationInput + + +def format_declaration(declaration: TranslatedItem): + result : list[str] = [] + + result.append("") + result.append(f"{declaration.docstring}") + result.append(f"{declaration.kind} {declaration.name}") + result.append(f"{declaration.signature}") + result.append(f"{declaration.value}") + if declaration.informal_name is not None: + result.append(f"{declaration.informal_name}") + if declaration.informal_description is not None: + result.append(f"{declaration.informal_description}") + result.append("") + + return "\n".join(result) + +def metaprogramming_prompt(input : TranslationInput): + result : list[str] = [] + + result.append("Your task is to create an informal description of the following metaprogramming declaration in Lean 4. Later on, we will use them to create embedding vectors.") + + result.append(f"{input.header}") + + result.append("") + for neighbor in input.neighbor: + result.append(format_declaration(neighbor)) # type: ignore + result.append("") + + result.append("") + for dependency in input.dependency: + result.append(format_declaration(dependency)) # type: ignore + result.append("") + + result.append("Finally, here is the declaration that you should create the description of.") + result.append(format_declaration(input)) # type: ignore + + result.append("Please put your informal description into the following format: ... (this is where you put the informal name of this Lean 4 declaration), ... (this is where you put the informal description of this Lean 4 declaration). You can put your thinking into ... tags.") + + return "\n".join(result)