From f92a88c8f92b2b1092e0b74111ba902d734fe001 Mon Sep 17 00:00:00 2001 From: Josh Cohen Date: Fri, 16 Jan 2026 11:58:25 -0500 Subject: [PATCH 1/5] Add mutual recursion to TypeFactory Generate eliminators for mutual recursive types Change Boogie TypeDecl to take in List Leave ProgramWF proofs as sorry for now Change strict positivity+uniformity checks for mutually recursive types --- Strata/DL/Lambda/LExprTypeEnv.lean | 29 ++- Strata/DL/Lambda/TypeFactory.lean | 226 +++++++++++++++-- .../Boogie/DDMTransform/Translate.lean | 9 +- Strata/Languages/Boogie/Env.lean | 9 +- Strata/Languages/Boogie/Program.lean | 8 +- Strata/Languages/Boogie/ProgramType.lean | 12 +- Strata/Languages/Boogie/ProgramWF.lean | 8 + Strata/Languages/Boogie/TypeDecl.lean | 20 +- StrataTest/DL/Lambda/TypeFactoryTests.lean | 228 +++++++++++++++--- .../Boogie/DatatypeVerificationTests.lean | 2 +- StrataTest/Languages/Boogie/ExprEvalTest.lean | 2 +- .../Boogie/SMTEncoderDatatypeTest.lean | 2 +- 12 files changed, 480 insertions(+), 75 deletions(-) diff --git a/Strata/DL/Lambda/LExprTypeEnv.lean b/Strata/DL/Lambda/LExprTypeEnv.lean index 3b0ad8d3a..9961d4511 100644 --- a/Strata/DL/Lambda/LExprTypeEnv.lean +++ b/Strata/DL/Lambda/LExprTypeEnv.lean @@ -218,9 +218,6 @@ deriving Inhabited def LDatatype.toKnownType (d: LDatatype IDMeta) : KnownType := { name := d.name, metadata := d.typeArgs.length} -def TypeFactory.toKnownTypes (t: @TypeFactory IDMeta) : KnownTypes := - makeKnownTypes (t.foldr (fun t l => t.toKnownType :: l) []) - /-- A type environment `TEnv` contains - genEnv: `TGenEnv` to track the generator state as well as the typing context @@ -253,7 +250,7 @@ structure LContext (T: LExprParams) where deriving Inhabited def LContext.empty {IDMeta} : LContext IDMeta := - ⟨#[], #[], {}, {}⟩ + ⟨#[], TypeFactory.default, {}, {}⟩ instance : EmptyCollection (LContext IDMeta) where emptyCollection := LContext.empty @@ -290,7 +287,7 @@ def TEnv.default : TEnv IDMeta := def LContext.default : LContext T := { functions := #[], - datatypes := #[], + datatypes := TypeFactory.default, knownTypes := KnownTypes.default, idents := Identifiers.default } @@ -353,8 +350,28 @@ def LContext.addDatatype [Inhabited T.IDMeta] [Inhabited T.Metadata] (C: LContex let ks ← C.knownTypes.add d.toKnownType .ok {C with datatypes := ds, functions := fs, knownTypes := ks} +/-- Add a mutual block of datatypes to the context. -/ +def LContext.addMutualBlock [Inhabited T.IDMeta] [Inhabited T.Metadata] [ToFormat T.IDMeta] (C: LContext T) (block: MutualDatatype T.IDMeta) : Except Format (LContext T) := do + if block.isEmpty then return C + -- Check for name clashes with known types + for d in block do + if C.knownTypes.containsName d.name then + throw f!"Cannot name datatype same as known type!\n{d}\nKnownTypes' names:\n{C.knownTypes.keywords}" + -- Add all datatypes to the type factory + let mut ds := C.datatypes + for d in block do + ds ← ds.addDatatype d + -- Generate factory functions for the whole mutual block + let f ← genBlockFactory block + let fs ← C.functions.addFactory f + -- Add datatype names to knownTypes + let mut ks := C.knownTypes + for d in block do + ks ← ks.add d.toKnownType + .ok {C with datatypes := ds, functions := fs, knownTypes := ks} + def LContext.addTypeFactory [Inhabited T.IDMeta] [Inhabited T.Metadata] (C: LContext T) (f: @TypeFactory T.IDMeta) : Except Format (LContext T) := - Array.foldlM (fun C d => C.addDatatype d) C f + f.foldlM (fun C block => C.addMutualBlock block) C /-- Replace the global substitution in `T.state.subst` with `S`. diff --git a/Strata/DL/Lambda/TypeFactory.lean b/Strata/DL/Lambda/TypeFactory.lean index e20cb53c5..667197bac 100644 --- a/Strata/DL/Lambda/TypeFactory.lean +++ b/Strata/DL/Lambda/TypeFactory.lean @@ -81,6 +81,10 @@ The default type application for a datatype. E.g. for datatype def dataDefault (d: LDatatype IDMeta) : LMonoTy := data d (d.typeArgs.map .ftvar) +/-- A group of mutually recursive datatypes. + For non-mutually-recursive datatypes, this is a single-element list. -/ +abbrev MutualDatatype (IDMeta : Type) := List (LDatatype IDMeta) + --------------------------------------------------------------------- -- Typechecking @@ -130,6 +134,33 @@ def checkStrictPosUnif (d: LDatatype IDMeta) : Except Format Unit := ) () args ) () d.constrs +/-- +Check for strict positivity and uniformity of all datatypes in a mutual block +within type `ty`. The string `c` appears only for error message information. +-/ +def checkStrictPosUnifTyMutual (c: String) (block: MutualDatatype IDMeta) (ty: LMonoTy) : Except Format Unit := + match ty with + | .arrow t1 t2 => + -- Check that no datatype in the block appears in the left side of an arrow + match block.find? (fun d => tyNameAppearsIn d.name t1) with + | some d => .error f!"Error in constructor {c}: Non-strictly positive occurrence of {d.name} in type {ty}" + | none => checkStrictPosUnifTyMutual c block t2 + | _ => + -- Check uniformity for all datatypes in the block + block.foldlM (fun _ d => checkUniform c d.name (d.typeArgs.map .ftvar) ty) () + +/-- +Check for strict positivity and uniformity across a mutual block of datatypes +-/ +def checkStrictPosUnifMutual (block: MutualDatatype IDMeta) : Except Format Unit := + block.foldlM (fun _ d => + d.constrs.foldlM (fun _ ⟨name, args, _⟩ => + args.foldlM (fun _ ⟨_, ty⟩ => + checkStrictPosUnifTyMutual name.name block ty + ) () + ) () + ) () + --------------------------------------------------------------------- -- Generating constructors and eliminators @@ -302,6 +333,105 @@ def elimFunc [Inhabited T.IDMeta] [BEq T.Identifier] (d: LDatatype T.IDMeta) (m: --------------------------------------------------------------------- +-- Mutual block eliminator generation + +/-- Find which datatype in a mutual block a recursive type refers to. -/ +def findRecTyInBlock (block : MutualDatatype IDMeta) (ty : LMonoTy) : Option (LDatatype IDMeta) := + block.find? (fun d => isRecTy d ty) + +/-- +Generate recursive type for mutual blocks. Replace references to datatypes +in the block with the appropriate return type variable. +-/ +def genRecTyMutual (block : MutualDatatype IDMeta) (retTyVars : List TyIdentifier) (ty : LMonoTy) : Option LMonoTy := + match block.findIdx? (fun d => ty == dataDefault d) with + | some idx => retTyVars[idx]? |>.map LMonoTy.ftvar + | none => + match ty with + | .arrow t1 t2 => (genRecTyMutual block retTyVars t2).map (fun r => .arrow t1 r) + | _ => none + +/-- +Generate eliminator case type for a constructor in a mutual block. +`dtIdx` is the index of the datatype this constructor belongs to. +-/ +def elimTyMutual (block : MutualDatatype IDMeta) (retTyVars : List TyIdentifier) + (dtIdx : Nat) (c : LConstr IDMeta) : LMonoTy := + let outputType := retTyVars[dtIdx]? |>.map LMonoTy.ftvar |>.getD (.ftvar "") + match c.args with + | [] => outputType + | _ :: _ => + let argTypes := c.args.map Prod.snd + let recTypes := c.args.filterMap fun (_, ty) => genRecTyMutual block retTyVars ty + LMonoTy.mkArrow' outputType (argTypes ++ recTypes) + +/-- Compute global constructor index within a mutual block. -/ +def globalConstrIdx (block : MutualDatatype IDMeta) (dtIdx : Nat) (constrIdx : Nat) : Nat := + let prevCount := (block.take dtIdx).foldl (fun acc d => acc + d.constrs.length) 0 + prevCount + constrIdx + +/-- Check if a type is recursive within a mutual block. -/ +def isRecTyInBlock (block : MutualDatatype IDMeta) (ty : LMonoTy) : Bool := + block.any (fun d => isRecTy d ty) + +/-- Find which datatype and constructor a value belongs to in a mutual block. + Unlike datatypeGetConstr, this identifies recursive args across the whole block. -/ +def matchConstrInBlock {T: LExprParams} [BEq T.Identifier] (block : MutualDatatype T.IDMeta) (x : LExpr T.mono) + : Option (Nat × Nat × LConstr T.IDMeta × List (LExpr T.mono) × List (LExpr T.mono × LMonoTy)) := + List.zip block (List.range block.length) |>.findSome? fun (d, dtIdx) => + (datatypeGetConstr d x).map fun (c, constrIdx, args, _) => + -- Recompute recs using the whole block, not just d + let recs := (List.zip args (c.args.map Prod.snd)).filter (fun (_, ty) => isRecTyInBlock block ty) + (dtIdx, constrIdx, c, args, recs) + +/-- Construct recursive eliminator call for mutual blocks. -/ +def elimRecCallMutual {T: LExprParams} [Inhabited T.IDMeta] (block : MutualDatatype T.IDMeta) (recArg : LExpr T.mono) + (recTy : LMonoTy) (elimArgs : List (LExpr T.mono)) (m : T.Metadata) : Option (LExpr T.mono) := + (findRecTyInBlock block recTy).map fun d => + let elimName := elimFuncName d + match recTyStructure d recTy with + | .inl _ => (LExpr.op m elimName .none).mkApp m (recArg :: elimArgs) + | .inr funArgs => + LExpr.absMulti m funArgs ((LExpr.op m elimName .none).mkApp m (recArg.mkApp m (getBVars m funArgs.length) :: elimArgs)) + +/-- Generate eliminator concrete evaluator for mutual blocks. -/ +def elimConcreteEvalMutual {T: LExprParams} [BEq T.Identifier] [Inhabited T.IDMeta] (block : MutualDatatype T.IDMeta) + (m : T.Metadata) : T.Metadata → List (LExpr T.mono) → Option (LExpr T.mono) := + fun _ args => + match args with + | x :: xs => + match matchConstrInBlock block x with + | some (dtIdx, constrIdx, _, a, recs) => + let gIdx := globalConstrIdx block dtIdx constrIdx + xs[gIdx]?.bind fun f => + let recCalls := recs.filterMap (fun (r, rty) => elimRecCallMutual block r rty xs m) + some (f.mkApp m (a ++ recCalls)) + | none => none + | _ => none + +/-- +Generate eliminators for all datatypes in a mutual block. +Each datatype gets its own eliminator, but they share case function arguments +for all constructors across the block. +-/ +def elimFuncsMutual [Inhabited T.IDMeta] [BEq T.Identifier] (block : MutualDatatype T.IDMeta) (m : T.Metadata) + : List (LFunc T) := + if block.isEmpty then [] else + let allTypeArgs := block.flatMap (·.typeArgs) + let retTyVars := freshTypeArgs block.length allTypeArgs + let allConstrs : List (Nat × LConstr T.IDMeta) := + List.zip block (List.range block.length) |>.flatMap fun (d, dtIdx) => d.constrs.map (dtIdx, ·) + let caseTypes := allConstrs.map fun (dtIdx, c) => elimTyMutual block retTyVars dtIdx c + List.zip block (List.range block.length) |>.map fun (d, dtIdx) => + let outputTyVar := retTyVars[dtIdx]?.getD "" + { name := elimFuncName d + typeArgs := retTyVars ++ d.typeArgs + inputs := List.zip (genArgNames (allConstrs.length + 1)) (dataDefault d :: caseTypes) + output := .ftvar outputTyVar + concreteEval := elimConcreteEvalMutual block m } + +--------------------------------------------------------------------- + -- Generating testers and destructors /-- @@ -315,6 +445,33 @@ def testerFuncBody {T : LExprParams} [Inhabited T.IDMeta] (d: LDatatype T.IDMeta let args := List.map (fun c1 => LExpr.absMultiInfer m (numargs c1) (.boolConst m (c.name.name == c1.name.name))) d.constrs .mkApp m (.op m (elimFuncName d) .none) (input :: args) +/-- +Generate tester body for mutual blocks. For mutual eliminators, we need case functions +for ALL constructors across the block, not just the constructors of one datatype. +-/ +def testerFuncBodyMutual {T : LExprParams} [Inhabited T.IDMeta] (block: MutualDatatype T.IDMeta) + (d: LDatatype T.IDMeta) (c: LConstr T.IDMeta) (input: LExpr T.mono) (m: T.Metadata) : LExpr T.mono := + -- Number of arguments for a constructor in a mutual block + let numargs (c: LConstr T.IDMeta) := c.args.length + ((c.args.map Prod.snd).filter (isRecTyInBlock block)).length + -- Generate case functions for ALL constructors in the block + let args := block.flatMap (fun d' => d'.constrs.map (fun c1 => + LExpr.absMultiInfer m (numargs c1) (.boolConst m (c.name.name == c1.name.name)))) + .mkApp m (.op m (elimFuncName d) .none) (input :: args) + +/-- +Generate tester function for a constructor in a mutual block. +-/ +def testerFuncMutual {T} [Inhabited T.IDMeta] (block: MutualDatatype T.IDMeta) + (d: LDatatype T.IDMeta) (c: LConstr T.IDMeta) (m: T.Metadata) : LFunc T := + let arg := genArgName + {name := c.testerName, + typeArgs := d.typeArgs, + inputs := [(arg, dataDefault d)], + output := .bool, + body := testerFuncBodyMutual block d c (.fvar m arg .none) m, + attr := #["inline_if_val"] + } + /-- Generate tester function for a constructor (e.g. `List$isCons` and `List$isNil`). The semantics of the testers are given via a body, @@ -368,16 +525,45 @@ def destructorFuncs {T} [BEq T.Identifier] [Inhabited T.IDMeta] (d: LDatatype T -- Type Factories -def TypeFactory := Array (LDatatype IDMeta) +/-- A TypeFactory stores datatypes grouped by mutual recursion. -/ +def TypeFactory := Array (MutualDatatype IDMeta) instance: ToFormat (@TypeFactory IDMeta) where - format f := Std.Format.joinSep f.toList f!"{Format.line}" + format f := + let formatBlock (block : MutualDatatype IDMeta) : Format := + match block with + | [d] => format d + | ds => f!"mutual {Std.Format.joinSep (ds.map format) Format.line} end" + Std.Format.joinSep (f.toList.map formatBlock) f!"{Format.line}" instance : Inhabited (@TypeFactory IDMeta) where default := #[] def TypeFactory.default : @TypeFactory IDMeta := #[] +/-- Get all datatypes in the TypeFactory as a flat list. -/ +def TypeFactory.allDatatypes (t : @TypeFactory IDMeta) : List (LDatatype IDMeta) := + t.toList.flatten + +/-- Find a datatype by name in the TypeFactory. -/ +def TypeFactory.getType (F : @TypeFactory IDMeta) (name : String) : Option (LDatatype IDMeta) := + F.allDatatypes.find? (fun d => d.name == name) + +/-- Add a mutual block to the TypeFactory, checking for duplicate names. -/ +def TypeFactory.addMutualBlock (t : @TypeFactory IDMeta) (block : MutualDatatype IDMeta) : Except Format (@TypeFactory IDMeta) := do + for d in block do + match t.getType d.name with + | some d' => throw f!"A datatype of name {d.name} already exists! \ + Redefinitions are not allowed.\n\ + Existing Type: {d'}\n\ + New Type:{d}" + | none => pure () + .ok (t.push block) + +/-- Add a single datatype (as a single-element block). -/ +def TypeFactory.addDatatype (t : @TypeFactory IDMeta) (d : LDatatype IDMeta) : Except Format (@TypeFactory IDMeta) := + t.addMutualBlock [d] + /-- Generates the Factory (containing the eliminator, constructors, testers, and destructors) for a single datatype. @@ -405,28 +591,28 @@ def LDatatype.genFunctionMaps {T: LExprParams} [Inhabited T.IDMeta] [BEq T.Ident (destructorFuncs d c).map (fun f => (f.name.name, (d, c))))).flatten) /-- -Generates the Factory (containing all constructor and eliminator functions) for the given `TypeFactory` +Generates the Factory (containing eliminators, constructors, testers, and destructors) +for a mutual block of datatypes. -/ -def TypeFactory.genFactory {T: LExprParams} [inst: Inhabited T.Metadata] [Inhabited T.IDMeta] [ToFormat T.IDMeta] [BEq T.Identifier] (t: @TypeFactory T.IDMeta) : Except Format (@Lambda.Factory T) := - t.foldlM (fun f d => do - let f' ← d.genFactory - f.addFactory f') Factory.default - -def TypeFactory.getType (F : @TypeFactory IDMeta) (name : String) : Option (LDatatype IDMeta) := - F.find? (fun d => d.name == name) +def genBlockFactory {T: LExprParams} [inst: Inhabited T.Metadata] [Inhabited T.IDMeta] [ToFormat T.IDMeta] [BEq T.Identifier] + (block : MutualDatatype T.IDMeta) : Except Format (@Lambda.Factory T) := do + if block.isEmpty then return Factory.default + -- Check strict positivity and uniformity across the whole block + checkStrictPosUnifMutual block + let elims := elimFuncsMutual block inst.default + let constrs := block.flatMap (fun d => d.constrs.map (fun c => constrFunc c d)) + let testers := block.flatMap (fun d => d.constrs.map (fun c => testerFuncMutual block d c inst.default)) + let destrs := block.flatMap (fun d => d.constrs.flatMap (fun c => destructorFuncs d c)) + Factory.default.addFactory (elims ++ constrs ++ testers ++ destrs).toArray /-- -Add an `LDatatype` to an existing `TypeFactory`, checking that no -types are duplicated. +Generates the Factory (containing all constructor and eliminator functions) for the given `TypeFactory`. -/ -def TypeFactory.addDatatype (t: @TypeFactory IDMeta) (d: LDatatype IDMeta) : Except Format (@TypeFactory IDMeta) := - -- Check that type is not redeclared - match t.getType d.name with - | none => .ok (t.push d) - | some d' => .error f!"A datatype of name {d.name} already exists! \ - Redefinitions are not allowed.\n\ - Existing Type: {d'}\n\ - New Type:{d}" +def TypeFactory.genFactory {T: LExprParams} [inst: Inhabited T.Metadata] [Inhabited T.IDMeta] [ToFormat T.IDMeta] [BEq T.Identifier] (t: @TypeFactory T.IDMeta) : Except Format (@Lambda.Factory T) := + t.foldlM (fun f block => do + let f' ← genBlockFactory block + f.addFactory f' + ) Factory.default --------------------------------------------------------------------- diff --git a/Strata/Languages/Boogie/DDMTransform/Translate.lean b/Strata/Languages/Boogie/DDMTransform/Translate.lean index 0c563dbb4..83cb63f8c 100644 --- a/Strata/Languages/Boogie/DDMTransform/Translate.lean +++ b/Strata/Languages/Boogie/DDMTransform/Translate.lean @@ -251,10 +251,13 @@ partial def translateLMonoTy (bindings : TransBindings) (arg : Arg) : | .type (.syn syn) _md => let ty := syn.toLHSLMonoTy pure ty - | .type (.data ldatatype) => + | .type (.data (ldatatype :: _)) _md => -- Datatype Declaration + -- TODO: For mutual blocks, need to find the specific datatype by name let args := ldatatype.typeArgs.map LMonoTy.ftvar pure (.tcons ldatatype.name args) + | .type (.data []) _md => + TransM.error "Empty mutual datatype block" | _ => TransM.error s!"translateLMonoTy not yet implemented for this declaration: \ @@ -1347,7 +1350,7 @@ def translateDatatype (p : Program) (bindings : TransBindings) (op : Operation) typeArgs := typeArgs constrs := [{ name := datatypeName, args := [], testerName := "" }] constrs_ne := by simp } - let placeholderDecl := Boogie.Decl.type (.data placeholderLDatatype) + let placeholderDecl := Boogie.Decl.type (.data [placeholderLDatatype]) let bindingsWithPlaceholder := { bindings with freeVars := bindings.freeVars.push placeholderDecl } -- Extract constructor information (possibly recursive) @@ -1383,7 +1386,7 @@ def translateDatatype (p : Program) (bindings : TransBindings) (op : Operation) Boogie.Decl.func func -- Only includes typeDecl, factory functions generated later - let typeDecl := Boogie.Decl.type (.data ldatatype) + let typeDecl := Boogie.Decl.type (.data [ldatatype]) let allDecls := [typeDecl] /- diff --git a/Strata/Languages/Boogie/Env.lean b/Strata/Languages/Boogie/Env.lean index 1ebc91778..b219df9fc 100644 --- a/Strata/Languages/Boogie/Env.lean +++ b/Strata/Languages/Boogie/Env.lean @@ -318,10 +318,13 @@ def Env.merge (cond : Expression.Expr) (E1 E2 : Env) : Env := else Env.performMerge cond E1 E2 (by simp_all) (by simp_all) -def Env.addDatatypes (E: Env) (datatypes: List (Lambda.LDatatype Visibility)) : Except Format Env := do - let f ← Lambda.TypeFactory.genFactory (T:=BoogieLParams) (datatypes.toArray) +def Env.addMutualDatatype (E: Env) (block: Lambda.MutualDatatype Visibility) : Except Format Env := do + let f ← Lambda.genBlockFactory (T:=BoogieLParams) block let env ← E.addFactory f - return { env with datatypes := datatypes.toArray } + return { env with datatypes := E.datatypes.push block } + +def Env.addDatatypes (E: Env) (blocks: List (Lambda.MutualDatatype Visibility)) : Except Format Env := + blocks.foldlM Env.addMutualDatatype E end Boogie diff --git a/Strata/Languages/Boogie/Program.lean b/Strata/Languages/Boogie/Program.lean index b62f23a25..2864d3e79 100644 --- a/Strata/Languages/Boogie/Program.lean +++ b/Strata/Languages/Boogie/Program.lean @@ -79,6 +79,12 @@ def Decl.name (d : Decl) : Expression.Ident := | .proc p _ => p.header.name | .func f _ => f.name +/-- Get all names from a declaration. For mutual datatypes, returns all datatype names. -/ +def Decl.names (d : Decl) : List Expression.Ident := + match d with + | .type t _ => t.names + | _ => [d.name] + def Decl.getVar? (d : Decl) : Option (Expression.Ident × Expression.Ty × Expression.Expr) := match d with @@ -207,7 +213,7 @@ def Program.getInit? (P: Program) (x : Expression.Ident) : Option Expression.Exp def Program.getNames (P: Program) : List Expression.Ident := go P.decls - where go decls := decls.map Decl.name + where go decls := decls.flatMap Decl.names def Program.Type.find? (P : Program) (x : Expression.Ident) : Option TypeDecl := match P.find? .type x with diff --git a/Strata/Languages/Boogie/ProgramType.lean b/Strata/Languages/Boogie/ProgramType.lean index 001acc0a6..95644f9b1 100644 --- a/Strata/Languages/Boogie/ProgramType.lean +++ b/Strata/Languages/Boogie/ProgramType.lean @@ -34,9 +34,11 @@ def typeCheck (C: Boogie.Expression.TyContext) (Env : Boogie.Expression.TyEnv) ( | decl :: drest => do let sourceLoc := Imperative.MetaData.formatFileRangeD decl.metadata (includeEnd? := true) let errorWithSourceLoc := fun e => if sourceLoc.isEmpty then e else f!"{sourceLoc} {e}" - let C := {C with idents := (← C.idents.addWithError decl.name - f!"{sourceLoc} Error in {decl.kind} {decl.name}: \ - a declaration of this name already exists.")} + -- Add all names from the declaration (multiple for mutual datatypes) + let C ← decl.names.foldlM (fun C name => do + let idents ← C.idents.addWithError name + f!"{sourceLoc} Error in {decl.kind} {name}: a declaration of this name already exists." + pure {C with idents}) C let (decl', C, Env) ← match decl with @@ -61,8 +63,8 @@ def typeCheck (C: Boogie.Expression.TyContext) (Env : Boogie.Expression.TyEnv) ( | .syn ts => let Env ← TEnv.addTypeAlias { typeArgs := ts.typeArgs, name := ts.name, type := ts.type } C Env .ok (.type td, C, Env) - | .data d => - let C ← C.addDatatype d + | .data block => + let C ← C.addMutualBlock block .ok (.type td, C, Env) catch e => .error (errorWithSourceLoc e) diff --git a/Strata/Languages/Boogie/ProgramWF.lean b/Strata/Languages/Boogie/ProgramWF.lean index 31ceca4ed..3b4f30db6 100644 --- a/Strata/Languages/Boogie/ProgramWF.lean +++ b/Strata/Languages/Boogie/ProgramWF.lean @@ -291,6 +291,9 @@ If a program typechecks successfully, then every identifier in the list of program decls is not in the original `LContext` -/ theorem Program.typeCheckFunctionDisjoint : Program.typeCheck.go p C T decls acc = .ok (d', T') → (∀ x, x ∈ Program.getNames.go decls → ¬ C.idents.contains x) := by + -- TODO: This proof needs to be updated to handle mutual datatypes (multiple names per decl) + sorry + /- Original proof: induction decls generalizing acc p d' T' T C with | nil => simp[Program.getNames.go] | cons r rs IH => @@ -329,12 +332,16 @@ theorem Program.typeCheckFunctionDisjoint : Program.typeCheck.go p C T decls acc rename_i hmatch2; split at hmatch2 <;> try grind simp only [LContext.addFactoryFunction] at hmatch2; grind done + -/ /-- If a program typechecks succesfully, all identifiers defined in the program are unique. -/ theorem Program.typeCheckFunctionNoDup : Program.typeCheck.go p C T decls acc = .ok (d', T') → (Program.getNames.go decls).Nodup := by + -- TODO: This proof needs to be updated to handle mutual datatypes (multiple names per decl) + sorry + /- Original proof: induction decls generalizing acc p C T with | nil => simp[Program.getNames.go] | cons r rs IH => @@ -376,6 +383,7 @@ theorem Program.typeCheckFunctionNoDup : Program.typeCheck.go p C T decls acc = rename_i hmatch2; split at hmatch2 <;> try grind simp only [LContext.addFactoryFunction] at hmatch2; grind done + -/ /-- The main lemma stating that a program 'p' that passes type checking is well formed diff --git a/Strata/Languages/Boogie/TypeDecl.lean b/Strata/Languages/Boogie/TypeDecl.lean index 6b1d075c4..d963d0050 100644 --- a/Strata/Languages/Boogie/TypeDecl.lean +++ b/Strata/Languages/Boogie/TypeDecl.lean @@ -83,10 +83,13 @@ def TypeSynonym.toRHSLTy (t : TypeSynonym) : LTy := /-! # Boogie Type Declarations -/ +/-- A Boogie type declaration. The `data` variant stores a mutual block + (a non-empty list of mutually recursive datatypes). For non-mutually-recursive + datatypes, this is a single-element list. -/ inductive TypeDecl where | con : TypeConstructor → TypeDecl | syn : TypeSynonym → TypeDecl - | data : LDatatype Visibility → TypeDecl + | data : List (LDatatype Visibility) → TypeDecl deriving Repr instance : ToFormat TypeDecl where @@ -94,12 +97,23 @@ instance : ToFormat TypeDecl where match d with | .con tc => f!"{tc}" | .syn ts => f!"{ts}" - | .data td => f!"{td}" + | .data [] => f!"" + | .data [td] => f!"{td}" + | .data tds => f!"mutual {Std.Format.joinSep (tds.map format) Format.line} end" +/-- Get all names from a TypeDecl. For mutual blocks, returns all datatype names. -/ +def TypeDecl.names (d : TypeDecl) : List Expression.Ident := + match d with + | .con tc => [tc.name] + | .syn ts => [ts.name] + | .data tds => tds.map (·.name) + +/-- Get the primary name of a TypeDecl (first name for mutual blocks). -/ def TypeDecl.name (d : TypeDecl) : Expression.Ident := match d with | .con tc => tc.name | .syn ts => ts.name - | .data td => td.name + | .data [] => "" + | .data (td :: _) => td.name --------------------------------------------------------------------- diff --git a/StrataTest/DL/Lambda/TypeFactoryTests.lean b/StrataTest/DL/Lambda/TypeFactoryTests.lean index 0a4eba152..856968acf 100644 --- a/StrataTest/DL/Lambda/TypeFactoryTests.lean +++ b/StrataTest/DL/Lambda/TypeFactoryTests.lean @@ -57,7 +57,7 @@ info: #3 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[weekTy] (Factory.default : @Factory TestParams) ((LExpr.op () ("Day$Elim" : TestParams.Identifier) .none).mkApp () (.op () ("W" : TestParams.Identifier) (.some (.tcons "Day" [])) :: (List.range 7).map (intConst () ∘ Int.ofNat))) + typeCheckAndPartialEval #[[weekTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("Day$Elim" : TestParams.Identifier) .none).mkApp () (.op () ("W" : TestParams.Identifier) (.some (.tcons "Day" [])) :: (List.range 7).map (intConst () ∘ Int.ofNat))) /-- @@ -69,7 +69,7 @@ info: #true -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[weekTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[weekTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("Day$isW" : TestParams.Identifier) .none).mkApp () [.op () ("W" : TestParams.Identifier) (.some (.tcons "Day" []))]) /-- @@ -81,7 +81,7 @@ info: #false -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[weekTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[weekTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("Day$isW" : TestParams.Identifier) .none).mkApp () [.op () ("M" : TestParams.Identifier) (.some (.tcons "Day" []))]) @@ -116,7 +116,7 @@ info: #3 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[tupTy] Factory.default (fst (prod (intConst () 3) (strConst () "a"))) + typeCheckAndPartialEval #[[tupTy]] Factory.default (fst (prod (intConst () 3) (strConst () "a"))) /-- info: Annotated expression: @@ -127,7 +127,7 @@ info: #a -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[tupTy] Factory.default (snd (prod (intConst () 3) (strConst () "a"))) + typeCheckAndPartialEval #[[tupTy]] Factory.default (snd (prod (intConst () 3) (strConst () "a"))) /-- @@ -139,7 +139,7 @@ info: #1 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[tupTy] Factory.default (fst (snd (prod (strConst () "a") (prod (intConst () 1) (strConst () "b"))))) + typeCheckAndPartialEval #[[tupTy]] Factory.default (fst (snd (prod (strConst () "a") (prod (intConst () 1) (strConst () "b"))))) -- Test 3: Polymorphic Lists @@ -171,7 +171,7 @@ info: #1 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) ((LExpr.op () ("List$Elim" : TestParams.Identifier) .none).mkApp () [nil, (intConst () 1), .abs () .none (.abs () .none (.abs () .none (intConst () 1)))]) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("List$Elim" : TestParams.Identifier) .none).mkApp () [nil, (intConst () 1), .abs () .none (.abs () .none (.abs () .none (intConst () 1)))]) -- Test: elim(cons 1 nil, 0, fun x y => x) -> (fun x y => x) 1 nil @@ -185,7 +185,7 @@ info: #2 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) ((LExpr.op () ("List$Elim" : TestParams.Identifier) .none).mkApp () [listExpr [intConst () 2], intConst () 0, .abs () .none (.abs () .none (.abs () .none (bvar () 2)))]) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("List$Elim" : TestParams.Identifier) .none).mkApp () [listExpr [intConst () 2], intConst () 0, .abs () .none (.abs () .none (.abs () .none (bvar () 2)))]) -- Test testers (isNil and isCons) @@ -197,7 +197,7 @@ info: #true -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("isNil" : TestParams.Identifier) .none).mkApp () [nil]) /-- info: Annotated expression: @@ -208,7 +208,7 @@ info: #false -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("isNil" : TestParams.Identifier) .none).mkApp () [cons (intConst () 1) nil]) /-- info: Annotated expression: @@ -219,7 +219,7 @@ info: #false -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("isCons" : TestParams.Identifier) .none).mkApp () [nil]) /-- info: Annotated expression: @@ -230,7 +230,7 @@ info: #true -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("isCons" : TestParams.Identifier) .none).mkApp () [cons (intConst () 1) nil]) -- But a non-value should NOT reduce @@ -247,7 +247,7 @@ info: ((~isCons : (arrow (List int) bool)) (~l : (List int))) #guard_msgs in #eval format $ do let f ← ((Factory.default : @Factory TestParams).addFactoryFunc ex_list) - (typeCheckAndPartialEval (T:=TestParams) #[listTy] f + (typeCheckAndPartialEval (T:=TestParams) #[[listTy]] f ((LExpr.op () ("isCons" : TestParams.Identifier) (some (LMonoTy.arrow (.tcons "List" [.int]) .bool))).mkApp () [.op () "l" .none])) -- Test destructors @@ -261,7 +261,7 @@ info: #1 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("hd" : TestParams.Identifier) .none).mkApp () [cons (intConst () 1) nil]) /-- @@ -272,7 +272,7 @@ info: (((~Cons : (arrow int (arrow (List int) (List int)))) #2) (~Nil : (List in -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("tl" : TestParams.Identifier) .none).mkApp () [cons (intConst () 1) (cons (intConst () 2) nil)]) -- Destructor does not evaluate on a different constructor @@ -284,7 +284,7 @@ info: Annotated expression: ((~tl : (arrow (List $__ty1) (List $__ty1))) (~Nil : info: ((~tl : (arrow (List $__ty1) (List $__ty1))) (~Nil : (List $__ty1)))-/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("tl" : TestParams.Identifier) .none).mkApp () [nil]) @@ -308,7 +308,7 @@ info: #7 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy, tupTy] (IntBoolFactory : @Factory TestParams) + typeCheckAndPartialEval #[[listTy], [tupTy]] (IntBoolFactory : @Factory TestParams) ((LExpr.op () ("List$Elim" : TestParams.Identifier) .none).mkApp () [listExpr [(prod (intConst () 3) (strConst () "a")), (prod (intConst () 4) (strConst () "b"))], intConst () 0, @@ -332,7 +332,7 @@ info: #3 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (IntBoolFactory : @Factory TestParams) (length (listExpr [strConst () "a", strConst () "b", strConst () "c"])) + typeCheckAndPartialEval #[[listTy]] (IntBoolFactory : @Factory TestParams) (length (listExpr [strConst () "a", strConst () "b", strConst () "c"])) /-- info: Annotated expression: @@ -343,7 +343,7 @@ info: #15 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (IntBoolFactory : @Factory TestParams) (length (listExpr ((List.range 15).map (intConst () ∘ Int.ofNat)))) + typeCheckAndPartialEval #[[listTy]] (IntBoolFactory : @Factory TestParams) (length (listExpr ((List.range 15).map (intConst () ∘ Int.ofNat)))) /- Append is trickier since it takes in two arguments, so the eliminator returns @@ -367,7 +367,7 @@ info: (((~Cons : (arrow int (arrow (List int) (List int)))) #2) (((~Cons : (arro -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (IntBoolFactory : @Factory TestParams) (append list1 list2) + typeCheckAndPartialEval #[[listTy]] (IntBoolFactory : @Factory TestParams) (append list1 list2) -- 2. Preorder traversal of binary tree @@ -420,7 +420,7 @@ info: (((~Cons : (arrow int (arrow (List int) (List int)))) #1) (((~Cons : (arro -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy, binTreeTy] (IntBoolFactory : @Factory TestParams) (toList tree1) + typeCheckAndPartialEval #[[listTy], [binTreeTy]] (IntBoolFactory : @Factory TestParams) (toList tree1) -- 3. Infinite-ary trees namespace Tree @@ -463,7 +463,7 @@ info: #3 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[treeTy] (IntBoolFactory : @Factory TestParams) (height 0 tree1) + typeCheckAndPartialEval #[[treeTy]] (IntBoolFactory : @Factory TestParams) (height 0 tree1) /--info: Annotated expression: ((((~tree$Elim : (arrow (tree int) (arrow (arrow int int) (arrow (arrow (arrow int (tree int)) (arrow (arrow int int) int)) int)))) ((~Node : (arrow (arrow int (tree int)) (tree int))) (λ ((~Node : (arrow (arrow int (tree int)) (tree int))) (λ (if ((((~Int.Add : (arrow int (arrow int int))) %1) %0) == #0) then ((~Node : (arrow (arrow int (tree int)) (tree int))) (λ ((~Leaf : (arrow int (tree int))) #3))) else ((~Leaf : (arrow int (tree int))) #4))))))) (λ #0)) (λ (λ (((~Int.Add : (arrow int (arrow int int))) #1) (%0 #1))))) @@ -473,7 +473,7 @@ info: #2 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[treeTy] (IntBoolFactory : @Factory TestParams) (height 1 tree1) + typeCheckAndPartialEval #[[treeTy]] (IntBoolFactory : @Factory TestParams) (height 1 tree1) end Tree @@ -490,7 +490,7 @@ def badTy1 : LDatatype Unit := {name := "Bad", typeArgs := [], constrs := [badCo /-- info: Error in constructor C: Non-strictly positive occurrence of Bad in type (arrow Bad Bad) -/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[badTy1] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[badTy1]] (IntBoolFactory : @Factory TestParams) (intConst () 0) /- 2.Non-strictly positive type @@ -502,7 +502,7 @@ def badTy2 : LDatatype Unit := {name := "Bad", typeArgs := ["a"], constrs := [ba /-- info: Error in constructor C: Non-strictly positive occurrence of Bad in type (arrow (arrow (Bad a) int) int)-/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[badTy2] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[badTy2]] (IntBoolFactory : @Factory TestParams) (intConst () 0) /- 3. Non-strictly positive type 2 @@ -514,7 +514,7 @@ def badTy3 : LDatatype Unit := {name := "Bad", typeArgs := ["a"], constrs := [ba /--info: Error in constructor C: Non-strictly positive occurrence of Bad in type (arrow (Bad a) int)-/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[badTy3] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[badTy3]] (IntBoolFactory : @Factory TestParams) (intConst () 0) /- 4. Strictly positive type @@ -531,7 +531,7 @@ def goodTy1 : LDatatype Unit := {name := "Good", typeArgs := ["a"], constrs := [ info: #0 -/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[goodTy1] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[goodTy1]] (IntBoolFactory : @Factory TestParams) (intConst () 0) /- 5. Non-uniform type @@ -542,7 +542,7 @@ def nonUnifTy1 : LDatatype Unit := {name := "Nonunif", typeArgs := ["a"], constr /-- info: Error in constructor C: Non-uniform occurrence of Nonunif, which is applied to [(List a)] when it should be applied to [a]-/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[listTy, nonUnifTy1] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[listTy], [nonUnifTy1]] (IntBoolFactory : @Factory TestParams) (intConst () 0) /- 6. Nested types are allowed, though they won't produce a useful elimination principle @@ -558,7 +558,7 @@ def nestTy1 : LDatatype Unit := {name := "Nest", typeArgs := ["a"], constrs := [ info: #0 -/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[listTy, nestTy1] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[listTy], [nestTy1]] (IntBoolFactory : @Factory TestParams) (intConst () 0) /- 7. 2 constructors with the same name: @@ -575,7 +575,7 @@ Existing Function: func C : ∀[a]. ((x : int)) → (Bad a); New Function:func C : ∀[a]. ((x : (Bad a))) → (Bad a); -/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[badTy4] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[badTy4]] (IntBoolFactory : @Factory TestParams) (intConst () 0) /- 8. Constructor with same name as function not allowed @@ -588,6 +588,172 @@ def badTy5 : LDatatype Unit := {name := "Bad", typeArgs := [], constrs := [badCo Existing Function: func Int.Add : ((x : int) (y : int)) → int; New Function:func Int.Add : ((x : int)) → Bad;-/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[badTy5] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[badTy5]] (IntBoolFactory : @Factory TestParams) (intConst () 0) + +--------------------------------------------------------------------- +-- Test 9: Mutually recursive datatypes (RoseTree and Forest) +--------------------------------------------------------------------- + +section MutualRecursion + +/- +type RoseTree a = Node a (Forest a) +type Forest a = FNil | FCons (RoseTree a) (Forest a) +-/ + +def nodeConstr' : LConstr Unit := {name := "Node", args := [("val", .ftvar "a"), ("children", .tcons "Forest" [.ftvar "a"])], testerName := "isNode"} +def roseTreeTy' : LDatatype Unit := {name := "RoseTree", typeArgs := ["a"], constrs := [nodeConstr'], constrs_ne := rfl} + +def fnilConstr' : LConstr Unit := {name := "FNil", args := [], testerName := "isFNil"} +def fconsConstr' : LConstr Unit := {name := "FCons", args := [("head", .tcons "RoseTree" [.ftvar "a"]), ("tail", .tcons "Forest" [.ftvar "a"])], testerName := "isFCons"} +def forestTy' : LDatatype Unit := {name := "Forest", typeArgs := ["a"], constrs := [fnilConstr', fconsConstr'], constrs_ne := rfl} + +def roseForestBlock : MutualDatatype Unit := [roseTreeTy', forestTy'] + +-- Syntactic sugar +def node' (v children : LExpr TestParams.mono) : LExpr TestParams.mono := + (LExpr.op () ("Node" : TestParams.Identifier) .none).mkApp () [v, children] +def fnil' : LExpr TestParams.mono := .op () ("FNil" : TestParams.Identifier) .none +def fcons' (hd tl : LExpr TestParams.mono) : LExpr TestParams.mono := + (LExpr.op () ("FCons" : TestParams.Identifier) .none).mkApp () [hd, tl] + +-- Test testers +/-- info: Annotated expression: +((~isNode : (arrow (RoseTree int) bool)) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #1) (~FNil : (Forest int)))) + +--- +info: #true +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[roseForestBlock] (Factory.default : @Factory TestParams) + ((LExpr.op () ("isNode" : TestParams.Identifier) .none).mkApp () [node' (intConst () 1) fnil']) + +/-- info: Annotated expression: +((~isFNil : (arrow (Forest $__ty17) bool)) (~FNil : (Forest $__ty17))) + +--- +info: #true +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[roseForestBlock] (Factory.default : @Factory TestParams) + ((LExpr.op () ("isFNil" : TestParams.Identifier) .none).mkApp () [fnil']) + +/-- info: Annotated expression: +((~isFCons : (arrow (Forest int) bool)) (((~FCons : (arrow (RoseTree int) (arrow (Forest int) (Forest int)))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #1) (~FNil : (Forest int)))) (~FNil : (Forest int)))) + +--- +info: #true +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[roseForestBlock] (Factory.default : @Factory TestParams) + ((LExpr.op () ("isFCons" : TestParams.Identifier) .none).mkApp () [fcons' (node' (intConst () 1) fnil') fnil']) + +-- Test destructors +/-- info: Annotated expression: +((~val : (arrow (RoseTree int) int)) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #42) (~FNil : (Forest int)))) + +--- +info: #42 +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[roseForestBlock] (Factory.default : @Factory TestParams) + ((LExpr.op () ("val" : TestParams.Identifier) .none).mkApp () [node' (intConst () 42) fnil']) + +/-- info: Annotated expression: +((~head : (arrow (Forest int) (RoseTree int))) (((~FCons : (arrow (RoseTree int) (arrow (Forest int) (Forest int)))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #7) (~FNil : (Forest int)))) (~FNil : (Forest int)))) + +--- +info: (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #7) (~FNil : (Forest int))) +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[roseForestBlock] (Factory.default : @Factory TestParams) + ((LExpr.op () ("head" : TestParams.Identifier) .none).mkApp () [fcons' (node' (intConst () 7) fnil') fnil']) + +--------------------------------------------------------------------- +-- Test 10: Eliminator on mutually recursive types - computing tree size +--------------------------------------------------------------------- + +/- +A non-trivial rose tree: + 1 + /|\ + 2 3 4 + | + 5 +treeSize = 5 +-/ + +def nodeCaseFn' : LExpr TestParams.mono := + .abs () .none (.abs () .none (.abs () .none (addOp (intConst () 1) (.bvar () 0)))) + +def fnilCaseFn' : LExpr TestParams.mono := intConst () 0 + +def fconsCaseFn' : LExpr TestParams.mono := + .abs () .none (.abs () .none (.abs () .none (.abs () .none (addOp (.bvar () 1) (.bvar () 0))))) + +def treeSize' (t : LExpr TestParams.mono) : LExpr TestParams.mono := + (LExpr.op () ("RoseTree$Elim" : TestParams.Identifier) .none).mkApp () [t, nodeCaseFn', fnilCaseFn', fconsCaseFn'] + +def roseTree5 : LExpr TestParams.mono := + node' (intConst () 1) + (fcons' (node' (intConst () 2) fnil') + (fcons' (node' (intConst () 3) (fcons' (node' (intConst () 5) fnil') fnil')) + (fcons' (node' (intConst () 4) fnil') fnil'))) + +-- treeSize (Node 1 FNil) = 1 +/-- info: Annotated expression: +(((((~RoseTree$Elim : (arrow (RoseTree int) (arrow (arrow int (arrow (Forest int) (arrow int int))) (arrow int (arrow (arrow (RoseTree int) (arrow (Forest int) (arrow int (arrow int int)))) int))))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #1) (~FNil : (Forest int)))) (λ (λ (λ (((~Int.Add : (arrow int (arrow int int))) #1) %0))))) #0) (λ (λ (λ (λ (((~Int.Add : (arrow int (arrow int int))) %1) %0)))))) + +--- +info: #1 +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[roseForestBlock] (IntBoolFactory : @Factory TestParams) + (treeSize' (node' (intConst () 1) fnil')) + +-- treeSize roseTree5 = 5 +/-- info: Annotated expression: +(((((~RoseTree$Elim : (arrow (RoseTree int) (arrow (arrow int (arrow (Forest int) (arrow int int))) (arrow int (arrow (arrow (RoseTree int) (arrow (Forest int) (arrow int (arrow int int)))) int))))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #1) (((~FCons : (arrow (RoseTree int) (arrow (Forest int) (Forest int)))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #2) (~FNil : (Forest int)))) (((~FCons : (arrow (RoseTree int) (arrow (Forest int) (Forest int)))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #3) (((~FCons : (arrow (RoseTree int) (arrow (Forest int) (Forest int)))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #5) (~FNil : (Forest int)))) (~FNil : (Forest int))))) (((~FCons : (arrow (RoseTree int) (arrow (Forest int) (Forest int)))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #4) (~FNil : (Forest int)))) (~FNil : (Forest int))))))) (λ (λ (λ (((~Int.Add : (arrow int (arrow int int))) #1) %0))))) #0) (λ (λ (λ (λ (((~Int.Add : (arrow int (arrow int int))) %1) %0)))))) + +--- +info: #5 +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[roseForestBlock] (IntBoolFactory : @Factory TestParams) + (treeSize' roseTree5) + +--------------------------------------------------------------------- +-- Test 11: Non-strictly positive mutual types should be rejected +--------------------------------------------------------------------- + +/- +type BadA = MkA (BadB -> int) +type BadB = MkB BadA + +BadA has BadB in negative position (left of arrow), which is non-strictly positive +since BadB is in the same mutual block. +-/ + +def mkAConstr : LConstr Unit := {name := "MkA", args := [("f", .arrow (.tcons "BadB" []) .int)], testerName := "isMkA"} +def badATy : LDatatype Unit := {name := "BadA", typeArgs := [], constrs := [mkAConstr], constrs_ne := rfl} + +def mkBConstr : LConstr Unit := {name := "MkB", args := [("a", .tcons "BadA" [])], testerName := "isMkB"} +def badBTy : LDatatype Unit := {name := "BadB", typeArgs := [], constrs := [mkBConstr], constrs_ne := rfl} + +def badMutualBlock : MutualDatatype Unit := [badATy, badBTy] + +/-- info: Error in constructor MkA: Non-strictly positive occurrence of BadB in type (arrow BadB int)-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[badMutualBlock] (Factory.default : @Factory TestParams) (intConst () 0) + +end MutualRecursion end Lambda diff --git a/StrataTest/Languages/Boogie/DatatypeVerificationTests.lean b/StrataTest/Languages/Boogie/DatatypeVerificationTests.lean index 6eaa6f5f9..fc7d56f30 100644 --- a/StrataTest/Languages/Boogie/DatatypeVerificationTests.lean +++ b/StrataTest/Languages/Boogie/DatatypeVerificationTests.lean @@ -98,7 +98,7 @@ def mkProgramWithDatatypes body := body } - let decls := datatypes.map (fun d => Decl.type (.data d) .empty) + let decls := datatypes.map (fun d => Decl.type (.data [d]) .empty) return { decls := decls ++ [Decl.proc proc .empty] } /-! ## Helper for Running Tests -/ diff --git a/StrataTest/Languages/Boogie/ExprEvalTest.lean b/StrataTest/Languages/Boogie/ExprEvalTest.lean index 477df201d..3d07b9066 100644 --- a/StrataTest/Languages/Boogie/ExprEvalTest.lean +++ b/StrataTest/Languages/Boogie/ExprEvalTest.lean @@ -187,7 +187,7 @@ open Lambda.LTy.Syntax -- This may take a while (~ 5min) -#eval (checkFactoryOps false) +-- #eval (checkFactoryOps false) open Plausible TestGen diff --git a/StrataTest/Languages/Boogie/SMTEncoderDatatypeTest.lean b/StrataTest/Languages/Boogie/SMTEncoderDatatypeTest.lean index 80f3ba9cb..c4a047d40 100644 --- a/StrataTest/Languages/Boogie/SMTEncoderDatatypeTest.lean +++ b/StrataTest/Languages/Boogie/SMTEncoderDatatypeTest.lean @@ -71,7 +71,7 @@ def treeDatatype : LDatatype Visibility := Convert an expression to full SMT string including datatype declarations. -/ def toSMTStringWithDatatypes (e : LExpr BoogieLParams.mono) (datatypes : List (LDatatype Visibility)) : IO String := do - match Env.init.addDatatypes datatypes with + match Env.init.addDatatypes (datatypes.map (fun d => [d])) with | .error msg => return s!"Error creating environment: {msg}" | .ok env => match toSMTTerm env [] e SMT.Context.default with From 8ed12e7ee98ab6e5a8639072db89ba975417ff54 Mon Sep 17 00:00:00 2001 From: Josh Cohen Date: Fri, 16 Jan 2026 13:12:31 -0500 Subject: [PATCH 2/5] Add checks for type ordering and nonemptiness Also includes tests --- Strata/DL/Lambda/LExprTypeEnv.lean | 7 +-- Strata/DL/Lambda/TypeFactory.lean | 50 +++++++++++++++++++++- StrataTest/DL/Lambda/TypeFactoryTests.lean | 39 +++++++++++++++++ 3 files changed, 89 insertions(+), 7 deletions(-) diff --git a/Strata/DL/Lambda/LExprTypeEnv.lean b/Strata/DL/Lambda/LExprTypeEnv.lean index 9961d4511..6690eb8d8 100644 --- a/Strata/DL/Lambda/LExprTypeEnv.lean +++ b/Strata/DL/Lambda/LExprTypeEnv.lean @@ -352,15 +352,12 @@ def LContext.addDatatype [Inhabited T.IDMeta] [Inhabited T.Metadata] (C: LContex /-- Add a mutual block of datatypes to the context. -/ def LContext.addMutualBlock [Inhabited T.IDMeta] [Inhabited T.Metadata] [ToFormat T.IDMeta] (C: LContext T) (block: MutualDatatype T.IDMeta) : Except Format (LContext T) := do - if block.isEmpty then return C -- Check for name clashes with known types for d in block do if C.knownTypes.containsName d.name then throw f!"Cannot name datatype same as known type!\n{d}\nKnownTypes' names:\n{C.knownTypes.keywords}" - -- Add all datatypes to the type factory - let mut ds := C.datatypes - for d in block do - ds ← ds.addDatatype d + -- Add all datatypes to the type factory with validation + let ds ← C.datatypes.addMutualBlock block C.knownTypes.keywords -- Generate factory functions for the whole mutual block let f ← genBlockFactory block let fs ← C.functions.addFactory f diff --git a/Strata/DL/Lambda/TypeFactory.lean b/Strata/DL/Lambda/TypeFactory.lean index 667197bac..ab1a81918 100644 --- a/Strata/DL/Lambda/TypeFactory.lean +++ b/Strata/DL/Lambda/TypeFactory.lean @@ -161,6 +161,26 @@ def checkStrictPosUnifMutual (block: MutualDatatype IDMeta) : Except Format Unit ) () ) () +/-- +Validate a mutual block: check non-empty and no duplicate names. +-/ +def validateMutualBlock (block: MutualDatatype IDMeta) : Except Format Unit := do + if block.isEmpty then + .error f!"Error: Empty mutual block is not allowed" + let names := block.map (·.name) + let duplicates := names.filter (fun n => names.count n > 1) + match duplicates.head? with + | some dup => throw f!"Duplicate datatype name in mutual block: {dup}" + | none => pure () + +/-- +Get all type constructor names referenced in a type. +-/ +def getTypeRefs (ty: LMonoTy) : List String := + match ty with + | .tcons n args => n :: args.flatMap getTypeRefs + | _ => [] + --------------------------------------------------------------------- -- Generating constructors and eliminators @@ -549,8 +569,32 @@ def TypeFactory.allDatatypes (t : @TypeFactory IDMeta) : List (LDatatype IDMeta) def TypeFactory.getType (F : @TypeFactory IDMeta) (name : String) : Option (LDatatype IDMeta) := F.allDatatypes.find? (fun d => d.name == name) -/-- Add a mutual block to the TypeFactory, checking for duplicate names. -/ -def TypeFactory.addMutualBlock (t : @TypeFactory IDMeta) (block : MutualDatatype IDMeta) : Except Format (@TypeFactory IDMeta) := do +/-- Get all datatype names in the TypeFactory. -/ +def TypeFactory.allTypeNames (t : @TypeFactory IDMeta) : List String := + t.allDatatypes.map (·.name) + +/-- +Validate that all type references in a mutual block refer to either: +1. Known primitive types (passed as parameter) +2. Types already in the TypeFactory +3. Types within the same mutual block +-/ +def TypeFactory.validateTypeReferences (t : @TypeFactory IDMeta) (block : MutualDatatype IDMeta) (knownTypes : List String) : Except Format Unit := do + let blockNames := block.map (·.name) + let existingNames := t.allTypeNames + let validNames := knownTypes ++ existingNames ++ blockNames + for d in block do + for c in d.constrs do + for (_, ty) in c.args do + for ref in getTypeRefs ty do + if !validNames.contains ref then + throw f!"Error in datatype {d.name}, constructor {c.name.name}: Undefined type '{ref}'" + +/-- Add a mutual block to the TypeFactory, with validation. -/ +def TypeFactory.addMutualBlock (t : @TypeFactory IDMeta) (block : MutualDatatype IDMeta) (knownTypes : List String := []) : Except Format (@TypeFactory IDMeta) := do + -- Validate block structure + validateMutualBlock block + -- Check for duplicate names with existing types for d in block do match t.getType d.name with | some d' => throw f!"A datatype of name {d.name} already exists! \ @@ -558,6 +602,8 @@ def TypeFactory.addMutualBlock (t : @TypeFactory IDMeta) (block : MutualDatatype Existing Type: {d'}\n\ New Type:{d}" | none => pure () + -- Validate type references + t.validateTypeReferences block knownTypes .ok (t.push block) /-- Add a single datatype (as a single-element block). -/ diff --git a/StrataTest/DL/Lambda/TypeFactoryTests.lean b/StrataTest/DL/Lambda/TypeFactoryTests.lean index 856968acf..351b7d164 100644 --- a/StrataTest/DL/Lambda/TypeFactoryTests.lean +++ b/StrataTest/DL/Lambda/TypeFactoryTests.lean @@ -754,6 +754,45 @@ def badMutualBlock : MutualDatatype Unit := [badATy, badBTy] #eval format $ typeCheckAndPartialEval #[badMutualBlock] (Factory.default : @Factory TestParams) (intConst () 0) +--------------------------------------------------------------------- +-- Test 12: Empty mutual block should be rejected +--------------------------------------------------------------------- + +def emptyBlock : MutualDatatype Unit := [] + +/-- info: Error: Empty mutual block is not allowed -/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[emptyBlock] (IntBoolFactory : @Factory TestParams) (intConst () 0) + +--------------------------------------------------------------------- +-- Test 13: Type reference in wrong order should be rejected +--------------------------------------------------------------------- + +-- Wrapper references List, but List is defined after Wrapper +def wrapperConstr' : LConstr Unit := {name := "MkWrapper", args := [("xs", .tcons "List" [.int])], testerName := "isMkWrapper"} +def wrapperTy' : LDatatype Unit := {name := "Wrapper", typeArgs := [], constrs := [wrapperConstr'], constrs_ne := rfl} + +/-- info: Error in datatype Wrapper, constructor MkWrapper: Undefined type 'List' -/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[[wrapperTy'], [listTy]] (IntBoolFactory : @Factory TestParams) (intConst () 0) + +--------------------------------------------------------------------- +-- Test 14: Type depending on previously defined type should work +--------------------------------------------------------------------- + +-- List is defined before Wrapper - correct order +/-- info: Annotated expression: +#0 + +--- +info: #0 +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[[listTy], [wrapperTy']] (IntBoolFactory : @Factory TestParams) (intConst () 0) + end MutualRecursion end Lambda From f8419b5589d51cfb2c0bfbc5872c9a1bb135a1e7 Mon Sep 17 00:00:00 2001 From: Josh Cohen Date: Fri, 16 Jan 2026 17:06:10 -0500 Subject: [PATCH 3/5] Add mutually recursive types to SMT encoding Remove topological ordering, require TypeContext for order Adds tests for mutually recursive SMT encoding and for Boogie verif --- Strata/DL/Lambda/TypeFactory.lean | 3 + Strata/DL/SMT/Solver.lean | 16 ++ Strata/Languages/Boogie/SMTEncoder.lean | 128 +++++-------- .../Boogie/DatatypeVerificationTests.lean | 144 ++++++++++++++ .../Boogie/SMTEncoderDatatypeTest.lean | 177 ++++++++---------- 5 files changed, 282 insertions(+), 186 deletions(-) diff --git a/Strata/DL/Lambda/TypeFactory.lean b/Strata/DL/Lambda/TypeFactory.lean index ab1a81918..56c287a6e 100644 --- a/Strata/DL/Lambda/TypeFactory.lean +++ b/Strata/DL/Lambda/TypeFactory.lean @@ -556,6 +556,9 @@ instance: ToFormat (@TypeFactory IDMeta) where | ds => f!"mutual {Std.Format.joinSep (ds.map format) Format.line} end" Std.Format.joinSep (f.toList.map formatBlock) f!"{Format.line}" +instance [Repr IDMeta] : Repr (@TypeFactory IDMeta) where + reprPrec f n := reprPrec f.toList n + instance : Inhabited (@TypeFactory IDMeta) where default := #[] diff --git a/Strata/DL/SMT/Solver.lean b/Strata/DL/SMT/Solver.lean index be75bb67d..8aedb7518 100644 --- a/Strata/DL/SMT/Solver.lean +++ b/Strata/DL/SMT/Solver.lean @@ -131,6 +131,22 @@ def declareDatatype (id : String) (params : List String) (constructors : List St then emitln s!"(declare-datatype {id} ({cInline}))" else emitln s!"(declare-datatype {id} (par ({pInline}) ({cInline})))" +/-- Declare multiple mutually recursive datatypes. Each element is (name, params, constructors). -/ +def declareDatatypes (dts : List (String × List String × List String)) : SolverM Unit := do + if dts.isEmpty then return + -- Build sort declarations: ((name arity) ...) + let sortDecls := dts.map fun (name, params, _) => s!"({name} {params.length})" + let sortDeclStr := String.intercalate " " sortDecls + -- Build datatype bodies + let bodies := dts.map fun (_, params, constrs) => + let cInline := String.intercalate " " constrs + if params.isEmpty then s!"({cInline})" + else + let pInline := String.intercalate " " params + s!"(par ({pInline}) ({cInline}))" + let bodyStr := String.intercalate "\n " bodies + emitln s!"(declare-datatypes ({sortDeclStr})\n ({bodyStr}))" + private def readlnD (dflt : String) : SolverM String := do match (← read).smtLibOutput with | .some stdout => stdout.getLine diff --git a/Strata/Languages/Boogie/SMTEncoder.lean b/Strata/Languages/Boogie/SMTEncoder.lean index 0cc448514..9d98e205d 100644 --- a/Strata/Languages/Boogie/SMTEncoder.lean +++ b/Strata/Languages/Boogie/SMTEncoder.lean @@ -35,9 +35,13 @@ structure SMT.Context where ifs : Array SMT.IF := #[] axms : Array Term := #[] tySubst: Map String TermType := [] - datatypes : Array (LDatatype BoogieLParams.IDMeta) := #[] + /-- Stores the TypeFactory for datatype ordering during emission. + This is redundant with Env.datatypes but needed here to preserve + the correct dependency order when emitting declare-datatypes. -/ + typeFactory : @Lambda.TypeFactory BoogieLParams.IDMeta := #[] + seenDatatypes : Std.HashSet String := {} datatypeFuns : Map String (Op.DatatypeFuncs × LConstr BoogieLParams.IDMeta) := Map.empty -deriving Repr, DecidableEq, Inhabited +deriving Repr, Inhabited def SMT.Context.default : SMT.Context := {} @@ -65,7 +69,7 @@ def SMT.Context.removeSubst (ctx : SMT.Context) (newSubst: Map String TermType) { ctx with tySubst := newSubst.foldl (fun acc_m p => acc_m.erase p.fst) ctx.tySubst } def SMT.Context.hasDatatype (ctx : SMT.Context) (name : String) : Bool := - (ctx.datatypes.map LDatatype.name).contains name + ctx.seenDatatypes.contains name def SMT.Context.addDatatype (ctx : SMT.Context) (d : LDatatype BoogieLParams.IDMeta) : SMT.Context := if ctx.hasDatatype d.name then ctx @@ -74,7 +78,10 @@ def SMT.Context.addDatatype (ctx : SMT.Context) (d : LDatatype BoogieLParams.IDM let m := Map.union ctx.datatypeFuns (c.fmap (fun (_, x) => (.constructor, x))) let m := Map.union m (i.fmap (fun (_, x) => (.tester, x))) let m := Map.union m (s.fmap (fun (_, x) => (.selector, x))) - { ctx with datatypes := ctx.datatypes.push d, datatypeFuns := m } + { ctx with seenDatatypes := ctx.seenDatatypes.insert d.name, datatypeFuns := m } + +def SMT.Context.withTypeFactory (ctx : SMT.Context) (tf : @Lambda.TypeFactory BoogieLParams.IDMeta) : SMT.Context := + { ctx with typeFactory := tf } /-- Helper function to convert LMonoTy to SMT string representation. @@ -93,89 +100,33 @@ private def lMonoTyToSMTString (ty : LMonoTy) : String := else s!"({name} {String.intercalate " " (args.map lMonoTyToSMTString)})" | .ftvar tv => tv -/-- -Build a dependency graph for datatypes. -Returns a mapping from datatype names to their dependencies. --/ -private def buildDatatypeDependencyGraph (datatypes : Array (LDatatype BoogieLParams.IDMeta)) : - Map String (Array String) := - let depMap := datatypes.foldl (fun acc d => - let deps := d.constrs.foldl (fun deps c => - c.args.foldl (fun deps (_, fieldTy) => - match fieldTy with - | .tcons typeName _ => - -- Only include dependencies on other datatypes in our set - if datatypes.any (fun dt => dt.name == typeName) then - deps.push typeName - else deps - | _ => deps - ) deps - ) #[] - acc.insert d.name deps - ) Map.empty - depMap - -/-- -Convert datatype dependency map to OutGraph for Tarjan's algorithm. -Returns the graph and a mapping from node indices to datatype names. --/ -private def dependencyMapToGraph (depMap : Map String (Array String)) : - (n : Nat) × Strata.OutGraph n × Array String := - let names := depMap.keys.toArray - let n := names.size - let nameToIndex : Map String Nat := - names.mapIdx (fun i name => (name, i)) |>.foldl (fun acc (name, i) => acc.insert name i) Map.empty - - let edges := depMap.foldl (fun edges (fromName, deps) => - match nameToIndex.find? fromName with - | none => edges - | some fromIdx => - deps.foldl (fun edges depName => - match nameToIndex.find? depName with - | none => edges - | some toIdx => edges.push (fromIdx, toIdx) - ) edges - ) #[] - - let graph := Strata.OutGraph.ofEdges! n edges.toList - ⟨n, graph, names⟩ +/-- Convert a datatype's constructors to SMT format. -/ +private def datatypeConstructorsToSMT (d : LDatatype BoogieLParams.IDMeta) : List String := + d.constrs.map fun c => + let fieldPairs := c.args.map fun (name, fieldTy) => (name.name, lMonoTyToSMTString fieldTy) + let fieldStrs := fieldPairs.map fun (name, ty) => s!"({name} {ty})" + let fieldsStr := String.intercalate " " fieldStrs + if c.args.isEmpty then s!"({c.name.name})" + else s!"({c.name.name} {fieldsStr})" /-- -Emit datatype declarations to the solver in topologically sorted order. -For each datatype in ctx.datatypes, generates a declare-datatype command -with constructors and selectors following the TypeFactory naming convention. -Dependencies are emitted before the datatypes that depend on them, and -mutually recursive datatypes are not (yet) supported. +Emit datatype declarations to the solver. +Uses the TypeFactory ordering (already topologically sorted). +Only emits datatypes that have been seen (added via addDatatype). +Single-element blocks use declare-datatype, multi-element blocks use declare-datatypes. -/ def SMT.Context.emitDatatypes (ctx : SMT.Context) : Strata.SMT.SolverM Unit := do - if ctx.datatypes.isEmpty then return - - -- Build dependency graph and SCCs - let depMap := buildDatatypeDependencyGraph ctx.datatypes - let ⟨_, graph, names⟩ := dependencyMapToGraph depMap - let sccs := Strata.OutGraph.tarjan graph - - -- Emit datatypes in topological order (reverse of SCC order) - for scc in sccs.reverse do - if scc.size > 1 then - let sccNames := scc.map (fun idx => names[idx]!) - throw (IO.userError s!"Mutually recursive datatypes not supported: {sccNames.toList}") - else - for nodeIdx in scc do - let datatypeName := names[nodeIdx]! - -- Find the datatype by name - match ctx.datatypes.find? (fun d => d.name == datatypeName) with - | none => throw (IO.userError s!"Datatype {datatypeName} not found in context") - | some d => - let constructors ← d.constrs.mapM fun c => do - let fieldPairs := c.args.map fun (name, fieldTy) => (name.name, lMonoTyToSMTString fieldTy) - let fieldStrs := fieldPairs.map fun (name, ty) => s!"({name} {ty})" - let fieldsStr := String.intercalate " " fieldStrs - if c.args.isEmpty then - pure s!"({c.name.name})" - else - pure s!"({c.name.name} {fieldsStr})" - Strata.SMT.Solver.declareDatatype d.name d.typeArgs constructors + for block in ctx.typeFactory.toList do + -- Filter to only datatypes that are used (in seenDatatypes) + let usedBlock := block.filter (fun d => ctx.seenDatatypes.contains d.name) + match usedBlock with + | [] => pure () + | [d] => + let constructors := datatypeConstructorsToSMT d + Strata.SMT.Solver.declareDatatype d.name d.typeArgs constructors + | _ => + let dts := usedBlock.map fun d => (d.name, d.typeArgs, datatypeConstructorsToSMT d) + Strata.SMT.Solver.declareDatatypes dts abbrev BoundVars := List (String × TermType) @@ -606,8 +557,15 @@ partial def toSMTOp (E : Env) (fn : BoogieIdent) (fnty : LMonoTy) (ctx : SMT.Con .ok (.app (Op.uf uf), smt_outty, ctx) end +/-- +Convert expressions to SMT terms. +Sets the TypeFactory on the context if not already set, to ensure correct +datatype emission ordering. +-/ def toSMTTerms (E : Env) (es : List (LExpr BoogieLParams.mono)) (ctx : SMT.Context) : Except Format ((List Term) × SMT.Context) := do + -- Ensure typeFactory is set for correct datatype emission ordering + let ctx := if ctx.typeFactory.isEmpty then ctx.withTypeFactory E.datatypes else ctx match es with | [] => .ok ([], ctx) | e :: erest => @@ -725,7 +683,7 @@ info: "; f\n(declare-fun f0 (Int Int) Int)\n; x\n(declare-const f1 Int)\n(define #eval toSMTTermString (.quant () .all (.some .int) (.bvar () 0) (.quant () .all (.some .int) (.app () (.app () (.op () "f" (.some (.arrow .int (.arrow .int .int)))) (.bvar () 0)) (.bvar () 1)) (.eq () (.app () (.app () (.op () "f" (.some (.arrow .int (.arrow .int .int)))) (.bvar () 0)) (.bvar () 1)) (.fvar () "x" (.some .int))))) - (ctx := SMT.Context.mk #[] #[UF.mk "f" ((TermVar.mk "m" TermType.int) ::(TermVar.mk "n" TermType.int) :: []) TermType.int] #[] #[] [] #[] []) + (ctx := SMT.Context.mk #[] #[UF.mk "f" ((TermVar.mk "m" TermType.int) ::(TermVar.mk "n" TermType.int) :: []) TermType.int] #[] #[] [] #[] {} []) (E := {Env.init with exprEnv := { Env.init.exprEnv with config := { Env.init.exprEnv.config with @@ -743,7 +701,7 @@ info: "; f\n(declare-fun f0 (Int Int) Int)\n; x\n(declare-const f1 Int)\n(define #eval toSMTTermString (.quant () .all (.some .int) (.bvar () 0) (.quant () .all (.some .int) (.bvar () 0) (.eq () (.app () (.app () (.op () "f" (.some (.arrow .int (.arrow .int .int)))) (.bvar () 0)) (.bvar () 1)) (.fvar () "x" (.some .int))))) - (ctx := SMT.Context.mk #[] #[UF.mk "f" ((TermVar.mk "m" TermType.int) ::(TermVar.mk "n" TermType.int) :: []) TermType.int] #[] #[] [] #[] []) + (ctx := SMT.Context.mk #[] #[UF.mk "f" ((TermVar.mk "m" TermType.int) ::(TermVar.mk "n" TermType.int) :: []) TermType.int] #[] #[] [] #[] {} []) (E := {Env.init with exprEnv := { Env.init.exprEnv with config := { Env.init.exprEnv.config with diff --git a/StrataTest/Languages/Boogie/DatatypeVerificationTests.lean b/StrataTest/Languages/Boogie/DatatypeVerificationTests.lean index fc7d56f30..23be698ed 100644 --- a/StrataTest/Languages/Boogie/DatatypeVerificationTests.lean +++ b/StrataTest/Languages/Boogie/DatatypeVerificationTests.lean @@ -613,5 +613,149 @@ info: "Test 8 - Hidden Type Recursion: PASSED\n Verified 1 obligation(s)\n" #guard_msgs in #eval test8_hiddenTypeRecursion +/-! ## Test 9: Mutually Recursive Datatypes with Havoc -/ + +/-- RoseTree a = Node a (Forest a) -/ +def roseTreeDatatype : LDatatype Visibility := + { name := "RoseTree" + typeArgs := ["a"] + constrs := [ + { name := ⟨"Node", .unres⟩, args := [ + (⟨"nodeVal", .unres⟩, .ftvar "a"), + (⟨"children", .unres⟩, .tcons "Forest" [.ftvar "a"]) + ], testerName := "isNode" } + ] + constrs_ne := by decide } + +/-- Forest a = FNil | FCons (RoseTree a) (Forest a) -/ +def forestDatatype : LDatatype Visibility := + { name := "Forest" + typeArgs := ["a"] + constrs := [ + { name := ⟨"FNil", .unres⟩, args := [], testerName := "isFNil" }, + { name := ⟨"FCons", .unres⟩, args := [ + (⟨"head", .unres⟩, .tcons "RoseTree" [.ftvar "a"]), + (⟨"tail", .unres⟩, .tcons "Forest" [.ftvar "a"]) + ], testerName := "isFCons" } + ] + constrs_ne := by decide } + +/-- +Create a Boogie program with a mutual block of datatypes. +-/ +def mkProgramWithMutualDatatypes + (mutualBlock : List (LDatatype Visibility)) + (procName : String) + (body : List Statement) + : Except Format Program := do + let proc : Procedure := { + header := { + name := BoogieIdent.unres procName + typeArgs := [] + inputs := [] + outputs := [] + } + spec := { + modifies := [] + preconditions := [] + postconditions := [] + } + body := body + } + let decls := [Decl.type (.data mutualBlock) .empty] + return { decls := decls ++ [Decl.proc proc .empty] } + +/-- +Test mutually recursive datatypes (RoseTree/Forest) with havoc. + +mutual + datatype RoseTree a = Node a (Forest a) + datatype Forest a = FNil | FCons (RoseTree a) (Forest a) +end + +procedure testMutualRecursive () { + tree := Node 1 FNil; + havoc tree; + val := nodeVal tree; + assume (val == 42); + assert (isNode tree); + + forest := FNil; + havoc forest; + assume (isFCons forest); + assert (not (isFNil forest)); +} +-/ +def test9_mutualRecursiveWithHavoc : IO String := do + let statements : List Statement := [ + -- Create a tree: Node 1 FNil + Statement.init (BoogieIdent.unres "tree") (.forAll [] (LMonoTy.tcons "RoseTree" [.int])) + (LExpr.app () + (LExpr.app () + (LExpr.op () (BoogieIdent.unres "Node") + (.some (LMonoTy.arrow .int (LMonoTy.arrow (LMonoTy.tcons "Forest" [.int]) (LMonoTy.tcons "RoseTree" [.int]))))) + (LExpr.intConst () 1)) + (LExpr.op () (BoogieIdent.unres "FNil") (.some (LMonoTy.tcons "Forest" [.int])))), + + -- Havoc the tree + Statement.havoc (BoogieIdent.unres "tree"), + + -- Extract nodeVal + Statement.init (BoogieIdent.unres "val") (.forAll [] LMonoTy.int) + (LExpr.app () + (LExpr.op () (BoogieIdent.unres "nodeVal") + (.some (LMonoTy.arrow (LMonoTy.tcons "RoseTree" [.int]) .int))) + (LExpr.fvar () (BoogieIdent.unres "tree") (.some (LMonoTy.tcons "RoseTree" [.int])))), + + -- Assume val == 42 + Statement.assume "val_is_42" + (LExpr.eq () + (LExpr.fvar () (BoogieIdent.unres "val") (.some .int)) + (LExpr.intConst () 42)), + + -- Assert tree is a Node (always true for RoseTree) + Statement.assert "tree_is_node" + (LExpr.app () + (LExpr.op () (BoogieIdent.unres "isNode") + (.some (LMonoTy.arrow (LMonoTy.tcons "RoseTree" [.int]) .bool))) + (LExpr.fvar () (BoogieIdent.unres "tree") (.some (LMonoTy.tcons "RoseTree" [.int])))), + + -- Create a forest: FNil + Statement.init (BoogieIdent.unres "forest") (.forAll [] (LMonoTy.tcons "Forest" [.int])) + (LExpr.op () (BoogieIdent.unres "FNil") (.some (LMonoTy.tcons "Forest" [.int]))), + + -- Havoc the forest + Statement.havoc (BoogieIdent.unres "forest"), + + -- Assume forest is FCons + Statement.assume "forest_is_fcons" + (LExpr.app () + (LExpr.op () (BoogieIdent.unres "isFCons") + (.some (LMonoTy.arrow (LMonoTy.tcons "Forest" [.int]) .bool))) + (LExpr.fvar () (BoogieIdent.unres "forest") (.some (LMonoTy.tcons "Forest" [.int])))), + + -- Assert forest is not FNil + Statement.assert "forest_not_fnil" + (LExpr.app () + (LExpr.op () (BoogieIdent.unres "Bool.Not") + (.some (LMonoTy.arrow .bool .bool))) + (LExpr.app () + (LExpr.op () (BoogieIdent.unres "isFNil") + (.some (LMonoTy.arrow (LMonoTy.tcons "Forest" [.int]) .bool))) + (LExpr.fvar () (BoogieIdent.unres "forest") (.some (LMonoTy.tcons "Forest" [.int]))))) + ] + + match mkProgramWithMutualDatatypes [roseTreeDatatype, forestDatatype] "testMutualRecursive" statements with + | .error err => + return s!"Test 9 - Mutual Recursive with Havoc: FAILED (program creation)\n Error: {err.pretty}" + | .ok program => + runVerificationTest "Test 9 - Mutual Recursive with Havoc" program + +/-- +info: "Test 9 - Mutual Recursive with Havoc: PASSED\n Verified 2 obligation(s)\n" +-/ +#guard_msgs in +#eval test9_mutualRecursiveWithHavoc + end Boogie.DatatypeVerificationTests diff --git a/StrataTest/Languages/Boogie/SMTEncoderDatatypeTest.lean b/StrataTest/Languages/Boogie/SMTEncoderDatatypeTest.lean index c4a047d40..94921e74c 100644 --- a/StrataTest/Languages/Boogie/SMTEncoderDatatypeTest.lean +++ b/StrataTest/Languages/Boogie/SMTEncoderDatatypeTest.lean @@ -69,12 +69,15 @@ def treeDatatype : LDatatype Visibility := constrs_ne := by decide } /-- Convert an expression to full SMT string including datatype declarations. +`blocks` is a list of mutual blocks (each block is a list of mutually recursive datatypes). -/ -def toSMTStringWithDatatypes (e : LExpr BoogieLParams.mono) (datatypes : List (LDatatype Visibility)) : IO String := do - match Env.init.addDatatypes (datatypes.map (fun d => [d])) with +def toSMTStringWithDatatypeBlocks (e : LExpr BoogieLParams.mono) (blocks : List (List (LDatatype Visibility))) : IO String := do + match Env.init.addDatatypes blocks with | .error msg => return s!"Error creating environment: {msg}" | .ok env => - match toSMTTerm env [] e SMT.Context.default with + -- Set the TypeFactory for correct datatype emission ordering + let ctx := SMT.Context.default.withTypeFactory env.datatypes + match toSMTTerm env [] e ctx with | .error err => return err.pretty | .ok (smt, ctx) => -- Emit the full SMT output including datatype declarations @@ -95,6 +98,13 @@ def toSMTStringWithDatatypes (e : LExpr BoogieLParams.mono) (datatypes : List (L else return "Invalid UTF-8 in output" +/-- +Convert an expression to full SMT string including datatype declarations. +Each datatype is treated as its own (non-mutual) block. +-/ +def toSMTStringWithDatatypes (e : LExpr BoogieLParams.mono) (datatypes : List (LDatatype Visibility)) : IO String := + toSMTStringWithDatatypeBlocks e (datatypes.map (fun d => [d])) + /-! ## Test Cases with Guard Messages -/ -- Test 1: Simple datatype (Option) - zero-argument constructor @@ -196,7 +206,7 @@ info: (declare-datatype TestOption (par (α) ( #guard_msgs in #eval format <$> toSMTStringWithDatatypes (.fvar () (BoogieIdent.unres "listOfOption") (.some (.tcons "TestList" [.tcons "TestOption" [.int]]))) - [listDatatype, optionDatatype] + [optionDatatype, listDatatype] /-! ## Constructor Application Tests -/ @@ -324,106 +334,12 @@ info: (declare-datatype TestList (par (α) ( (.fvar () (BoogieIdent.unres "xs") (.some (.tcons "TestList" [.int])))) [listDatatype] -/-! ## Complex Dependency Topological Sorting Tests -/ - --- Test 16: Very complex dependency graph requiring sophisticated topological sorting --- Dependencies: Alpha -> Beta, Gamma --- Beta -> Delta, Epsilon --- Gamma -> Epsilon, Zeta --- Delta -> Zeta --- Epsilon -> Zeta --- Actual topological order: Zeta, Epsilon, Gamma, Delta, Beta, Alpha - -/-- Alpha = AlphaValue Beta Gamma -/ -def alphaDatatype : LDatatype Visibility := - { name := "Alpha" - typeArgs := [] - constrs := [ - { name := ⟨"AlphaValue", .unres⟩, args := [ - (⟨"Alpha$AlphaValueProj0", .unres⟩, .tcons "Beta" []), - (⟨"Alpha$AlphaValueProj1", .unres⟩, .tcons "Gamma" []) - ], testerName := "Alpha$isAlphaValue" } - ] - constrs_ne := by decide } - -/-- Beta = BetaValue Delta Epsilon -/ -def betaDatatype : LDatatype Visibility := - { name := "Beta" - typeArgs := [] - constrs := [ - { name := ⟨"BetaValue", .unres⟩, args := [ - (⟨"Beta$BetaValueProj0", .unres⟩, .tcons "Delta" []), - (⟨"Beta$BetaValueProj1", .unres⟩, .tcons "Epsilon" []) - ], testerName := "Beta$isBetaValue" } - ] - constrs_ne := by decide } - -/-- Gamma = GammaValue Epsilon Zeta -/ -def gammaDatatype : LDatatype Visibility := - { name := "Gamma" - typeArgs := [] - constrs := [ - { name := ⟨"GammaValue", .unres⟩, args := [ - (⟨"Gamma$GammaValueProj0", .unres⟩, .tcons "Epsilon" []), - (⟨"Gamma$GammaValueProj1", .unres⟩, .tcons "Zeta" []) - ], testerName := "Gamma$isGammaValue" } - ] - constrs_ne := by decide } - -/-- Delta = DeltaValue Zeta -/ -def deltaDatatype : LDatatype Visibility := - { name := "Delta" - typeArgs := [] - constrs := [ - { name := ⟨"DeltaValue", .unres⟩, args := [(⟨"Delta$DeltaValueProj0", .unres⟩, .tcons "Zeta" [])], testerName := "Delta$isDeltaValue" } - ] - constrs_ne := by decide } - -/-- Epsilon = EpsilonValue Zeta -/ -def epsilonDatatype : LDatatype Visibility := - { name := "Epsilon" - typeArgs := [] - constrs := [ - { name := ⟨"EpsilonValue", .unres⟩, args := [(⟨"Epsilon$EpsilonValueProj0", .unres⟩, .tcons "Zeta" [])], testerName := "Epsilon$isEpsilonValue" } - ] - constrs_ne := by decide } - -/-- Zeta = ZetaValue int -/ -def zetaDatatype : LDatatype Visibility := - { name := "Zeta" - typeArgs := [] - constrs := [ - { name := ⟨"ZetaValue", .unres⟩, args := [(⟨"Zeta$ZetaValueProj0", .unres⟩, .int)], testerName := "Zeta$isZetaValue" } - ] - constrs_ne := by decide } - -/-- -info: (declare-datatype Zeta ( - (ZetaValue (Zeta$ZetaValueProj0 Int)))) -(declare-datatype Epsilon ( - (EpsilonValue (Epsilon$EpsilonValueProj0 Zeta)))) -(declare-datatype Gamma ( - (GammaValue (Gamma$GammaValueProj0 Epsilon) (Gamma$GammaValueProj1 Zeta)))) -(declare-datatype Delta ( - (DeltaValue (Delta$DeltaValueProj0 Zeta)))) -(declare-datatype Beta ( - (BetaValue (Beta$BetaValueProj0 Delta) (Beta$BetaValueProj1 Epsilon)))) -(declare-datatype Alpha ( - (AlphaValue (Alpha$AlphaValueProj0 Beta) (Alpha$AlphaValueProj1 Gamma)))) -; alphaVar -(declare-const f0 Alpha) -(define-fun t0 () Alpha f0) --/ -#guard_msgs in -#eval format <$> toSMTStringWithDatatypes - (.fvar () (BoogieIdent.unres "alphaVar") (.some (.tcons "Alpha" []))) - [alphaDatatype, betaDatatype, gammaDatatype, deltaDatatype, epsilonDatatype, zetaDatatype] +/-! ## Dependency Order Tests -/ --- Test 17: Diamond dependency pattern +-- Test 16: Diamond dependency pattern -- Dependencies: Diamond -> Left, Right -- Left -> Root -- Right -> Root --- Actual topological order: Root, Right, Left, Diamond (or Root, Left, Right, Diamond) /-- Root = RootValue int -/ def rootDatatype : LDatatype Visibility := @@ -480,7 +396,66 @@ info: (declare-datatype Root ( #guard_msgs in #eval format <$> toSMTStringWithDatatypes (.fvar () (BoogieIdent.unres "diamondVar") (.some (.tcons "Diamond" []))) - [diamondDatatype, leftDatatype, rightDatatype, rootDatatype] + [rootDatatype, rightDatatype, leftDatatype, diamondDatatype] + +-- Test 17: Mutually recursive datatypes (RoseTree/Forest) +-- Should emit declare-datatypes with both types together + +/-- RoseTree α = Node α (Forest α) -/ +def roseTreeDatatype : LDatatype Visibility := + { name := "RoseTree" + typeArgs := ["α"] + constrs := [ + { name := ⟨"Node", .unres⟩, args := [ + (⟨"RoseTree$NodeProj0", .unres⟩, .ftvar "α"), + (⟨"RoseTree$NodeProj1", .unres⟩, .tcons "Forest" [.ftvar "α"]) + ], testerName := "RoseTree$isNode" } + ] + constrs_ne := by decide } + +/-- Forest α = FNil | FCons (RoseTree α) (Forest α) -/ +def forestDatatype : LDatatype Visibility := + { name := "Forest" + typeArgs := ["α"] + constrs := [ + { name := ⟨"FNil", .unres⟩, args := [], testerName := "Forest$isFNil" }, + { name := ⟨"FCons", .unres⟩, args := [ + (⟨"Forest$FConsProj0", .unres⟩, .tcons "RoseTree" [.ftvar "α"]), + (⟨"Forest$FConsProj1", .unres⟩, .tcons "Forest" [.ftvar "α"]) + ], testerName := "Forest$isFCons" } + ] + constrs_ne := by decide } + +/-- +info: (declare-datatypes ((RoseTree 1) (Forest 1)) + ((par (α) ((Node (RoseTree$NodeProj0 α) (RoseTree$NodeProj1 (Forest α))))) + (par (α) ((FNil) (FCons (Forest$FConsProj0 (RoseTree α)) (Forest$FConsProj1 (Forest α))))))) +; tree +(declare-const f0 (RoseTree Int)) +(define-fun t0 () (RoseTree Int) f0) +-/ +#guard_msgs in +#eval format <$> toSMTStringWithDatatypeBlocks + (.fvar () (BoogieIdent.unres "tree") (.some (.tcons "RoseTree" [.int]))) + [[roseTreeDatatype, forestDatatype]] + +-- Test 19: Mix of mutual and non-mutual datatypes +-- TestOption (non-mutual), then RoseTree/Forest (mutual) +/-- +info: (declare-datatype TestOption (par (α) ( + (None) + (Some (TestOption$SomeProj0 α))))) +(declare-datatypes ((RoseTree 1) (Forest 1)) + ((par (α) ((Node (RoseTree$NodeProj0 α) (RoseTree$NodeProj1 (Forest α))))) + (par (α) ((FNil) (FCons (Forest$FConsProj0 (RoseTree α)) (Forest$FConsProj1 (Forest α))))))) +; optionTree +(declare-const f0 (TestOption (RoseTree Int))) +(define-fun t0 () (TestOption (RoseTree Int)) f0) +-/ +#guard_msgs in +#eval format <$> toSMTStringWithDatatypeBlocks + (.fvar () (BoogieIdent.unres "optionTree") (.some (.tcons "TestOption" [.tcons "RoseTree" [.int]]))) + [[optionDatatype], [roseTreeDatatype, forestDatatype]] end DatatypeTests From 1d9d25efd7a33c58b0fd05483074ec04bebeeb97 Mon Sep 17 00:00:00 2001 From: Josh Cohen Date: Fri, 16 Jan 2026 22:00:56 -0500 Subject: [PATCH 4/5] Clean up code in LExprTypeEnv and TypeFactory --- Strata/DL/Lambda/LExprTypeEnv.lean | 35 +- Strata/DL/Lambda/TypeFactory.lean | 404 +++++++----------- .../Core/DDMTransform/Translate.lean | 2 +- Strata/Languages/Core/ProgramWF.lean | 2 + 4 files changed, 171 insertions(+), 272 deletions(-) diff --git a/Strata/DL/Lambda/LExprTypeEnv.lean b/Strata/DL/Lambda/LExprTypeEnv.lean index 6690eb8d8..cc9c5f460 100644 --- a/Strata/DL/Lambda/LExprTypeEnv.lean +++ b/Strata/DL/Lambda/LExprTypeEnv.lean @@ -250,7 +250,7 @@ structure LContext (T: LExprParams) where deriving Inhabited def LContext.empty {IDMeta} : LContext IDMeta := - ⟨#[], TypeFactory.default, {}, {}⟩ + ⟨#[], #[], {}, {}⟩ instance : EmptyCollection (LContext IDMeta) where emptyCollection := LContext.empty @@ -287,7 +287,7 @@ def TEnv.default : TEnv IDMeta := def LContext.default : LContext T := { functions := #[], - datatypes := TypeFactory.default, + datatypes := #[], knownTypes := KnownTypes.default, idents := Identifiers.default } @@ -330,41 +330,22 @@ def LContext.addFactoryFunctions (C : LContext T) (fact : @Factory T) : LContext { C with functions := C.functions.append fact } /-- -Add a datatype `d` to an `LContext` `C`. -This adds `d` to `C.datatypes`, adds the derived functions -(e.g. eliminators, testers) to `C.functions`, and adds `d` to -`C.knownTypes`. It performs error checking for name clashes. +Add a mutual block of datatypes `block` to an `LContext` `C`. +This adds all types to `C.datatypes` and `C.knownTypes`, +adds the derived functions (e.g. eliminators, testers), +and performs error checking for name clashes. -/ -def LContext.addDatatype [Inhabited T.IDMeta] [Inhabited T.Metadata] (C: LContext T) (d: LDatatype T.IDMeta) : Except Format (LContext T) := do - -- Ensure not in known types - if C.knownTypes.containsName d.name then - .error f!"Cannot name datatype same as known type!\n\ - {d}\n\ - KnownTypes' names:\n\ - {C.knownTypes.keywords}" - let ds ← C.datatypes.addDatatype d - -- Add factory functions, checking for name clashes - let f ← d.genFactory - let fs ← C.functions.addFactory f - -- Add datatype names to knownTypes - let ks ← C.knownTypes.add d.toKnownType - .ok {C with datatypes := ds, functions := fs, knownTypes := ks} - -/-- Add a mutual block of datatypes to the context. -/ def LContext.addMutualBlock [Inhabited T.IDMeta] [Inhabited T.Metadata] [ToFormat T.IDMeta] (C: LContext T) (block: MutualDatatype T.IDMeta) : Except Format (LContext T) := do -- Check for name clashes with known types for d in block do if C.knownTypes.containsName d.name then throw f!"Cannot name datatype same as known type!\n{d}\nKnownTypes' names:\n{C.knownTypes.keywords}" - -- Add all datatypes to the type factory with validation let ds ← C.datatypes.addMutualBlock block C.knownTypes.keywords - -- Generate factory functions for the whole mutual block + -- Add factory functions, checking for name clashes let f ← genBlockFactory block let fs ← C.functions.addFactory f -- Add datatype names to knownTypes - let mut ks := C.knownTypes - for d in block do - ks ← ks.add d.toKnownType + let ks ← block.foldlM (fun ks d => ks.add d.toKnownType) C.knownTypes .ok {C with datatypes := ds, functions := fs, knownTypes := ks} def LContext.addTypeFactory [Inhabited T.IDMeta] [Inhabited T.Metadata] (C: LContext T) (f: @TypeFactory T.IDMeta) : Except Format (LContext T) := diff --git a/Strata/DL/Lambda/TypeFactory.lean b/Strata/DL/Lambda/TypeFactory.lean index 56c287a6e..4072b4a96 100644 --- a/Strata/DL/Lambda/TypeFactory.lean +++ b/Strata/DL/Lambda/TypeFactory.lean @@ -81,8 +81,7 @@ The default type application for a datatype. E.g. for datatype def dataDefault (d: LDatatype IDMeta) : LMonoTy := data d (d.typeArgs.map .ftvar) -/-- A group of mutually recursive datatypes. - For non-mutually-recursive datatypes, this is a single-element list. -/ +/-- A group of mutually recursive datatypes. -/ abbrev MutualDatatype (IDMeta : Type) := List (LDatatype IDMeta) --------------------------------------------------------------------- @@ -111,40 +110,17 @@ def checkUniform (c: String) (n: String) (args: List LMonoTy) (t: LMonoTy) : Exc ) () args1 | _ => .ok () - -/-- -Check for strict positivity and uniformity of datatype `d` in type `ty`. The -string `c` appears only for error message information. --/ -def checkStrictPosUnifTy (c: String) (d: LDatatype IDMeta) (ty: LMonoTy) : Except Format Unit := - match ty with - | .arrow t1 t2 => - if tyNameAppearsIn d.name t1 then - .error f!"Error in constructor {c}: Non-strictly positive occurrence of {d.name} in type {ty}" - else checkStrictPosUnifTy c d t2 - | _ => checkUniform c d.name (d.typeArgs.map .ftvar) ty - -/-- -Check for strict positivity and uniformity of a datatype --/ -def checkStrictPosUnif (d: LDatatype IDMeta) : Except Format Unit := - List.foldrM (fun ⟨name, args, _⟩ _ => - List.foldrM (fun ⟨ _, ty ⟩ _ => - checkStrictPosUnifTy name.name d ty - ) () args - ) () d.constrs - /-- Check for strict positivity and uniformity of all datatypes in a mutual block within type `ty`. The string `c` appears only for error message information. -/ -def checkStrictPosUnifTyMutual (c: String) (block: MutualDatatype IDMeta) (ty: LMonoTy) : Except Format Unit := +def checkStrictPosUnifTy (c: String) (block: MutualDatatype IDMeta) (ty: LMonoTy) : Except Format Unit := match ty with | .arrow t1 t2 => -- Check that no datatype in the block appears in the left side of an arrow match block.find? (fun d => tyNameAppearsIn d.name t1) with | some d => .error f!"Error in constructor {c}: Non-strictly positive occurrence of {d.name} in type {ty}" - | none => checkStrictPosUnifTyMutual c block t2 + | none => checkStrictPosUnifTy c block t2 | _ => -- Check uniformity for all datatypes in the block block.foldlM (fun _ d => checkUniform c d.name (d.typeArgs.map .ftvar) ty) () @@ -152,11 +128,11 @@ def checkStrictPosUnifTyMutual (c: String) (block: MutualDatatype IDMeta) (ty: L /-- Check for strict positivity and uniformity across a mutual block of datatypes -/ -def checkStrictPosUnifMutual (block: MutualDatatype IDMeta) : Except Format Unit := +def checkStrictPosUnif (block: MutualDatatype IDMeta) : Except Format Unit := block.foldlM (fun _ d => d.constrs.foldlM (fun _ ⟨name, args, _⟩ => args.foldlM (fun _ ⟨_, ty⟩ => - checkStrictPosUnifTyMutual name.name block ty + checkStrictPosUnifTy name.name block ty ) () ) () ) () @@ -167,19 +143,11 @@ Validate a mutual block: check non-empty and no duplicate names. def validateMutualBlock (block: MutualDatatype IDMeta) : Except Format Unit := do if block.isEmpty then .error f!"Error: Empty mutual block is not allowed" - let names := block.map (·.name) - let duplicates := names.filter (fun n => names.count n > 1) - match duplicates.head? with - | some dup => throw f!"Duplicate datatype name in mutual block: {dup}" - | none => pure () - -/-- -Get all type constructor names referenced in a type. --/ -def getTypeRefs (ty: LMonoTy) : List String := - match ty with - | .tcons n args => n :: args.flatMap getTypeRefs - | _ => [] + let names : Std.HashSet String := ∅ + match (block.foldl (fun o d => + if d.name ∈ names then some d else o) none) with + | some dup => .error f!"Duplicate datataype name in mutual block: {dup}" + | none => .ok () --------------------------------------------------------------------- @@ -222,35 +190,88 @@ def freshTypeArg (l: List TyIdentifier) : TyIdentifier := | t :: _ => t | _ => "" +--------------------------------------------------------------------- + +-- Mutual block eliminator generation + +/- +In the following, we will use 3 examples to demonstrate different parts: +1. `List α = | Nil | Cons α (List α)` should generate an eliminator of +type `list α → β → (α → List α → β → β) → β` with rules +`List$Elim Nil e1 e2 = e1` and +`List$Elim (x :: xs) e1 e2 = e2 x xs (List$Elim xs e1 e2)` +2. `tree = | T (int -> tree)` should generate an eliminator of type +`tree → ((int → tree) → (int → β)) → β with rule +`Tree$Elim (T f) e = e f (fun (x: int) => Tree$Elim (f x) e)` +3. `RoseTree = R (Forest) with Forest = FNil | FCons RoseTree Forest` +should generate two eliminators: +`RoseTree$Elim : RoseTree → (Forest → β → α) → β → (RoseTree → Forest → α → β → β) → α` +`Forest$Elim : Forest → (Forest → β → α) → β → (RoseTree → Forest → α → β → β) → β` +with behavior: +`RoseTree$Elim (R x) r1 f1 f2 = r1 x (Forest$Elim x r1 f1 f2)` +`Forest$Elim FNil r1 f1 f2 = f1` +`Forest$Elim (Fcons r f) r1 f1 f2 = f2 r f (Rose$Elim r r1 f1 f2) (Forest$Elim f r1 f1 f2)` +-/ + /-- Construct a recursive type argument for the eliminator. Specifically, determine if a type `ty` contains a strictly positive, uniform -occurrence of `t`, if so, replace this occurence with `retTy`. +occurrence of a datatype in `block`, if so, +replace this occurence with the corresponding variable from `retTyVars`. -For example, given `ty` (int -> (int -> List α)), datatype List, and `retTy` β, -gives (int -> (int -> β)) +Examples: +Single datatype: given `ty` (int -> (int -> List α)), datatype List, +and `retTys` [β], gives (int -> (int -> β)) +Mutually recursive type: given `ty` (int -> (int -> RoseTree)) and +`retTys` `[β₁, β₂]`, gives (int -> (int -> β₁)) -/ -def genRecTy (t: LDatatype IDMeta) (retTy: LMonoTy) (ty: LMonoTy) : Option LMonoTy := - if ty == dataDefault t then .some retTy else +def genRecTy (block : MutualDatatype IDMeta) (retTyVars : List TyIdentifier) (ty : LMonoTy) : Option LMonoTy := + match block.findIdx? (fun d => ty == dataDefault d) with + | some idx => retTyVars[idx]? |>.map LMonoTy.ftvar + | none => + match ty with + | .arrow t1 t2 => (genRecTy block retTyVars t2).map (fun r => .arrow t1 r) + | _ => none + +/-- Find which datatype in a mutual block a recursive type refers to. -/ +def findRecTy (block : MutualDatatype IDMeta) (ty : LMonoTy) : Option (LDatatype IDMeta) := match ty with - | .arrow t1 t2 => (genRecTy t retTy t2).map (fun r => .arrow t1 r) - | _ => .none + | .arrow _ t2 => findRecTy block t2 + | _ => block.find? (fun d => ty == dataDefault d) + +/-- Check if a type is recursive within a mutual block. -/ +def isRecTy (block : MutualDatatype IDMeta) (ty : LMonoTy) : Bool := + (findRecTy block ty).isSome + -def isRecTy (t: LDatatype IDMeta) (ty: LMonoTy) : Bool := - (genRecTy t .int ty).isSome /-- Generate types for eliminator arguments. The types are functions taking in 1) each argument of constructor `c` of -datatype `d` and 2) recursive results for each recursive argument of `c` and -returning an element of type `outputType`. +each datatype in the block `block` (in order) and +2) recursive results for each recursive argument of `c` +It returns an element of type `retTyVars[dtIdx]`, where `dtIdx` is the +index of this constructor in the mutual block -For example, the eliminator type argument for `cons` is α → List α → β → β +Examples: +1. For `Cons`, the eliminator type argument is `α → List α → β → β` +2. For `FCons`, the eliminator type argument is `RoseTree → Forest → α → β → β` -/ -def elimTy (outputType : LMonoTy) (t: LDatatype IDMeta) (c: LConstr IDMeta): LMonoTy := +def elimTy (block : MutualDatatype IDMeta) (retTyVars : List TyIdentifier) + (dtIdx : Nat) (c : LConstr IDMeta) : LMonoTy := + let outputType := retTyVars[dtIdx]? |>.map LMonoTy.ftvar |>.getD (.ftvar "") match c.args with | [] => outputType - | _ :: _ => LMonoTy.mkArrow' outputType (c.args.map Prod.snd ++ (c.args.map (fun x => (genRecTy t outputType x.2).toList)).flatten) + | _ :: _ => + let recTypes := c.args.filterMap fun (_, ty) => genRecTy block retTyVars ty + LMonoTy.mkArrow' outputType (c.args.map Prod.snd ++ recTypes) + +/-- Compute global constructor index within a mutual block. + E.g. if we have a mutual type with types A, B, C, with constructors + A1, A2, B1, B2, B3, C1, the global index of B3 is 4. -/ +def blockConstrIdx (block : MutualDatatype IDMeta) (dtIdx : Nat) (constrIdx : Nat) : Nat := + let prevCount := (block.take dtIdx).foldl (fun acc d => acc + d.constrs.length) 0 + prevCount + constrIdx /-- Simulates pattern matching on operator o. @@ -261,24 +282,37 @@ def LExpr.matchOp {T: LExprParams} [BEq T.Identifier] (e: LExpr T.mono) (o: T.Id | _ => .none /-- -Determine which constructor, if any, a datatype instance belongs to and get the -arguments. Also gives the index in the constructor list as well as the -recursive arguments (somewhat redundantly) - -For example, expression `cons x l` gives constructor `cons`, index `1` (cons is -the second constructor), arguments `[x, l]`, and recursive argument -`[(l, List α)]` +Determine the constructor, index, and arguments for a constructor application. +E.g. `Cons x l` gives `Cons`, constructor index `1`, and `[x, l]`. -/ -def datatypeGetConstr {T: LExprParams} [BEq T.Identifier] (d: LDatatype T.IDMeta) (x: LExpr T.mono) : Option (LConstr T.IDMeta × Nat × List (LExpr T.mono) × List (LExpr T.mono × LMonoTy)) := +def datatypeGetConstr {T: LExprParams} [BEq T.Identifier] (d: LDatatype T.IDMeta) (x: LExpr T.mono) : Option (LConstr T.IDMeta × Nat × List (LExpr T.mono)) := List.foldr (fun (c, i) acc => match x.matchOp c.name with | .some args => - -- Get the elements of args corresponding to recursive calls, in order - let recs := (List.zip args (c.args.map Prod.snd)).filter (fun (_, ty) => isRecTy d ty) - - .some (c, i, args, recs) + .some (c, i, args) | .none => acc) .none (List.zip d.constrs (List.range d.constrs.length)) + +/-- +Determine which datatype and constructor, if any, a datatype instance +belongs to and get the arguments. Also gives the index in the block and +constructor list as well as the recursive arguments + +For example, expression `Cons x l` gives constructor `Cons`, datatype +index `0`, constructor index `1` (cons is the second constructor), +arguments `[x, l]`, and recursive argument `[(l, List α)]` +-/ +def matchConstr {T: LExprParams} [BEq T.Identifier] (block : MutualDatatype T.IDMeta) (x : LExpr T.mono) + : Option (Nat × Nat × LConstr T.IDMeta × List (LExpr T.mono) × List (LExpr T.mono × LMonoTy)) := + List.zip block (List.range block.length) |>.findSome? fun (d, dtIdx) => + (datatypeGetConstr d x).map fun (c, constrIdx, args) => + let recs := (List.zip args (c.args.map Prod.snd)).filter ( + fun (_, ty) => isRecTy block ty) + (dtIdx, constrIdx, c, args, recs) + +def elimFuncName (d: LDatatype IDMeta) : Identifier IDMeta := + d.name ++ "$Elim" + /-- Determines which category a recursive type argument falls in: either `d (typeArgs)` or `τ₁ → ... → τₙ → d(typeArgs)`. @@ -304,14 +338,16 @@ Invariant: `recTy` must either have the form `d(typeArgs)` or `τ₁ → ... → (typeArgs)`. This is enforced by `dataTypeGetConstr` -/ -def elimRecCall {T: LExprParams} (d: LDatatype T.IDMeta) (recArg: LExpr T.mono) (recTy: LMonoTy) (elimArgs: List (LExpr T.mono)) (m: T.Metadata) (elimName : Identifier T.IDMeta) : LExpr T.mono := - match recTyStructure d recTy with - | .inl _ => -- Generate eliminator call directly - (LExpr.op m elimName .none).mkApp m (recArg :: elimArgs) - | .inr funArgs => - /- Construct lambda, first arg of eliminator is recArg applied to lambda - arguments -/ - LExpr.absMulti m funArgs ((LExpr.op m elimName .none).mkApp m (recArg.mkApp m (getBVars m funArgs.length) :: elimArgs)) +def elimRecCall {T: LExprParams} [Inhabited T.IDMeta] (block : MutualDatatype T.IDMeta) (recArg : LExpr T.mono) + (recTy : LMonoTy) (elimArgs : List (LExpr T.mono)) (m : T.Metadata) : + Option (LExpr T.mono) := + (findRecTy block recTy).map fun d => + let elimName := elimFuncName d + match recTyStructure d recTy with + | .inl _ => -- Generate eliminator call directly + (LExpr.op m elimName .none).mkApp m (recArg :: elimArgs) + | .inr funArgs => + LExpr.absMulti m funArgs ((LExpr.op m elimName .none).mkApp m (recArg.mkApp m (getBVars m funArgs.length) :: elimArgs)) /-- Generate eliminator concrete evaluator. Idea: match on 1st argument (e.g. @@ -325,107 +361,22 @@ Examples: `List$Elim (x :: xs) e1 e2 = e2 x xs (List$Elim xs e1 e2)` 2. For `tree = | T (int -> tree)`, the generated function is: `Tree$Elim (T f) e = e f (fun (x: int) => Tree$Elim (f x) e)` - --/ -def elimConcreteEval {T: LExprParams} [BEq T.Identifier] (d: LDatatype T.IDMeta) (m: T.Metadata) (elimName : Identifier T.IDMeta) : - T.Metadata → List (LExpr T.mono) → Option (LExpr T.mono) := - fun _ args => - match args with - | x :: xs => - match datatypeGetConstr d x with - | .some (_, i, a, recs) => - match xs[i]? with - | .some f => f.mkApp m (a ++ recs.map (fun (r, rty) => elimRecCall d r rty xs m elimName)) - | .none => .none - | .none => .none - | _ => .none - -def elimFuncName (d: LDatatype IDMeta) : Identifier IDMeta := - d.name ++ "$Elim" - -/-- -The `LFunc` corresponding to the eliminator for datatype `d`, called e.g. -`List$Elim` for type `List`. --/ -def elimFunc [Inhabited T.IDMeta] [BEq T.Identifier] (d: LDatatype T.IDMeta) (m: T.Metadata) : LFunc T := - let outTyId := freshTypeArg d.typeArgs - { name := elimFuncName d, typeArgs := outTyId :: d.typeArgs, inputs := List.zip (genArgNames (d.constrs.length + 1)) (dataDefault d :: d.constrs.map (elimTy (.ftvar outTyId) d)), output := .ftvar outTyId, concreteEval := elimConcreteEval d m (elimFuncName d)} - ---------------------------------------------------------------------- - --- Mutual block eliminator generation - -/-- Find which datatype in a mutual block a recursive type refers to. -/ -def findRecTyInBlock (block : MutualDatatype IDMeta) (ty : LMonoTy) : Option (LDatatype IDMeta) := - block.find? (fun d => isRecTy d ty) - -/-- -Generate recursive type for mutual blocks. Replace references to datatypes -in the block with the appropriate return type variable. --/ -def genRecTyMutual (block : MutualDatatype IDMeta) (retTyVars : List TyIdentifier) (ty : LMonoTy) : Option LMonoTy := - match block.findIdx? (fun d => ty == dataDefault d) with - | some idx => retTyVars[idx]? |>.map LMonoTy.ftvar - | none => - match ty with - | .arrow t1 t2 => (genRecTyMutual block retTyVars t2).map (fun r => .arrow t1 r) - | _ => none - -/-- -Generate eliminator case type for a constructor in a mutual block. -`dtIdx` is the index of the datatype this constructor belongs to. +3. For `RoseTree = R (Forest) with Forest = FNil | FCons RoseTree Forest`, +the generated functions are: +`RoseTree$Elim (R x) r1 f1 f2 = r1 x (Forest$Elim x r1 f1 f2)` +`Forest$Elim FNil r1 f1 f2 = f1` +`Forest$Elim (Fcons r f) r1 f1 f2 = f2 r f (Rose$Elim r r1 f1 f2) (Forest$Elim f r1 f1 f2)` -/ -def elimTyMutual (block : MutualDatatype IDMeta) (retTyVars : List TyIdentifier) - (dtIdx : Nat) (c : LConstr IDMeta) : LMonoTy := - let outputType := retTyVars[dtIdx]? |>.map LMonoTy.ftvar |>.getD (.ftvar "") - match c.args with - | [] => outputType - | _ :: _ => - let argTypes := c.args.map Prod.snd - let recTypes := c.args.filterMap fun (_, ty) => genRecTyMutual block retTyVars ty - LMonoTy.mkArrow' outputType (argTypes ++ recTypes) - -/-- Compute global constructor index within a mutual block. -/ -def globalConstrIdx (block : MutualDatatype IDMeta) (dtIdx : Nat) (constrIdx : Nat) : Nat := - let prevCount := (block.take dtIdx).foldl (fun acc d => acc + d.constrs.length) 0 - prevCount + constrIdx - -/-- Check if a type is recursive within a mutual block. -/ -def isRecTyInBlock (block : MutualDatatype IDMeta) (ty : LMonoTy) : Bool := - block.any (fun d => isRecTy d ty) - -/-- Find which datatype and constructor a value belongs to in a mutual block. - Unlike datatypeGetConstr, this identifies recursive args across the whole block. -/ -def matchConstrInBlock {T: LExprParams} [BEq T.Identifier] (block : MutualDatatype T.IDMeta) (x : LExpr T.mono) - : Option (Nat × Nat × LConstr T.IDMeta × List (LExpr T.mono) × List (LExpr T.mono × LMonoTy)) := - List.zip block (List.range block.length) |>.findSome? fun (d, dtIdx) => - (datatypeGetConstr d x).map fun (c, constrIdx, args, _) => - -- Recompute recs using the whole block, not just d - let recs := (List.zip args (c.args.map Prod.snd)).filter (fun (_, ty) => isRecTyInBlock block ty) - (dtIdx, constrIdx, c, args, recs) - -/-- Construct recursive eliminator call for mutual blocks. -/ -def elimRecCallMutual {T: LExprParams} [Inhabited T.IDMeta] (block : MutualDatatype T.IDMeta) (recArg : LExpr T.mono) - (recTy : LMonoTy) (elimArgs : List (LExpr T.mono)) (m : T.Metadata) : Option (LExpr T.mono) := - (findRecTyInBlock block recTy).map fun d => - let elimName := elimFuncName d - match recTyStructure d recTy with - | .inl _ => (LExpr.op m elimName .none).mkApp m (recArg :: elimArgs) - | .inr funArgs => - LExpr.absMulti m funArgs ((LExpr.op m elimName .none).mkApp m (recArg.mkApp m (getBVars m funArgs.length) :: elimArgs)) - -/-- Generate eliminator concrete evaluator for mutual blocks. -/ -def elimConcreteEvalMutual {T: LExprParams} [BEq T.Identifier] [Inhabited T.IDMeta] (block : MutualDatatype T.IDMeta) +def elimConcreteEval {T: LExprParams} [BEq T.Identifier] [Inhabited T.IDMeta] (block : MutualDatatype T.IDMeta) (m : T.Metadata) : T.Metadata → List (LExpr T.mono) → Option (LExpr T.mono) := fun _ args => match args with | x :: xs => - match matchConstrInBlock block x with + match matchConstr block x with | some (dtIdx, constrIdx, _, a, recs) => - let gIdx := globalConstrIdx block dtIdx constrIdx + let gIdx := blockConstrIdx block dtIdx constrIdx xs[gIdx]?.bind fun f => - let recCalls := recs.filterMap (fun (r, rty) => elimRecCallMutual block r rty xs m) - some (f.mkApp m (a ++ recCalls)) + some (f.mkApp m (a ++ recs.filterMap (fun (r, rty) => elimRecCall block r rty xs m))) | none => none | _ => none @@ -433,47 +384,39 @@ def elimConcreteEvalMutual {T: LExprParams} [BEq T.Identifier] [Inhabited T.IDMe Generate eliminators for all datatypes in a mutual block. Each datatype gets its own eliminator, but they share case function arguments for all constructors across the block. +For example, +`RoseTree$Elim : RoseTree → (Forest → β → α) → β → (RoseTree → Forest → α → β → β) → α` +`Forest$Elim : Forest → (Forest → β → α) → β → (RoseTree → Forest → α → β → β) → β` -/ -def elimFuncsMutual [Inhabited T.IDMeta] [BEq T.Identifier] (block : MutualDatatype T.IDMeta) (m : T.Metadata) +def elimFuncs [Inhabited T.IDMeta] [BEq T.Identifier] (block : MutualDatatype T.IDMeta) (m : T.Metadata) : List (LFunc T) := - if block.isEmpty then [] else - let allTypeArgs := block.flatMap (·.typeArgs) - let retTyVars := freshTypeArgs block.length allTypeArgs - let allConstrs : List (Nat × LConstr T.IDMeta) := - List.zip block (List.range block.length) |>.flatMap fun (d, dtIdx) => d.constrs.map (dtIdx, ·) - let caseTypes := allConstrs.map fun (dtIdx, c) => elimTyMutual block retTyVars dtIdx c - List.zip block (List.range block.length) |>.map fun (d, dtIdx) => - let outputTyVar := retTyVars[dtIdx]?.getD "" + if h: block.isEmpty then [] else + have hlen : 0 < List.length block := by + have := @List.isEmpty_iff_length_eq_zero _ block; grind + let typeArgs := block[0].typeArgs -- OK because all must have same typevars + let retTyVars := freshTypeArgs block.length typeArgs + let allConstrs := + block.mapIdx (fun dtIdx d => d.constrs.map (dtIdx, ·)) |>.flatten + let caseTypes := allConstrs.map fun (dtIdx, c) => elimTy block retTyVars dtIdx c + (block.zip retTyVars).map fun (d, outputTyVar) => { name := elimFuncName d typeArgs := retTyVars ++ d.typeArgs inputs := List.zip (genArgNames (allConstrs.length + 1)) (dataDefault d :: caseTypes) output := .ftvar outputTyVar - concreteEval := elimConcreteEvalMutual block m } + concreteEval := elimConcreteEval block m } --------------------------------------------------------------------- -- Generating testers and destructors -/-- -Generate tester body (see `testerFunc`). The body assigns each argument of the -eliminator (fun _ ... _ => b), where b is true for the constructor's index and -false otherwise. This requires knowledge of the number of arguments for each -argument to the eliminator.-/ -def testerFuncBody {T : LExprParams} [Inhabited T.IDMeta] (d: LDatatype T.IDMeta) (c: LConstr T.IDMeta) (input: LExpr T.mono) (m: T.Metadata) : LExpr T.mono := - -- Number of arguments is number of constr args + number of recursive args - let numargs (c: LConstr T.IDMeta) := c.args.length + ((c.args.map Prod.snd).filter (isRecTy d)).length - let args := List.map (fun c1 => LExpr.absMultiInfer m (numargs c1) (.boolConst m (c.name.name == c1.name.name))) d.constrs - .mkApp m (.op m (elimFuncName d) .none) (input :: args) - /-- Generate tester body for mutual blocks. For mutual eliminators, we need case functions for ALL constructors across the block, not just the constructors of one datatype. -/ -def testerFuncBodyMutual {T : LExprParams} [Inhabited T.IDMeta] (block: MutualDatatype T.IDMeta) +def testerFuncBody {T : LExprParams} [Inhabited T.IDMeta] (block: MutualDatatype T.IDMeta) (d: LDatatype T.IDMeta) (c: LConstr T.IDMeta) (input: LExpr T.mono) (m: T.Metadata) : LExpr T.mono := -- Number of arguments for a constructor in a mutual block - let numargs (c: LConstr T.IDMeta) := c.args.length + ((c.args.map Prod.snd).filter (isRecTyInBlock block)).length - -- Generate case functions for ALL constructors in the block + let numargs (c: LConstr T.IDMeta) := c.args.length + ((c.args.map Prod.snd).filter (isRecTy block)).length let args := block.flatMap (fun d' => d'.constrs.map (fun c1 => LExpr.absMultiInfer m (numargs c1) (.boolConst m (c.name.name == c1.name.name)))) .mkApp m (.op m (elimFuncName d) .none) (input :: args) @@ -481,31 +424,14 @@ def testerFuncBodyMutual {T : LExprParams} [Inhabited T.IDMeta] (block: MutualDa /-- Generate tester function for a constructor in a mutual block. -/ -def testerFuncMutual {T} [Inhabited T.IDMeta] (block: MutualDatatype T.IDMeta) +def testerFunc {T} [Inhabited T.IDMeta] (block: MutualDatatype T.IDMeta) (d: LDatatype T.IDMeta) (c: LConstr T.IDMeta) (m: T.Metadata) : LFunc T := let arg := genArgName {name := c.testerName, typeArgs := d.typeArgs, inputs := [(arg, dataDefault d)], output := .bool, - body := testerFuncBodyMutual block d c (.fvar m arg .none) m, - attr := #["inline_if_val"] - } - -/-- -Generate tester function for a constructor (e.g. `List$isCons` and -`List$isNil`). The semantics of the testers are given via a body, -and they are defined in terms of eliminators. For example: -`List$isNil l := List$Elim l true (fun _ _ _ => false)` -`List$isCons l := List$Elim l false (fun _ _ _ => true)` --/ -def testerFunc {T} [Inhabited T.IDMeta] (d: LDatatype T.IDMeta) (c: LConstr T.IDMeta) (m: T.Metadata) : LFunc T := - let arg := genArgName - {name := c.testerName, - typeArgs := d.typeArgs, - inputs := [(arg, dataDefault d)], - output := .bool, - body := testerFuncBody d c (.fvar m arg .none) m, + body := testerFuncBody block d c (.fvar m arg .none) m, attr := #["inline_if_val"] } @@ -518,7 +444,7 @@ def destructorConcreteEval {T: LExprParams} [BEq T.Identifier] (d: LDatatype T.I fun args => match args with | [x] => - (datatypeGetConstr d x).bind (fun (c1, _, a, _) => + (datatypeGetConstr d x).bind (fun (c1, _, a) => if c1.name.name == c.name.name then a[idx]? else none) | _ => none @@ -577,15 +503,20 @@ def TypeFactory.allTypeNames (t : @TypeFactory IDMeta) : List String := t.allDatatypes.map (·.name) /-- -Validate that all type references in a mutual block refer to either: -1. Known primitive types (passed as parameter) -2. Types already in the TypeFactory -3. Types within the same mutual block +Get all type constructor names referenced in a type. +-/ +def getTypeRefs (ty: LMonoTy) : List String := + match ty with + | .tcons n args => n :: args.flatMap getTypeRefs + | _ => [] + +/-- +Ensures all type occuring a constructor are only primitive types, +types defined previously, or types in the same mutual block. -/ def TypeFactory.validateTypeReferences (t : @TypeFactory IDMeta) (block : MutualDatatype IDMeta) (knownTypes : List String) : Except Format Unit := do - let blockNames := block.map (·.name) - let existingNames := t.allTypeNames - let validNames := knownTypes ++ existingNames ++ blockNames + let validNames : Std.HashSet String := + Std.HashSet.ofList (knownTypes ++ t.allTypeNames ++ block.map (·.name)) for d in block do for c in d.constrs do for (_, ty) in c.args do @@ -593,10 +524,13 @@ def TypeFactory.validateTypeReferences (t : @TypeFactory IDMeta) (block : Mutual if !validNames.contains ref then throw f!"Error in datatype {d.name}, constructor {c.name.name}: Undefined type '{ref}'" -/-- Add a mutual block to the TypeFactory, with validation. -/ +/-- Add a mutual block to the TypeFactory, checking for duplicates, + inconsistent types, and positivity. -/ def TypeFactory.addMutualBlock (t : @TypeFactory IDMeta) (block : MutualDatatype IDMeta) (knownTypes : List String := []) : Except Format (@TypeFactory IDMeta) := do - -- Validate block structure + -- Check for name clashes within block validateMutualBlock block + -- Check for positivity and uniformity + checkStrictPosUnif block -- Check for duplicate names with existing types for d in block do match t.getType d.name with @@ -605,26 +539,10 @@ def TypeFactory.addMutualBlock (t : @TypeFactory IDMeta) (block : MutualDatatype Existing Type: {d'}\n\ New Type:{d}" | none => pure () - -- Validate type references + -- Check for consistent type dependencies t.validateTypeReferences block knownTypes .ok (t.push block) -/-- Add a single datatype (as a single-element block). -/ -def TypeFactory.addDatatype (t : @TypeFactory IDMeta) (d : LDatatype IDMeta) : Except Format (@TypeFactory IDMeta) := - t.addMutualBlock [d] - -/-- -Generates the Factory (containing the eliminator, constructors, testers, -and destructors) for a single datatype. --/ -def LDatatype.genFactory {T: LExprParams} [inst: Inhabited T.Metadata] [Inhabited T.IDMeta] [ToFormat T.IDMeta] [BEq T.Identifier] (d: LDatatype T.IDMeta): Except Format (@Lambda.Factory T) := do - _ ← checkStrictPosUnif d - Factory.default.addFactory ( - elimFunc d inst.default :: - d.constrs.map (fun c => constrFunc c d) ++ - d.constrs.map (fun c => testerFunc d c inst.default) ++ - (d.constrs.map (fun c => destructorFuncs d c)).flatten).toArray - /-- Constructs maps of generated functions for datatype `d`: map of constructors, testers, and destructors in order. Each maps names to @@ -646,11 +564,9 @@ for a mutual block of datatypes. def genBlockFactory {T: LExprParams} [inst: Inhabited T.Metadata] [Inhabited T.IDMeta] [ToFormat T.IDMeta] [BEq T.Identifier] (block : MutualDatatype T.IDMeta) : Except Format (@Lambda.Factory T) := do if block.isEmpty then return Factory.default - -- Check strict positivity and uniformity across the whole block - checkStrictPosUnifMutual block - let elims := elimFuncsMutual block inst.default + let elims := elimFuncs block inst.default let constrs := block.flatMap (fun d => d.constrs.map (fun c => constrFunc c d)) - let testers := block.flatMap (fun d => d.constrs.map (fun c => testerFuncMutual block d c inst.default)) + let testers := block.flatMap (fun d => d.constrs.map (fun c => testerFunc block d c inst.default)) let destrs := block.flatMap (fun d => d.constrs.flatMap (fun c => destructorFuncs d c)) Factory.default.addFactory (elims ++ constrs ++ testers ++ destrs).toArray diff --git a/Strata/Languages/Core/DDMTransform/Translate.lean b/Strata/Languages/Core/DDMTransform/Translate.lean index 2af42284d..8a160623f 100644 --- a/Strata/Languages/Core/DDMTransform/Translate.lean +++ b/Strata/Languages/Core/DDMTransform/Translate.lean @@ -1379,7 +1379,7 @@ def translateDatatype (p : Program) (bindings : TransBindings) (op : Operation) -- Generate factory from LDatatype and convert to Core.Decl -- (used only for bindings.freeVars, not for allDecls) - let factory ← match ldatatype.genFactory (T := CoreLParams) with + let factory ← match genBlockFactory [ldatatype] (T := CoreLParams) with | .ok f => pure f | .error e => TransM.error s!"Failed to generate datatype factory: {e}" let funcDecls : List Core.Decl := factory.toList.map fun func => diff --git a/Strata/Languages/Core/ProgramWF.lean b/Strata/Languages/Core/ProgramWF.lean index fe32d4144..897be24f1 100644 --- a/Strata/Languages/Core/ProgramWF.lean +++ b/Strata/Languages/Core/ProgramWF.lean @@ -280,11 +280,13 @@ theorem addKnownTypeWithErrorIdents {C: Expression.TyContext}: C.addKnownTypeWit case error => intros _; contradiction case ok k'=> simp[Except.bind]; intros T'; subst T'; rfl +/- theorem addDatatypeIdents {C: Expression.TyContext}: C.addDatatype d = .ok C' → C.idents = C'.idents := by unfold LContext.addDatatype; simp only[bind, Except.bind, pure, Except.pure]; intros Hok repeat (split at Hok <;> try contradiction) cases Hok <;> rfl +-/ /-- If a program typechecks successfully, then every identifier in the list of From 778b4e84c23a48aadfe469e94df6a6e9a2543335 Mon Sep 17 00:00:00 2001 From: Josh Cohen Date: Fri, 16 Jan 2026 22:09:52 -0500 Subject: [PATCH 5/5] Minor cleaning up code --- Strata/DL/SMT/Solver.lean | 2 -- Strata/Languages/Core/DDMTransform/Translate.lean | 2 +- Strata/Languages/Core/ProgramWF.lean | 4 ++-- Strata/Languages/Core/SMTEncoder.lean | 12 ++---------- Strata/Languages/Core/TypeDecl.lean | 5 +---- Strata/Languages/Core/Verifier.lean | 4 ++-- StrataTest/Languages/Core/ExprEvalTest.lean | 2 +- 7 files changed, 9 insertions(+), 22 deletions(-) diff --git a/Strata/DL/SMT/Solver.lean b/Strata/DL/SMT/Solver.lean index 8aedb7518..ca6f3b0b4 100644 --- a/Strata/DL/SMT/Solver.lean +++ b/Strata/DL/SMT/Solver.lean @@ -134,10 +134,8 @@ def declareDatatype (id : String) (params : List String) (constructors : List St /-- Declare multiple mutually recursive datatypes. Each element is (name, params, constructors). -/ def declareDatatypes (dts : List (String × List String × List String)) : SolverM Unit := do if dts.isEmpty then return - -- Build sort declarations: ((name arity) ...) let sortDecls := dts.map fun (name, params, _) => s!"({name} {params.length})" let sortDeclStr := String.intercalate " " sortDecls - -- Build datatype bodies let bodies := dts.map fun (_, params, constrs) => let cInline := String.intercalate " " constrs if params.isEmpty then s!"({cInline})" diff --git a/Strata/Languages/Core/DDMTransform/Translate.lean b/Strata/Languages/Core/DDMTransform/Translate.lean index 8a160623f..34485c7ac 100644 --- a/Strata/Languages/Core/DDMTransform/Translate.lean +++ b/Strata/Languages/Core/DDMTransform/Translate.lean @@ -253,7 +253,7 @@ partial def translateLMonoTy (bindings : TransBindings) (arg : Arg) : pure ty | .type (.data (ldatatype :: _)) _md => -- Datatype Declaration - -- TODO: For mutual blocks, need to find the specific datatype by name + -- TODO: Handle mutual blocks, need to find the specific datatype by name let args := ldatatype.typeArgs.map LMonoTy.ftvar pure (.tcons ldatatype.name args) | .type (.data []) _md => diff --git a/Strata/Languages/Core/ProgramWF.lean b/Strata/Languages/Core/ProgramWF.lean index 897be24f1..2a09967fd 100644 --- a/Strata/Languages/Core/ProgramWF.lean +++ b/Strata/Languages/Core/ProgramWF.lean @@ -295,7 +295,7 @@ program decls is not in the original `LContext` theorem Program.typeCheckFunctionDisjoint : Program.typeCheck.go p C T decls acc = .ok (d', T') → (∀ x, x ∈ Program.getNames.go decls → ¬ C.idents.contains x) := by -- TODO: This proof needs to be updated to handle mutual datatypes (multiple names per decl) sorry - /- Original proof: + /- induction decls generalizing acc p d' T' T C with | nil => simp[Program.getNames.go] | cons r rs IH => @@ -343,7 +343,7 @@ unique. theorem Program.typeCheckFunctionNoDup : Program.typeCheck.go p C T decls acc = .ok (d', T') → (Program.getNames.go decls).Nodup := by -- TODO: This proof needs to be updated to handle mutual datatypes (multiple names per decl) sorry - /- Original proof: + /- induction decls generalizing acc p C T with | nil => simp[Program.getNames.go] | cons r rs IH => diff --git a/Strata/Languages/Core/SMTEncoder.lean b/Strata/Languages/Core/SMTEncoder.lean index 7957acbc3..11c02f53b 100644 --- a/Strata/Languages/Core/SMTEncoder.lean +++ b/Strata/Languages/Core/SMTEncoder.lean @@ -35,9 +35,8 @@ structure SMT.Context where ifs : Array SMT.IF := #[] axms : Array Term := #[] tySubst: Map String TermType := [] - /-- Stores the TypeFactory for datatype ordering during emission. - This is redundant with Env.datatypes but needed here to preserve - the correct dependency order when emitting declare-datatypes. -/ + /-- Stores the TypeFactory purely for ordering datatype declarations + correctly (TypeFactory in topological order) -/ typeFactory : @Lambda.TypeFactory CoreLParams.IDMeta := #[] seenDatatypes : Std.HashSet String := {} datatypeFuns : Map String (Op.DatatypeFuncs × LConstr CoreLParams.IDMeta) := Map.empty @@ -117,7 +116,6 @@ Single-element blocks use declare-datatype, multi-element blocks use declare-dat -/ def SMT.Context.emitDatatypes (ctx : SMT.Context) : Strata.SMT.SolverM Unit := do for block in ctx.typeFactory.toList do - -- Filter to only datatypes that are used (in seenDatatypes) let usedBlock := block.filter (fun d => ctx.seenDatatypes.contains d.name) match usedBlock with | [] => pure () @@ -557,14 +555,8 @@ partial def toSMTOp (E : Env) (fn : CoreIdent) (fnty : LMonoTy) (ctx : SMT.Conte .ok (.app (Op.uf uf), smt_outty, ctx) end -/-- -Convert expressions to SMT terms. -Sets the TypeFactory on the context if not already set, to ensure correct -datatype emission ordering. --/ def toSMTTerms (E : Env) (es : List (LExpr CoreLParams.mono)) (ctx : SMT.Context) : Except Format ((List Term) × SMT.Context) := do - -- Ensure typeFactory is set for correct datatype emission ordering let ctx := if ctx.typeFactory.isEmpty then ctx.withTypeFactory E.datatypes else ctx match es with | [] => .ok ([], ctx) diff --git a/Strata/Languages/Core/TypeDecl.lean b/Strata/Languages/Core/TypeDecl.lean index e82226898..5a45f8286 100644 --- a/Strata/Languages/Core/TypeDecl.lean +++ b/Strata/Languages/Core/TypeDecl.lean @@ -83,9 +83,6 @@ def TypeSynonym.toRHSLTy (t : TypeSynonym) : LTy := /-! # Strata Core Type Declarations -/ -/-- A Boogie type declaration. The `data` variant stores a mutual block - (a non-empty list of mutually recursive datatypes). For non-mutually-recursive - datatypes, this is a single-element list. -/ inductive TypeDecl where | con : TypeConstructor → TypeDecl | syn : TypeSynonym → TypeDecl @@ -101,7 +98,7 @@ instance : ToFormat TypeDecl where | .data [td] => f!"{td}" | .data tds => f!"mutual {Std.Format.joinSep (tds.map format) Format.line} end" -/-- Get all names from a TypeDecl. For mutual blocks, returns all datatype names. -/ +/-- Get all names from a TypeDecl. -/ def TypeDecl.names (d : TypeDecl) : List Expression.Ident := match d with | .con tc => [tc.name] diff --git a/Strata/Languages/Core/Verifier.lean b/Strata/Languages/Core/Verifier.lean index 4c2f7ede7..84038fdcf 100644 --- a/Strata/Languages/Core/Verifier.lean +++ b/Strata/Languages/Core/Verifier.lean @@ -158,8 +158,8 @@ def solverResult (vars : List (IdentT LMonoTy Visibility)) (output: IO.Process.O def getSolverPrelude : String → SolverM Unit | "z3" => do - -- These options are set by the standard Core implementation and are - -- generally good for the Core dialect, too, though we may want to + -- These options are set by the standard Boogie implementation and are + -- generally good for the Boogie dialect, too, though we may want to -- have more fine-grained criteria for when to use them. Solver.setOption "smt.mbqi" "false" Solver.setOption "auto_config" "false" diff --git a/StrataTest/Languages/Core/ExprEvalTest.lean b/StrataTest/Languages/Core/ExprEvalTest.lean index 3dcc4f464..d0d73c937 100644 --- a/StrataTest/Languages/Core/ExprEvalTest.lean +++ b/StrataTest/Languages/Core/ExprEvalTest.lean @@ -189,7 +189,7 @@ open Lambda.LTy.Syntax -- This may take a while (~ 5min) --- #eval (checkFactoryOps false) +#eval (checkFactoryOps false) open Plausible TestGen