From 46f6b791fc82283106faf347346ab39178e3d461 Mon Sep 17 00:00:00 2001 From: Evgenia Karunus Date: Fri, 16 May 2025 11:32:48 +0500 Subject: [PATCH 1/7] migrate.py - add "add module.project column" migration --- database/__init__.py | 3 +++ database/migrate.py | 6 ++++++ meta_instructions.md | 4 ++++ 3 files changed, 13 insertions(+) create mode 100644 database/migrate.py create mode 100644 meta_instructions.md diff --git a/database/__init__.py b/database/__init__.py index 30b7d19..1fefa6a 100644 --- a/database/__init__.py +++ b/database/__init__.py @@ -9,6 +9,7 @@ from .jixia_db import load_data from .vector_db import create_vector_db from .create_schema import create_schema +from .migrate import migrate def main(): @@ -43,6 +44,8 @@ def main(): args = parser.parse_args() with psycopg.connect(os.environ["CONNECTION_STRING"], autocommit=True) as conn: + if args.command == "migrate": + migrate(conn) if args.command == "schema": create_schema(conn) elif args.command == "jixia": diff --git a/database/migrate.py b/database/migrate.py new file mode 100644 index 0000000..81f33b2 --- /dev/null +++ b/database/migrate.py @@ -0,0 +1,6 @@ +from psycopg import Connection + +def migrate(conn: Connection): + with conn.cursor() as cursor: + cursor.execute("ALTER TABLE module ADD COLUMN project_name TEXT") + cursor.execute("UPDATE module SET project_name = 'mathlib'") diff --git a/meta_instructions.md b/meta_instructions.md new file mode 100644 index 0000000..c291432 --- /dev/null +++ b/meta_instructions.md @@ -0,0 +1,4 @@ + +# Indexing metaprogramming functions + +1. Run `python -m database migrate` From 9212d772cfc5bccc42183beedfa139fa57911d82 Mon Sep 17 00:00:00 2001 From: Evgenia Karunus Date: Mon, 19 May 2025 15:14:19 +0500 Subject: [PATCH 2/7] prompt/metaprogramming.py - create the prompt --- prompt/metaprogramming.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 prompt/metaprogramming.py diff --git a/prompt/metaprogramming.py b/prompt/metaprogramming.py new file mode 100644 index 0000000..ddcef48 --- /dev/null +++ b/prompt/metaprogramming.py @@ -0,0 +1,39 @@ +def format_declaration(declaration): + result : list[str] = [] + + result.append("") + result.append(f"{declaration['docstring']}") + result.append(f"{declaration['name']}") + result.append(f"{declaration['description']}") + result.append(f"{declaration['value']}") + if declaration['informal_name'] is not None: + result.append(f"{declaration['informal_name']}") + if declaration['informal_description'] is not None: + result.append(f"{declaration['informal_description']}") + result.append("") + + return "\n".join(result) + +def format_input(input_data): + result : list[str] = [] + + result.append("Your task is to create an informal description of the following metaprogramming declaration in Lean 4. Later on, we will use them to create embedding vectors.") + + result.append(f"{input_data['header']}") + + result.append("") + for neighbor in input_data["neighbor"]: + result.append(format_declaration(neighbor)) + result.append("") + + result.append("") + for dependency in input_data["dependency"]: + result.append(format_declaration(dependency)) + result.append("") + + result.append("Finally, here is the declaration that you should create the description of.") + result.append(format_declaration(input_data)) + + result.append("Please put your informal description into the following format: ... (this is where you put the informal name of this Lean 4 declaration), ... (this is where you put the informal description of this Lean 4 declaration). You can put your thinking into ... tags.") + + return "\n".join(result) From 692b72947875ef83287737ca41c578c2fd007018 Mon Sep 17 00:00:00 2001 From: Evgenia Karunus Date: Tue, 20 May 2025 11:05:48 +0500 Subject: [PATCH 3/7] informalization - make `.neighbors` and `.dependencies` equivalent to `TranslationInput` --- database/informalize.py | 16 +++++++++++----- database/translate.py | 10 +++++++++- prompt/metaprogramming.py | 35 +++++++++++++++++++---------------- 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/database/informalize.py b/database/informalize.py index 11123d7..c34f210 100644 --- a/database/informalize.py +++ b/database/informalize.py @@ -16,9 +16,12 @@ def find_neighbor(conn: Connection, module_name: LeanName, index: int, num_neigh with conn.cursor(row_factory=args_row(TranslatedItem)) as cursor: cursor.execute( """ - SELECT d.name, d.signature, i.name, i.description + SELECT s.name, d.signature, d.value, d.docstring, d.kind, i.name, i.description, d.signature FROM - declaration d + symbol s + INNER JOIN declaration d ON s.name = d.name + INNER JOIN module m ON s.module_name = m.name + INNER JOIN level l ON s.name = l.symbol_name LEFT JOIN informal i ON d.name = i.symbol_name WHERE d.module_name = %s AND d.index >= %s AND d.index <= %s @@ -32,11 +35,14 @@ def find_dependency(conn: Connection, name: LeanName) -> list[TranslatedItem]: with conn.cursor(row_factory=args_row(TranslatedItem)) as cursor: cursor.execute( """ - SELECT d.name, d.signature, i.name, i.description + SELECT s.name, d.signature, d.value, d.docstring, d.kind, i.name, i.description, d.signature FROM - declaration d - INNER JOIN dependency e ON d.name = e.target + symbol s + INNER JOIN declaration d ON s.name = d.name + INNER JOIN module m ON s.module_name = m.name + INNER JOIN level l ON s.name = l.symbol_name LEFT JOIN informal i ON d.name = i.symbol_name + INNER JOIN dependency e ON d.name = e.target WHERE e.source = %s """, diff --git a/database/translate.py b/database/translate.py index d2d9da4..f17c7d0 100644 --- a/database/translate.py +++ b/database/translate.py @@ -4,11 +4,13 @@ import re from dataclasses import dataclass from json import JSONDecodeError +import xml.etree.ElementTree as ET import jinja2 from jinja2 import Environment, FileSystemLoader from jixia.structs import DeclarationKind, LeanName, pp_name from openai import AsyncOpenAI +from prompt.metaprogramming import metaprogramming_prompt logger = logging.getLogger(__name__) @@ -16,10 +18,16 @@ @dataclass class TranslatedItem: name: LeanName - description: str + signature: str + value: str | None + docstring: str | None + kind: DeclarationKind + informal_name: str | None informal_description: str | None + # The "description" field is a copypaste of the more appropriately named "signature" field, it's here for backwards compatibility, and can be removed when all prompts are switch to using the "signature" field. + description: str @dataclass class TranslationInput: diff --git a/prompt/metaprogramming.py b/prompt/metaprogramming.py index ddcef48..c15a603 100644 --- a/prompt/metaprogramming.py +++ b/prompt/metaprogramming.py @@ -1,38 +1,41 @@ -def format_declaration(declaration): +from database.translate import TranslatedItem, TranslationInput + + +def format_declaration(declaration: TranslatedItem): result : list[str] = [] result.append("") - result.append(f"{declaration['docstring']}") - result.append(f"{declaration['name']}") - result.append(f"{declaration['description']}") - result.append(f"{declaration['value']}") - if declaration['informal_name'] is not None: - result.append(f"{declaration['informal_name']}") - if declaration['informal_description'] is not None: - result.append(f"{declaration['informal_description']}") + result.append(f"{declaration.docstring}") + result.append(f"{declaration.kind} {declaration.name}") + result.append(f"{declaration.signature}") + result.append(f"{declaration.value}") + if declaration.informal_name is not None: + result.append(f"{declaration.informal_name}") + if declaration.informal_description is not None: + result.append(f"{declaration.informal_description}") result.append("") return "\n".join(result) -def format_input(input_data): +def metaprogramming_prompt(input : TranslationInput): result : list[str] = [] result.append("Your task is to create an informal description of the following metaprogramming declaration in Lean 4. Later on, we will use them to create embedding vectors.") - result.append(f"{input_data['header']}") + result.append(f"{input.header}") result.append("") - for neighbor in input_data["neighbor"]: - result.append(format_declaration(neighbor)) + for neighbor in input.neighbor: + result.append(format_declaration(neighbor)) # type: ignore result.append("") result.append("") - for dependency in input_data["dependency"]: - result.append(format_declaration(dependency)) + for dependency in input.dependency: + result.append(format_declaration(dependency)) # type: ignore result.append("") result.append("Finally, here is the declaration that you should create the description of.") - result.append(format_declaration(input_data)) + result.append(format_declaration(input)) # type: ignore result.append("Please put your informal description into the following format: ... (this is where you put the informal name of this Lean 4 declaration), ... (this is where you put the informal description of this Lean 4 declaration). You can put your thinking into ... tags.") From e3adc5e20b8e0af25e8adf69dcfe45ab8290f47f Mon Sep 17 00:00:00 2001 From: Evgenia Karunus Date: Tue, 20 May 2025 12:34:01 +0500 Subject: [PATCH 4/7] everywhere - index based on `--project-name` --- database/__init__.py | 8 ++++++-- database/informalize.py | 3 ++- database/jixia_db.py | 6 +++--- database/translate.py | 30 ++++++++++++++++++++++-------- database/vector_db.py | 13 ++++++++----- 5 files changed, 41 insertions(+), 19 deletions(-) diff --git a/database/__init__.py b/database/__init__.py index 1fefa6a..04a4753 100644 --- a/database/__init__.py +++ b/database/__init__.py @@ -24,6 +24,7 @@ def main(): "prefixes", help="Comma-separated list of module prefixes to be included in the index; e.g., Init,Mathlib", ) + jixia_parser.add_argument("--project-name") informal_parser = subparser.add_parser("informal") informal_parser.set_defaults(command="informal") informal_parser.add_argument("--batch-size", type=int, default=50) @@ -37,9 +38,11 @@ def main(): type=int, help="Limit max number of items per level. Used for testing.", ) + informal_parser.add_argument("--project-name") vector_db_parser = subparser.add_parser("vector-db") vector_db_parser.set_defaults(command="vector-db") vector_db_parser.add_argument("--batch-size", type=int, default=8) + vector_db_parser.add_argument("--project-name") args = parser.parse_args() @@ -51,13 +54,14 @@ def main(): elif args.command == "jixia": project = LeanProject(args.project_root) prefixes = [parse_name(p) for p in args.prefixes.split(",")] - load_data(project, prefixes, conn) + load_data(project, prefixes, conn, project_name=args.project_name) elif args.command == "informal": generate_informal( conn, + project_name=args.project_name, batch_size=args.batch_size, limit_level=args.limit_level, limit_num_per_level=args.limit_num_per_level, ) elif args.command == "vector-db": - create_vector_db(conn, os.environ["CHROMA_PATH"], batch_size=args.batch_size) + create_vector_db(conn, os.environ["CHROMA_PATH"], batch_size=args.batch_size, project_name=args.project_name) diff --git a/database/informalize.py b/database/informalize.py index c34f210..dfe58be 100644 --- a/database/informalize.py +++ b/database/informalize.py @@ -51,7 +51,7 @@ def find_dependency(conn: Connection, name: LeanName) -> list[TranslatedItem]: return cursor.fetchall() -def generate_informal(conn: Connection, batch_size: int = 50, limit_level: int | None = None, limit_num_per_level: int | None = None): +def generate_informal(conn: Connection, project_name: str, batch_size: int = 50, limit_level: int | None = None, limit_num_per_level: int | None = None): max_level: int if limit_level is None: with conn.cursor(row_factory=scalar_row) as cursor: @@ -113,6 +113,7 @@ async def translate_and_insert(name: LeanName, data: TranslationInput): header=header, neighbor=neighbor, dependency=dependency, + project_name=project_name, ) tasks.append(translate_and_insert(name, ti)) diff --git a/database/jixia_db.py b/database/jixia_db.py index 03cfc9a..be3da36 100644 --- a/database/jixia_db.py +++ b/database/jixia_db.py @@ -24,12 +24,12 @@ def _get_value(declaration: Declaration, module_content): else: return None -def load_data(project: LeanProject, prefixes: list[LeanName], conn: Connection): +def load_data(project: LeanProject, prefixes: list[LeanName], conn: Connection, project_name): def load_module(data: Iterable[LeanName], base_dir: Path): - values = ((Jsonb(m), project.path_of_module(m, base_dir).read_bytes(), project.load_module_info(m).docstring) for m in data) + values = ((Jsonb(m), project.path_of_module(m, base_dir).read_bytes(), project.load_module_info(m).docstring, project_name) for m in data) cursor.executemany( """ - INSERT INTO module (name, content, docstring) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING + INSERT INTO module (name, content, docstring, project_name) VALUES (%s, %s, %s, %s) ON CONFLICT DO NOTHING """, values, ) diff --git a/database/translate.py b/database/translate.py index f17c7d0..df772e2 100644 --- a/database/translate.py +++ b/database/translate.py @@ -36,11 +36,13 @@ class TranslationInput: value: str | None docstring: str | None kind: DeclarationKind - header: str neighbor: list[TranslatedItem] dependency: list[TranslatedItem] + header: str + project_name: str + @property def value_matters(self): return self.kind in ["classInductive", "definition", "inductive", "structure"] @@ -68,7 +70,12 @@ async def translate(self, data: TranslationInput) -> tuple[str, str] | None: kind = "instance" else: kind = "definition" if data.value_matters else "theorem" - prompt = await self.template[kind].render_async(input=data) + + if data.project_name == "metaprogramming": + prompt = metaprogramming_prompt(data) + else: + prompt = await self.template[kind].render_async(input=data) + if os.environ["DRY_RUN"] == "true": logger.info("DRY_RUN:skipped informalization: %s", data.name) return "Fake Name", f"Fake Description\nPrompt:\n{data}" @@ -88,10 +95,17 @@ async def translate(self, data: TranslationInput) -> tuple[str, str] | None: await asyncio.sleep(1) continue answer = response.choices[0].message.content - try: - name = self.pattern_name.search(answer).group(1) - description = self.pattern_description.search(answer).group(1) - except AttributeError: # unable to parse the result, at least one of the regex did not match - logger.info("while translating %s: unable to parse the result; retrying", data.name) - continue + if data.project_name == "metaprogramming": + root = ET.fromstring(f"{answer}") + name = root.findtext('informal_name') + description = root.findtext('informal_description') + if (name is None or description is None): + continue + else: + try: + name = self.pattern_name.search(answer).group(1) + description = self.pattern_description.search(answer).group(1) + except AttributeError: # unable to parse the result, at least one of the regex did not match + logger.info("while translating %s: unable to parse the result; retrying", data.name) + continue return name.strip(), description.strip() diff --git a/database/vector_db.py b/database/vector_db.py index 1b458eb..e5a2d91 100644 --- a/database/vector_db.py +++ b/database/vector_db.py @@ -2,6 +2,7 @@ import logging import chromadb +from chromadb.api.types import Metadata, ID, Embedding, Document from jixia.structs import pp_name from psycopg import Connection @@ -9,7 +10,7 @@ logger = logging.getLogger(__name__) -def create_vector_db(conn: Connection, path: str, batch_size: int): +def create_vector_db(conn: Connection, path: str, batch_size: int, project_name: str): with open("prompt/embedding_instruction.txt") as fp: instruction = fp.read() MistralEmbedding.setup_env() @@ -31,15 +32,17 @@ def create_vector_db(conn: Connection, path: str, batch_size: int): """) while batch := cursor.fetchmany(batch_size): - batch_doc = [] - batch_id = [] + batch_doc: list[Document] = [] + batch_id: list[ID] = [] + metadatas: list[Metadata] = [] for module_name, index, kind, name, signature, informal_name, informal_description in batch: batch_doc.append(f"{kind} {name} {signature}\n{informal_name}: {informal_description}") # NOTE: we use module name + index as document id as they cannot contain special characters batch_id.append(f"{pp_name(module_name)}:{index}") + metadatas.append({ "project_name": project_name }) if os.environ["DRY_RUN"] == "true": logger.info("DRY_RUN:skipped embedding: %s", f"{kind} {name} {signature} {informal_name}") if os.environ["DRY_RUN"] == "true": return - batch_embedding = embedding.embed(batch_doc) - collection.add(embeddings=batch_embedding, ids=batch_id) + batch_embedding : list[Embedding] = embedding.embed(batch_doc) + collection.add(embeddings=batch_embedding, ids=batch_id, metadatas=metadatas) From 14ef3f0182a770e9f0203f1ad50c2f695ec5aa09 Mon Sep 17 00:00:00 2001 From: Evgenia Karunus Date: Tue, 20 May 2025 12:53:27 +0500 Subject: [PATCH 5/7] informalize.py - remove `--project-name` cli argument --- database/__init__.py | 2 -- database/informalize.py | 6 +++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/database/__init__.py b/database/__init__.py index 04a4753..ccf6e66 100644 --- a/database/__init__.py +++ b/database/__init__.py @@ -38,7 +38,6 @@ def main(): type=int, help="Limit max number of items per level. Used for testing.", ) - informal_parser.add_argument("--project-name") vector_db_parser = subparser.add_parser("vector-db") vector_db_parser.set_defaults(command="vector-db") vector_db_parser.add_argument("--batch-size", type=int, default=8) @@ -58,7 +57,6 @@ def main(): elif args.command == "informal": generate_informal( conn, - project_name=args.project_name, batch_size=args.batch_size, limit_level=args.limit_level, limit_num_per_level=args.limit_num_per_level, diff --git a/database/informalize.py b/database/informalize.py index dfe58be..2fbbb29 100644 --- a/database/informalize.py +++ b/database/informalize.py @@ -51,7 +51,7 @@ def find_dependency(conn: Connection, name: LeanName) -> list[TranslatedItem]: return cursor.fetchall() -def generate_informal(conn: Connection, project_name: str, batch_size: int = 50, limit_level: int | None = None, limit_num_per_level: int | None = None): +def generate_informal(conn: Connection, batch_size: int = 50, limit_level: int | None = None, limit_num_per_level: int | None = None): max_level: int if limit_level is None: with conn.cursor(row_factory=scalar_row) as cursor: @@ -63,7 +63,7 @@ def generate_informal(conn: Connection, project_name: str, batch_size: int = 50, with conn.cursor() as cursor, conn.cursor() as insert_cursor: for level in range(max_level + 1): query = """ - SELECT s.name, d.signature, d.value, d.docstring, d.kind, m.docstring, d.module_name, d.index + SELECT s.name, d.signature, d.value, d.docstring, d.kind, m.docstring, d.module_name, d.index, m.project_name FROM symbol s INNER JOIN declaration d ON s.name = d.name @@ -98,7 +98,7 @@ async def translate_and_insert(name: LeanName, data: TranslationInput): tasks.clear() for row in batch: - name, signature, value, docstring, kind, header, module_name, index = row + name, signature, value, docstring, kind, header, module_name, index, project_name = row logger.info("translating %s", name) neighbor = find_neighbor(conn, module_name, index) From 1cf33344e0f83752cb18e299023adc3e28dc0a4b Mon Sep 17 00:00:00 2001 From: Evgenia Karunus Date: Thu, 22 May 2025 13:21:36 +0500 Subject: [PATCH 6/7] vector_db.py - only put cli-project_name to chromadb --- database/vector_db.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/database/vector_db.py b/database/vector_db.py index e5a2d91..a80a949 100644 --- a/database/vector_db.py +++ b/database/vector_db.py @@ -24,22 +24,29 @@ def create_vector_db(conn: Connection, path: str, batch_size: int, project_name: ) with conn.cursor() as cursor: - cursor.execute(""" - SELECT d.module_name, d.index, d.kind, d.name, d.signature, i.name, i.description - FROM - declaration d INNER JOIN informal i ON d.name = i.symbol_name - WHERE d.visible = TRUE - """) + cursor.execute( + """ + SELECT d.module_name, d.index, d.kind, d.name, d.signature, i.name, i.description, m.project_name + FROM + declaration d + INNER JOIN informal i ON d.name = i.symbol_name + INNER JOIN module m ON d.module_name = m.name + WHERE d.visible = TRUE AND m.project_name = %(project_name)s + """, + { + "project_name": project_name + } + ) while batch := cursor.fetchmany(batch_size): batch_doc: list[Document] = [] batch_id: list[ID] = [] metadatas: list[Metadata] = [] - for module_name, index, kind, name, signature, informal_name, informal_description in batch: + for module_name, index, kind, name, signature, informal_name, informal_description, this_row_project_name in batch: batch_doc.append(f"{kind} {name} {signature}\n{informal_name}: {informal_description}") # NOTE: we use module name + index as document id as they cannot contain special characters batch_id.append(f"{pp_name(module_name)}:{index}") - metadatas.append({ "project_name": project_name }) + metadatas.append({ "project_name": this_row_project_name }) if os.environ["DRY_RUN"] == "true": logger.info("DRY_RUN:skipped embedding: %s", f"{kind} {name} {signature} {informal_name}") if os.environ["DRY_RUN"] == "true": From 511c56853c9f68ccfc3776e27ad716bd8b987483 Mon Sep 17 00:00:00 2001 From: Evgenia Karunus Date: Thu, 22 May 2025 13:24:49 +0500 Subject: [PATCH 7/7] meta_instructions.md - remove file, moved this to pr description --- meta_instructions.md | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 meta_instructions.md diff --git a/meta_instructions.md b/meta_instructions.md deleted file mode 100644 index c291432..0000000 --- a/meta_instructions.md +++ /dev/null @@ -1,4 +0,0 @@ - -# Indexing metaprogramming functions - -1. Run `python -m database migrate`