Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 13 additions & 30 deletions database/create_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,57 +27,40 @@ def create_schema(conn: Connection):
)
""",
"""
CREATE TABLE symbol (
name JSONB PRIMARY KEY,
module_name JSONB REFERENCES module(name) NOT NULL,
type TEXT NOT NULL,
is_prop BOOLEAN NOT NULL
)
""",
"""
CREATE TABLE declaration (
module_name JSONB REFERENCES module(name) NOT NULL,
name JSONB,

index INTEGER NOT NULL,
name JSONB UNIQUE REFERENCES symbol(name),
visible BOOLEAN NOT NULL,
docstring TEXT,
kind declaration_kind NOT NULL,
signature TEXT NOT NULL,
value TEXT,
PRIMARY KEY (module_name, index)

symbol_type TEXT NOT NULL,
symbol_is_prop BOOLEAN NOT NULL,

informal_name TEXT,
informal_description TEXT,

PRIMARY KEY (name)
)
""",
"""
CREATE TABLE dependency (
source JSONB REFERENCES symbol(name) NOT NULL,
target JSONB REFERENCES symbol(name) NOT NULL,
source JSONB NOT NULL,
target JSONB NOT NULL,
on_type BOOLEAN NOT NULL,
PRIMARY KEY (source, target, on_type)
)
""",
"""
CREATE TABLE level (
symbol_name JSONB PRIMARY KEY REFERENCES symbol(name) NOT NULL,
symbol_name JSONB PRIMARY KEY REFERENCES declaration(name) NOT NULL,
level INTEGER NOT NULL
)
""",
"""
CREATE TABLE informal (
symbol_name JSONB PRIMARY KEY REFERENCES symbol(name) NOT NULL,
name TEXT NOT NULL,
description TEXT NOT NULL
)
""",
"""
CREATE VIEW record AS
SELECT
d.module_name, d.index, d.kind, d.name, d.signature, s.type, d.value, d.docstring,
i.name AS informal_name, i.description AS informal_description
FROM
declaration d
INNER JOIN informal i ON d.name = i.symbol_name
INNER JOIN symbol s ON d.name = s.name
""",
]

with conn.cursor() as cursor:
Expand Down
39 changes: 20 additions & 19 deletions database/informalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ def find_neighbor(conn: Connection, module_name: LeanName, index: int, num_neigh
with conn.cursor(row_factory=args_row(TranslatedItem)) as cursor:
cursor.execute(
"""
SELECT d.name, d.signature, i.name, i.description
FROM
declaration d
LEFT JOIN informal i ON d.name = i.symbol_name
SELECT name, signature, informal_name, informal_description
FROM declaration
WHERE
d.module_name = %s AND d.index >= %s AND d.index <= %s
module_name = %s AND index >= %s AND index <= %s
""",
(Jsonb(module_name), index - num_neighbor, index + num_neighbor),
)
Expand All @@ -32,11 +30,10 @@ def find_dependency(conn: Connection, name: LeanName) -> list[TranslatedItem]:
with conn.cursor(row_factory=args_row(TranslatedItem)) as cursor:
cursor.execute(
"""
SELECT d.name, d.signature, i.name, i.description
SELECT name, signature, informal_name, informal_description
FROM
declaration d
INNER JOIN dependency e ON d.name = e.target
LEFT JOIN informal i ON d.name = i.symbol_name
WHERE
e.source = %s
""",
Expand All @@ -57,15 +54,14 @@ def generate_informal(conn: Connection, batch_size: int = 50, limit_level: int |
with conn.cursor() as cursor, conn.cursor() as insert_cursor:
for level in range(max_level + 1):
query = """
SELECT s.name, d.signature, d.value, d.docstring, d.kind, m.docstring, d.module_name, d.index
FROM
symbol s
INNER JOIN declaration d ON s.name = d.name
INNER JOIN module m ON s.module_name = m.name
INNER JOIN level l ON s.name = l.symbol_name
WHERE
l.level = %s AND
(NOT EXISTS(SELECT 1 FROM informal i WHERE i.symbol_name = s.name))
SELECT d.name, d.signature, d.value, d.docstring, d.kind, m.docstring, d.module_name, d.index
FROM
declaration d
INNER JOIN module m ON d.module_name = m.name
INNER JOIN level l ON d.name = l.symbol_name
WHERE
l.level = %s AND
(d.informal_description IS NULL OR d.informal_name IS NULL)
"""
if limit_num_per_level:
cursor.execute(query + " LIMIT %s", (level, limit_num_per_level))
Expand All @@ -84,10 +80,15 @@ async def translate_and_insert(name: LeanName, data: TranslationInput):
informal_name, informal_description = result
insert_cursor.execute(
"""
INSERT INTO informal (symbol_name, name, description)
VALUES (%s, %s, %s)
UPDATE declaration
SET informal_name = %(informal_name)s, informal_description = %(informal_description)s
WHERE name = %(name)s
""",
(Jsonb(name), informal_name, informal_description),
{
"name": Jsonb(name),
"informal_name": informal_name,
"informal_description": informal_description
},
)

tasks.clear()
Expand Down
188 changes: 97 additions & 91 deletions database/jixia_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,101 +10,109 @@

logger = logging.getLogger(__name__)

def _get_signature(declaration: Declaration, module_content):
if declaration.signature.pp is not None:
return declaration.signature.pp
elif declaration.signature.range is not None:
return module_content[declaration.signature.range.as_slice()].decode()
else:
return ''

def load_data(project: LeanProject, prefixes: list[LeanName], conn: Connection):
def load_module(data: Iterable[LeanName], base_dir: Path):
values = ((Jsonb(m), project.path_of_module(m, base_dir).read_bytes(), project.load_module_info(m).docstring) for m in data)
cursor.executemany(
"""
INSERT INTO module (name, content, docstring) VALUES (%s, %s, %s) ON CONFLICT DO NOTHING
""",
values,
)
def _get_value(declaration: Declaration, module_content):
if declaration.value is not None and declaration.value.range is not None:
return module_content[declaration.value.range.as_slice()].decode()
else:
return None

def load_symbol(module: LeanName):
symbols = [s for s in project.load_info(module, Symbol) if not is_internal(s.name)]
values = ((Jsonb(s.name), Jsonb(module), s.type, s.is_prop) for s in symbols)
cursor.executemany(
"""
INSERT INTO symbol (name, module_name, type, is_prop) VALUES (%s, %s, %s, %s) ON CONFLICT DO NOTHING
""",
values,
)
for s in symbols:
values = (
{
"source": Jsonb(s.name),
"target": Jsonb(t),
}
for t in s.type_references
if not is_internal(t)
)
cursor.executemany(
def _find_declaration(declarations: list[Declaration], target_name):
for index, declaration in enumerate(declarations):
if declaration.name == target_name:
return declaration, index
return None, None

def load_data(project: LeanProject, prefixes: list[LeanName], conn: Connection):
def load_module(module_names: Iterable[LeanName], base_dir: Path):
for module_name in module_names:
db_module = {
"name": Jsonb(module_name),
"content": project.path_of_module(module_name, base_dir).read_bytes(),
"docstring": project.load_module_info(module_name).docstring
}
cursor.execute(
"""
INSERT INTO dependency (source, target, on_type)
SELECT %(source)s, %(target)s, TRUE
WHERE EXISTS(SELECT 1 FROM symbol WHERE name = %(target)s)
ON CONFLICT DO NOTHING
INSERT INTO module (name, content, docstring) VALUES (%(name)s, %(content)s, %(docstring)s)
""",
values,
db_module
)

if s.value_references is not None:
values = (
symbols = project.load_info(module_name, Symbol)
declarations = project.load_info(module_name, Declaration)
for symbol in symbols:
declaration, index = _find_declaration(declarations, symbol.name)
if (
is_internal(symbol.name) or
declaration is None or
declaration.kind == "proofWanted"
):
continue

cursor.execute(
"""
INSERT INTO declaration (module_name, index, name, visible, docstring, kind, signature, value, symbol_type, symbol_is_prop)
VALUES (%(module_name)s, %(index)s, %(name)s, %(visible)s, %(docstring)s, %(kind)s, %(signature)s, %(value)s, %(symbol_type)s, %(symbol_is_prop)s)
""",
{
"source": Jsonb(s.name),
"target": Jsonb(t),
"module_name": Jsonb(module_name),
"name" : Jsonb(declaration.name) if declaration.kind != "example" else None,
"index" : index,
"visible" : declaration.modifiers.visibility != "private" and declaration.kind != "example",
"docstring" : declaration.modifiers.docstring,
"kind" : declaration.kind,
"signature" : _get_signature(declaration, db_module["content"]),
"value" : _get_value(declaration, db_module["content"]),
"symbol_type": symbol.type,
"symbol_is_prop": symbol.is_prop
}
for t in s.value_references
if not is_internal(t)
)

db_deps = []
for ref_name in symbol.type_references:
if is_internal(ref_name):
continue
db_deps.append({
"source": Jsonb(symbol.name),
"target": Jsonb(ref_name),
"on_type": True
})
for ref_name in (symbol.value_references or []):
if is_internal(ref_name):
continue
db_deps.append({
"source": Jsonb(symbol.name),
"target": Jsonb(ref_name),
"on_type": False
})
cursor.executemany(
"""
INSERT INTO dependency (source, target, on_type)
SELECT %(source)s, %(target)s, FALSE
WHERE EXISTS(SELECT 1 FROM symbol WHERE name = %(target)s)
ON CONFLICT DO NOTHING
VALUES (%(source)s, %(target)s, %(on_type)s)
""",
values,
db_deps,
)

def load_declaration(module: LeanName):
declarations = project.load_info(module, Declaration)
cursor.execute(
"""
SELECT content FROM module WHERE name = %s
""",
(Jsonb(module),),
)
(source,) = cursor.fetchone()
values = (
(
Jsonb(module),
i,
Jsonb(d.name) if d.kind != "example" else None,
d.modifiers.visibility != "private" and d.kind != "example",
d.modifiers.docstring,
d.kind,
d.signature.pp if d.signature.pp is not None else source[d.signature.range.as_slice()].decode(),
source[d.value.range.as_slice()].decode() if d.value is not None else None,
)
for i, d in enumerate(declarations)
if not is_internal(d.name) and d.kind != "proofWanted"
)
cursor.executemany(
"""
INSERT INTO declaration (module_name, index, name, visible, docstring, kind, signature, value)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT DO NOTHING
""",
values,
)

def topological_sort():
logger.info("performing topological sort")
# Delete dependencies where target doesn't exist in declaration table
cursor.execute("""
DELETE FROM dependency d
WHERE NOT EXISTS (SELECT 1 FROM declaration dec WHERE dec.name = d.target)
""")
logger.info("Deleted %d invalid dependencies", cursor.rowcount)

cursor.execute("""
INSERT INTO level (symbol_name, level)
SELECT name, 0
FROM symbol v
FROM declaration v
WHERE NOT EXISTS (SELECT 1 FROM dependency e WHERE e.source = v.name)
""")
while cursor.rowcount:
Expand All @@ -122,21 +130,19 @@ def topological_sort():
""")

