diff --git a/database/create_schema.py b/database/create_schema.py index c8d9f04..bb37ca6 100644 --- a/database/create_schema.py +++ b/database/create_schema.py @@ -27,57 +27,40 @@ def create_schema(conn: Connection): ) """, """ - CREATE TABLE symbol ( - name JSONB PRIMARY KEY, - module_name JSONB REFERENCES module(name) NOT NULL, - type TEXT NOT NULL, - is_prop BOOLEAN NOT NULL - ) - """, - """ CREATE TABLE declaration ( module_name JSONB REFERENCES module(name) NOT NULL, + name JSONB, + index INTEGER NOT NULL, - name JSONB UNIQUE REFERENCES symbol(name), visible BOOLEAN NOT NULL, docstring TEXT, kind declaration_kind NOT NULL, signature TEXT NOT NULL, value TEXT, - PRIMARY KEY (module_name, index) + + symbol_type TEXT NOT NULL, + symbol_is_prop BOOLEAN NOT NULL, + + informal_name TEXT, + informal_description TEXT, + + PRIMARY KEY (name) ) """, """ CREATE TABLE dependency ( - source JSONB REFERENCES symbol(name) NOT NULL, - target JSONB REFERENCES symbol(name) NOT NULL, + source JSONB NOT NULL, + target JSONB NOT NULL, on_type BOOLEAN NOT NULL, PRIMARY KEY (source, target, on_type) ) """, """ CREATE TABLE level ( - symbol_name JSONB PRIMARY KEY REFERENCES symbol(name) NOT NULL, + symbol_name JSONB PRIMARY KEY REFERENCES declaration(name) NOT NULL, level INTEGER NOT NULL ) - """, """ - CREATE TABLE informal ( - symbol_name JSONB PRIMARY KEY REFERENCES symbol(name) NOT NULL, - name TEXT NOT NULL, - description TEXT NOT NULL - ) - """, - """ - CREATE VIEW record AS - SELECT - d.module_name, d.index, d.kind, d.name, d.signature, s.type, d.value, d.docstring, - i.name AS informal_name, i.description AS informal_description - FROM - declaration d - INNER JOIN informal i ON d.name = i.symbol_name - INNER JOIN symbol s ON d.name = s.name - """, ] with conn.cursor() as cursor: diff --git a/database/informalize.py b/database/informalize.py index 11123d7..d590d11 100644 --- a/database/informalize.py +++ b/database/informalize.py @@ -16,12 +16,10 @@ def find_neighbor(conn: Connection, module_name: LeanName, index: int, num_neigh with conn.cursor(row_factory=args_row(TranslatedItem)) as cursor: cursor.execute( """ - SELECT d.name, d.signature, i.name, i.description - FROM - declaration d - LEFT JOIN informal i ON d.name = i.symbol_name + SELECT name, signature, informal_name, informal_description + FROM declaration WHERE - d.module_name = %s AND d.index >= %s AND d.index <= %s + module_name = %s AND index >= %s AND index <= %s """, (Jsonb(module_name), index - num_neighbor, index + num_neighbor), ) @@ -32,11 +30,10 @@ def find_dependency(conn: Connection, name: LeanName) -> list[TranslatedItem]: with conn.cursor(row_factory=args_row(TranslatedItem)) as cursor: cursor.execute( """ - SELECT d.name, d.signature, i.name, i.description + SELECT name, signature, informal_name, informal_description FROM declaration d INNER JOIN dependency e ON d.name = e.target - LEFT JOIN informal i ON d.name = i.symbol_name WHERE e.source = %s """, @@ -57,15 +54,14 @@ def generate_informal(conn: Connection, batch_size: int = 50, limit_level: int | with conn.cursor() as cursor, conn.cursor() as insert_cursor: for level in range(max_level + 1): query = """ - SELECT s.name, d.signature, d.value, d.docstring, d.kind, m.docstring, d.module_name, d.index - FROM - symbol s - INNER JOIN declaration d ON s.name = d.name - INNER JOIN module m ON s.module_name = m.name - INNER JOIN level l ON s.name = l.symbol_name - WHERE - l.level = %s AND - (NOT EXISTS(SELECT 1 FROM informal i WHERE i.symbol_name = s.name)) + SELECT d.name, d.signature, d.value, d.docstring, d.kind, m.docstring, d.module_name, d.index + FROM + declaration d + INNER JOIN module m ON d.module_name = m.name + INNER JOIN level l ON d.name = l.symbol_name + WHERE + l.level = %s AND + (d.informal_description IS NULL OR d.informal_name IS NULL) """ if limit_num_per_level: cursor.execute(query + " LIMIT %s", (level, limit_num_per_level)) @@ -84,10 +80,15 @@ async def translate_and_insert(name: LeanName, data: TranslationInput): informal_name, informal_description = result insert_cursor.execute( """ - INSERT INTO informal (symbol_name, name, description) - VALUES (%s, %s, %s) + UPDATE declaration + SET informal_name = %(informal_name)s, informal_description = %(informal_description)s + WHERE name = %(name)s """, - (Jsonb(name), informal_name, informal_description), + { + "name": Jsonb(name), + "informal_name": informal_name, + "informal_description": informal_description + }, ) tasks.clear() diff --git a/database/jixia_db.py b/database/jixia_db.py index 94d0e46..4d83272 100644 --- a/database/jixia_db.py +++ b/database/jixia_db.py @@ -10,101 +10,109 @@ logger = logging.getLogger(__name__) +def _get_signature(declaration: Declaration, module_content): + if declaration.signature.pp is not None: + return declaration.signature.pp + elif declaration.signature.range is not None: + return module_content[declaration.signature.range.as_slice()].decode() + else: + return '' -def load_data(project: LeanProject, prefixes: list[LeanName], conn: Connection): - def load_module(data: Iterable[LeanName], base_dir: Path): - values = ((Jsonb(m), project.path_of_module(m, base_dir).read_bytes(), project.load_module_info(m).docstring) for m in data) - cursor.executemany( - """ - INSERT INTO module (name, content, docstring) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING - """, - values, - ) +def _get_value(declaration: Declaration, module_content): + if declaration.value is not None and declaration.value.range is not None: + return module_content[declaration.value.range.as_slice()].decode() + else: + return None - def load_symbol(module: LeanName): - symbols = [s for s in project.load_info(module, Symbol) if not is_internal(s.name)] - values = ((Jsonb(s.name), Jsonb(module), s.type, s.is_prop) for s in symbols) - cursor.executemany( - """ - INSERT INTO symbol (name, module_name, type, is_prop) VALUES (%s, %s, %s, %s) ON CONFLICT DO NOTHING - """, - values, - ) - for s in symbols: - values = ( - { - "source": Jsonb(s.name), - "target": Jsonb(t), - } - for t in s.type_references - if not is_internal(t) - ) - cursor.executemany( +def _find_declaration(declarations: list[Declaration], target_name): + for index, declaration in enumerate(declarations): + if declaration.name == target_name: + return declaration, index + return None, None + +def load_data(project: LeanProject, prefixes: list[LeanName], conn: Connection): + def load_module(module_names: Iterable[LeanName], base_dir: Path): + for module_name in module_names: + db_module = { + "name": Jsonb(module_name), + "content": project.path_of_module(module_name, base_dir).read_bytes(), + "docstring": project.load_module_info(module_name).docstring + } + cursor.execute( """ - INSERT INTO dependency (source, target, on_type) - SELECT %(source)s, %(target)s, TRUE - WHERE EXISTS(SELECT 1 FROM symbol WHERE name = %(target)s) - ON CONFLICT DO NOTHING + INSERT INTO module (name, content, docstring) VALUES (%(name)s, %(content)s, %(docstring)s) """, - values, + db_module ) - if s.value_references is not None: - values = ( + symbols = project.load_info(module_name, Symbol) + declarations = project.load_info(module_name, Declaration) + for symbol in symbols: + declaration, index = _find_declaration(declarations, symbol.name) + if ( + is_internal(symbol.name) or + declaration is None or + declaration.kind == "proofWanted" + ): + continue + + cursor.execute( + """ + INSERT INTO declaration (module_name, index, name, visible, docstring, kind, signature, value, symbol_type, symbol_is_prop) + VALUES (%(module_name)s, %(index)s, %(name)s, %(visible)s, %(docstring)s, %(kind)s, %(signature)s, %(value)s, %(symbol_type)s, %(symbol_is_prop)s) + """, { - "source": Jsonb(s.name), - "target": Jsonb(t), + "module_name": Jsonb(module_name), + "name" : Jsonb(declaration.name) if declaration.kind != "example" else None, + "index" : index, + "visible" : declaration.modifiers.visibility != "private" and declaration.kind != "example", + "docstring" : declaration.modifiers.docstring, + "kind" : declaration.kind, + "signature" : _get_signature(declaration, db_module["content"]), + "value" : _get_value(declaration, db_module["content"]), + "symbol_type": symbol.type, + "symbol_is_prop": symbol.is_prop } - for t in s.value_references - if not is_internal(t) ) + + db_deps = [] + for ref_name in symbol.type_references: + if is_internal(ref_name): + continue + db_deps.append({ + "source": Jsonb(symbol.name), + "target": Jsonb(ref_name), + "on_type": True + }) + for ref_name in (symbol.value_references or []): + if is_internal(ref_name): + continue + db_deps.append({ + "source": Jsonb(symbol.name), + "target": Jsonb(ref_name), + "on_type": False + }) cursor.executemany( """ INSERT INTO dependency (source, target, on_type) - SELECT %(source)s, %(target)s, FALSE - WHERE EXISTS(SELECT 1 FROM symbol WHERE name = %(target)s) - ON CONFLICT DO NOTHING + VALUES (%(source)s, %(target)s, %(on_type)s) """, - values, + db_deps, ) - def load_declaration(module: LeanName): - declarations = project.load_info(module, Declaration) - cursor.execute( - """ - SELECT content FROM module WHERE name = %s - """, - (Jsonb(module),), - ) - (source,) = cursor.fetchone() - values = ( - ( - Jsonb(module), - i, - Jsonb(d.name) if d.kind != "example" else None, - d.modifiers.visibility != "private" and d.kind != "example", - d.modifiers.docstring, - d.kind, - d.signature.pp if d.signature.pp is not None else source[d.signature.range.as_slice()].decode(), - source[d.value.range.as_slice()].decode() if d.value is not None else None, - ) - for i, d in enumerate(declarations) - if not is_internal(d.name) and d.kind != "proofWanted" - ) - cursor.executemany( - """ - INSERT INTO declaration (module_name, index, name, visible, docstring, kind, signature, value) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT DO NOTHING - """, - values, - ) - def topological_sort(): logger.info("performing topological sort") + # Delete dependencies where target doesn't exist in declaration table + cursor.execute(""" + DELETE FROM dependency d + WHERE NOT EXISTS (SELECT 1 FROM declaration dec WHERE dec.name = d.target) + """) + logger.info("Deleted %d invalid dependencies", cursor.rowcount) + cursor.execute(""" INSERT INTO level (symbol_name, level) SELECT name, 0 - FROM symbol v + FROM declaration v WHERE NOT EXISTS (SELECT 1 FROM dependency e WHERE e.source = v.name) """) while cursor.rowcount: @@ -122,21 +130,19 @@ def topological_sort(): """) with conn.cursor() as cursor: - lean_sysroot = Path(os.environ["LEAN_SYSROOT"]) - lean_src = lean_sysroot / "src" / "lean" - all_modules = [] - for d in project.root, lean_src: - results = project.batch_run_jixia( - base_dir=d, - prefixes=prefixes, - plugins=["module", "declaration", "symbol"], - ) - modules = [r[0] for r in results] - load_module(modules, d) - all_modules += modules + path_to_project = project.root + project_modules = [r[0] for r in project.batch_run_jixia( + base_dir=path_to_project, + prefixes=prefixes, + plugins=["module", "declaration", "symbol"], + )] + load_module(project_modules, path_to_project) - for m in all_modules: - load_symbol(m) - for m in all_modules: - load_declaration(m) + path_to_lean = Path(os.environ["LEAN_SYSROOT"]) / "src" / "lean" + lean_modules = [r[0] for r in project.batch_run_jixia( + base_dir=path_to_lean, + prefixes=prefixes, + plugins=["module", "declaration", "symbol"], + )] + load_module(lean_modules, path_to_lean) topological_sort() diff --git a/database/vector_db.py b/database/vector_db.py index d70c4e4..153e2e6 100644 --- a/database/vector_db.py +++ b/database/vector_db.py @@ -22,10 +22,9 @@ def create_vector_db(conn: Connection, path: str, batch_size: int): with conn.cursor() as cursor: cursor.execute(""" - SELECT d.module_name, d.index, d.kind, d.name, d.signature, i.name, i.description - FROM - declaration d INNER JOIN informal i ON d.name = i.symbol_name - WHERE d.visible = TRUE + SELECT module_name, index, kind, name, signature, informal_name, informal_description + FROM declaration + WHERE visible = TRUE """) batch_doc = []