diff --git a/database/__init__.py b/database/__init__.py
index 30b7d19..ccf6e66 100644
--- a/database/__init__.py
+++ b/database/__init__.py
@@ -9,6 +9,7 @@
from .jixia_db import load_data
from .vector_db import create_vector_db
from .create_schema import create_schema
+from .migrate import migrate
def main():
@@ -23,6 +24,7 @@ def main():
"prefixes",
help="Comma-separated list of module prefixes to be included in the index; e.g., Init,Mathlib",
)
+ jixia_parser.add_argument("--project-name")
informal_parser = subparser.add_parser("informal")
informal_parser.set_defaults(command="informal")
informal_parser.add_argument("--batch-size", type=int, default=50)
@@ -39,16 +41,19 @@ def main():
vector_db_parser = subparser.add_parser("vector-db")
vector_db_parser.set_defaults(command="vector-db")
vector_db_parser.add_argument("--batch-size", type=int, default=8)
+ vector_db_parser.add_argument("--project-name")
args = parser.parse_args()
with psycopg.connect(os.environ["CONNECTION_STRING"], autocommit=True) as conn:
+ if args.command == "migrate":
+ migrate(conn)
if args.command == "schema":
create_schema(conn)
elif args.command == "jixia":
project = LeanProject(args.project_root)
prefixes = [parse_name(p) for p in args.prefixes.split(",")]
- load_data(project, prefixes, conn)
+ load_data(project, prefixes, conn, project_name=args.project_name)
elif args.command == "informal":
generate_informal(
conn,
@@ -57,4 +62,4 @@ def main():
limit_num_per_level=args.limit_num_per_level,
)
elif args.command == "vector-db":
- create_vector_db(conn, os.environ["CHROMA_PATH"], batch_size=args.batch_size)
+ create_vector_db(conn, os.environ["CHROMA_PATH"], batch_size=args.batch_size, project_name=args.project_name)
diff --git a/database/informalize.py b/database/informalize.py
index 11123d7..2fbbb29 100644
--- a/database/informalize.py
+++ b/database/informalize.py
@@ -16,9 +16,12 @@ def find_neighbor(conn: Connection, module_name: LeanName, index: int, num_neigh
with conn.cursor(row_factory=args_row(TranslatedItem)) as cursor:
cursor.execute(
"""
- SELECT d.name, d.signature, i.name, i.description
+ SELECT s.name, d.signature, d.value, d.docstring, d.kind, i.name, i.description, d.signature
FROM
- declaration d
+ symbol s
+ INNER JOIN declaration d ON s.name = d.name
+ INNER JOIN module m ON s.module_name = m.name
+ INNER JOIN level l ON s.name = l.symbol_name
LEFT JOIN informal i ON d.name = i.symbol_name
WHERE
d.module_name = %s AND d.index >= %s AND d.index <= %s
@@ -32,11 +35,14 @@ def find_dependency(conn: Connection, name: LeanName) -> list[TranslatedItem]:
with conn.cursor(row_factory=args_row(TranslatedItem)) as cursor:
cursor.execute(
"""
- SELECT d.name, d.signature, i.name, i.description
+ SELECT s.name, d.signature, d.value, d.docstring, d.kind, i.name, i.description, d.signature
FROM
- declaration d
- INNER JOIN dependency e ON d.name = e.target
+ symbol s
+ INNER JOIN declaration d ON s.name = d.name
+ INNER JOIN module m ON s.module_name = m.name
+ INNER JOIN level l ON s.name = l.symbol_name
LEFT JOIN informal i ON d.name = i.symbol_name
+ INNER JOIN dependency e ON d.name = e.target
WHERE
e.source = %s
""",
@@ -57,7 +63,7 @@ def generate_informal(conn: Connection, batch_size: int = 50, limit_level: int |
with conn.cursor() as cursor, conn.cursor() as insert_cursor:
for level in range(max_level + 1):
query = """
- SELECT s.name, d.signature, d.value, d.docstring, d.kind, m.docstring, d.module_name, d.index
+ SELECT s.name, d.signature, d.value, d.docstring, d.kind, m.docstring, d.module_name, d.index, m.project_name
FROM
symbol s
INNER JOIN declaration d ON s.name = d.name
@@ -92,7 +98,7 @@ async def translate_and_insert(name: LeanName, data: TranslationInput):
tasks.clear()
for row in batch:
- name, signature, value, docstring, kind, header, module_name, index = row
+ name, signature, value, docstring, kind, header, module_name, index, project_name = row
logger.info("translating %s", name)
neighbor = find_neighbor(conn, module_name, index)
@@ -107,6 +113,7 @@ async def translate_and_insert(name: LeanName, data: TranslationInput):
header=header,
neighbor=neighbor,
dependency=dependency,
+ project_name=project_name,
)
tasks.append(translate_and_insert(name, ti))
diff --git a/database/jixia_db.py b/database/jixia_db.py
index 03cfc9a..be3da36 100644
--- a/database/jixia_db.py
+++ b/database/jixia_db.py
@@ -24,12 +24,12 @@ def _get_value(declaration: Declaration, module_content):
else:
return None
-def load_data(project: LeanProject, prefixes: list[LeanName], conn: Connection):
+def load_data(project: LeanProject, prefixes: list[LeanName], conn: Connection, project_name):
def load_module(data: Iterable[LeanName], base_dir: Path):
- values = ((Jsonb(m), project.path_of_module(m, base_dir).read_bytes(), project.load_module_info(m).docstring) for m in data)
+ values = ((Jsonb(m), project.path_of_module(m, base_dir).read_bytes(), project.load_module_info(m).docstring, project_name) for m in data)
cursor.executemany(
"""
- INSERT INTO module (name, content, docstring) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING
+ INSERT INTO module (name, content, docstring, project_name) VALUES (%s, %s, %s, %s) ON CONFLICT DO NOTHING
""",
values,
)
diff --git a/database/migrate.py b/database/migrate.py
new file mode 100644
index 0000000..81f33b2
--- /dev/null
+++ b/database/migrate.py
@@ -0,0 +1,6 @@
+from psycopg import Connection
+
+def migrate(conn: Connection):
+ with conn.cursor() as cursor:
+ cursor.execute("ALTER TABLE module ADD COLUMN project_name TEXT")
+ cursor.execute("UPDATE module SET project_name = 'mathlib'")
diff --git a/database/translate.py b/database/translate.py
index d2d9da4..df772e2 100644
--- a/database/translate.py
+++ b/database/translate.py
@@ -4,11 +4,13 @@
import re
from dataclasses import dataclass
from json import JSONDecodeError
+import xml.etree.ElementTree as ET
import jinja2
from jinja2 import Environment, FileSystemLoader
from jixia.structs import DeclarationKind, LeanName, pp_name
from openai import AsyncOpenAI
+from prompt.metaprogramming import metaprogramming_prompt
logger = logging.getLogger(__name__)
@@ -16,10 +18,16 @@
@dataclass
class TranslatedItem:
name: LeanName
- description: str
+ signature: str
+ value: str | None
+ docstring: str | None
+ kind: DeclarationKind
+
informal_name: str | None
informal_description: str | None
+ # The "description" field is a copypaste of the more appropriately named "signature" field, it's here for backwards compatibility, and can be removed when all prompts are switch to using the "signature" field.
+ description: str
@dataclass
class TranslationInput:
@@ -28,11 +36,13 @@ class TranslationInput:
value: str | None
docstring: str | None
kind: DeclarationKind
- header: str
neighbor: list[TranslatedItem]
dependency: list[TranslatedItem]
+ header: str
+ project_name: str
+
@property
def value_matters(self):
return self.kind in ["classInductive", "definition", "inductive", "structure"]
@@ -60,7 +70,12 @@ async def translate(self, data: TranslationInput) -> tuple[str, str] | None:
kind = "instance"
else:
kind = "definition" if data.value_matters else "theorem"
- prompt = await self.template[kind].render_async(input=data)
+
+ if data.project_name == "metaprogramming":
+ prompt = metaprogramming_prompt(data)
+ else:
+ prompt = await self.template[kind].render_async(input=data)
+
if os.environ["DRY_RUN"] == "true":
logger.info("DRY_RUN:skipped informalization: %s", data.name)
return "Fake Name", f"Fake Description\nPrompt:\n{data}"
@@ -80,10 +95,17 @@ async def translate(self, data: TranslationInput) -> tuple[str, str] | None:
await asyncio.sleep(1)
continue
answer = response.choices[0].message.content
- try:
- name = self.pattern_name.search(answer).group(1)
- description = self.pattern_description.search(answer).group(1)
- except AttributeError: # unable to parse the result, at least one of the regex did not match
- logger.info("while translating %s: unable to parse the result; retrying", data.name)
- continue
+ if data.project_name == "metaprogramming":
+ root = ET.fromstring(f"{answer}")
+ name = root.findtext('informal_name')
+ description = root.findtext('informal_description')
+ if (name is None or description is None):
+ continue
+ else:
+ try:
+ name = self.pattern_name.search(answer).group(1)
+ description = self.pattern_description.search(answer).group(1)
+ except AttributeError: # unable to parse the result, at least one of the regex did not match
+ logger.info("while translating %s: unable to parse the result; retrying", data.name)
+ continue
return name.strip(), description.strip()
diff --git a/database/vector_db.py b/database/vector_db.py
index 1b458eb..a80a949 100644
--- a/database/vector_db.py
+++ b/database/vector_db.py
@@ -2,6 +2,7 @@
import logging
import chromadb
+from chromadb.api.types import Metadata, ID, Embedding, Document
from jixia.structs import pp_name
from psycopg import Connection
@@ -9,7 +10,7 @@
logger = logging.getLogger(__name__)
-def create_vector_db(conn: Connection, path: str, batch_size: int):
+def create_vector_db(conn: Connection, path: str, batch_size: int, project_name: str):
with open("prompt/embedding_instruction.txt") as fp:
instruction = fp.read()
MistralEmbedding.setup_env()
@@ -23,23 +24,32 @@ def create_vector_db(conn: Connection, path: str, batch_size: int):
)
with conn.cursor() as cursor:
- cursor.execute("""
- SELECT d.module_name, d.index, d.kind, d.name, d.signature, i.name, i.description
- FROM
- declaration d INNER JOIN informal i ON d.name = i.symbol_name
- WHERE d.visible = TRUE
- """)
+ cursor.execute(
+ """
+ SELECT d.module_name, d.index, d.kind, d.name, d.signature, i.name, i.description, m.project_name
+ FROM
+ declaration d
+ INNER JOIN informal i ON d.name = i.symbol_name
+ INNER JOIN module m ON d.module_name = m.name
+ WHERE d.visible = TRUE AND m.project_name = %(project_name)s
+ """,
+ {
+ "project_name": project_name
+ }
+ )
while batch := cursor.fetchmany(batch_size):
- batch_doc = []
- batch_id = []
- for module_name, index, kind, name, signature, informal_name, informal_description in batch:
+ batch_doc: list[Document] = []
+ batch_id: list[ID] = []
+ metadatas: list[Metadata] = []
+ for module_name, index, kind, name, signature, informal_name, informal_description, this_row_project_name in batch:
batch_doc.append(f"{kind} {name} {signature}\n{informal_name}: {informal_description}")
# NOTE: we use module name + index as document id as they cannot contain special characters
batch_id.append(f"{pp_name(module_name)}:{index}")
+ metadatas.append({ "project_name": this_row_project_name })
if os.environ["DRY_RUN"] == "true":
logger.info("DRY_RUN:skipped embedding: %s", f"{kind} {name} {signature} {informal_name}")
if os.environ["DRY_RUN"] == "true":
return
- batch_embedding = embedding.embed(batch_doc)
- collection.add(embeddings=batch_embedding, ids=batch_id)
+ batch_embedding : list[Embedding] = embedding.embed(batch_doc)
+ collection.add(embeddings=batch_embedding, ids=batch_id, metadatas=metadatas)
diff --git a/prompt/metaprogramming.py b/prompt/metaprogramming.py
new file mode 100644
index 0000000..c15a603
--- /dev/null
+++ b/prompt/metaprogramming.py
@@ -0,0 +1,42 @@
+from database.translate import TranslatedItem, TranslationInput
+
+
+def format_declaration(declaration: TranslatedItem):
+ result : list[str] = []
+
+ result.append("")
+ result.append(f"{declaration.docstring}")
+ result.append(f"{declaration.kind} {declaration.name}")
+ result.append(f"{declaration.signature}")
+ result.append(f"{declaration.value}")
+ if declaration.informal_name is not None:
+ result.append(f"{declaration.informal_name}")
+ if declaration.informal_description is not None:
+ result.append(f"{declaration.informal_description}")
+ result.append("")
+
+ return "\n".join(result)
+
+def metaprogramming_prompt(input : TranslationInput):
+ result : list[str] = []
+
+ result.append("Your task is to create an informal description of the following metaprogramming declaration in Lean 4. Later on, we will use them to create embedding vectors.")
+
+ result.append(f"{input.header}")
+
+ result.append("")
+ for neighbor in input.neighbor:
+ result.append(format_declaration(neighbor)) # type: ignore
+ result.append("")
+
+ result.append("")
+ for dependency in input.dependency:
+ result.append(format_declaration(dependency)) # type: ignore
+ result.append("")
+
+ result.append("Finally, here is the declaration that you should create the description of.")
+ result.append(format_declaration(input)) # type: ignore
+
+ result.append("Please put your informal description into the following format: ... (this is where you put the informal name of this Lean 4 declaration), ... (this is where you put the informal description of this Lean 4 declaration). You can put your thinking into ... tags.")
+
+ return "\n".join(result)