with conn.cursor() as cursor:
lean_sysroot = Path(os.environ["LEAN_SYSROOT"])
lean_src = lean_sysroot / "src" / "lean"
all_modules = []
for d in project.root, lean_src:
results = project.batch_run_jixia(
base_dir=d,
prefixes=prefixes,
plugins=["module", "declaration", "symbol"],
)
modules = [r[0] for r in results]
load_module(modules, d)
all_modules += modules
path_to_project = project.root
project_modules = [r[0] for r in project.batch_run_jixia(
base_dir=path_to_project,
prefixes=prefixes,
plugins=["module", "declaration", "symbol"],
)]
load_module(project_modules, path_to_project)

for m in all_modules:
load_symbol(m)
for m in all_modules:
load_declaration(m)
path_to_lean = Path(os.environ["LEAN_SYSROOT"]) / "src" / "lean"
lean_modules = [r[0] for r in project.batch_run_jixia(
base_dir=path_to_lean,
prefixes=prefixes,
plugins=["module", "declaration", "symbol"],
)]
load_module(lean_modules, path_to_lean)
topological_sort()
7 changes: 3 additions & 4 deletions database/vector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ def create_vector_db(conn: Connection, path: str, batch_size: int):

with conn.cursor() as cursor:
cursor.execute("""
SELECT d.module_name, d.index, d.kind, d.name, d.signature, i.name, i.description
FROM
declaration d INNER JOIN informal i ON d.name = i.symbol_name
WHERE d.visible = TRUE
SELECT module_name, index, kind, name, signature, informal_name, informal_description
FROM declaration
WHERE visible = TRUE
""")

batch_doc = []
Expand Down