diff --git a/Strata/DL/Lambda/LExprTypeEnv.lean b/Strata/DL/Lambda/LExprTypeEnv.lean index 3b0ad8d3a..cc9c5f460 100644 --- a/Strata/DL/Lambda/LExprTypeEnv.lean +++ b/Strata/DL/Lambda/LExprTypeEnv.lean @@ -218,9 +218,6 @@ deriving Inhabited def LDatatype.toKnownType (d: LDatatype IDMeta) : KnownType := { name := d.name, metadata := d.typeArgs.length} -def TypeFactory.toKnownTypes (t: @TypeFactory IDMeta) : KnownTypes := - makeKnownTypes (t.foldr (fun t l => t.toKnownType :: l) []) - /-- A type environment `TEnv` contains - genEnv: `TGenEnv` to track the generator state as well as the typing context @@ -333,28 +330,26 @@ def LContext.addFactoryFunctions (C : LContext T) (fact : @Factory T) : LContext { C with functions := C.functions.append fact } /-- -Add a datatype `d` to an `LContext` `C`. -This adds `d` to `C.datatypes`, adds the derived functions -(e.g. eliminators, testers) to `C.functions`, and adds `d` to -`C.knownTypes`. It performs error checking for name clashes. +Add a mutual block of datatypes `block` to an `LContext` `C`. +This adds all types to `C.datatypes` and `C.knownTypes`, +adds the derived functions (e.g. eliminators, testers), +and performs error checking for name clashes. -/ -def LContext.addDatatype [Inhabited T.IDMeta] [Inhabited T.Metadata] (C: LContext T) (d: LDatatype T.IDMeta) : Except Format (LContext T) := do - -- Ensure not in known types - if C.knownTypes.containsName d.name then - .error f!"Cannot name datatype same as known type!\n\ - {d}\n\ - KnownTypes' names:\n\ - {C.knownTypes.keywords}" - let ds ← C.datatypes.addDatatype d +def LContext.addMutualBlock [Inhabited T.IDMeta] [Inhabited T.Metadata] [ToFormat T.IDMeta] (C: LContext T) (block: MutualDatatype T.IDMeta) : Except Format (LContext T) := do + -- Check for name clashes with known types + for d in block do + if C.knownTypes.containsName d.name then + throw f!"Cannot name datatype same as known type!\n{d}\nKnownTypes' names:\n{C.knownTypes.keywords}" + let ds ← C.datatypes.addMutualBlock block C.knownTypes.keywords -- Add factory functions, checking for name clashes - let f ← d.genFactory + let f ← genBlockFactory block let fs ← C.functions.addFactory f -- Add datatype names to knownTypes - let ks ← C.knownTypes.add d.toKnownType + let ks ← block.foldlM (fun ks d => ks.add d.toKnownType) C.knownTypes .ok {C with datatypes := ds, functions := fs, knownTypes := ks} def LContext.addTypeFactory [Inhabited T.IDMeta] [Inhabited T.Metadata] (C: LContext T) (f: @TypeFactory T.IDMeta) : Except Format (LContext T) := - Array.foldlM (fun C d => C.addDatatype d) C f + f.foldlM (fun C block => C.addMutualBlock block) C /-- Replace the global substitution in `T.state.subst` with `S`. diff --git a/Strata/DL/Lambda/TypeFactory.lean b/Strata/DL/Lambda/TypeFactory.lean index e20cb53c5..4072b4a96 100644 --- a/Strata/DL/Lambda/TypeFactory.lean +++ b/Strata/DL/Lambda/TypeFactory.lean @@ -81,6 +81,9 @@ The default type application for a datatype. E.g. for datatype def dataDefault (d: LDatatype IDMeta) : LMonoTy := data d (d.typeArgs.map .ftvar) +/-- A group of mutually recursive datatypes. -/ +abbrev MutualDatatype (IDMeta : Type) := List (LDatatype IDMeta) + --------------------------------------------------------------------- -- Typechecking @@ -107,28 +110,44 @@ def checkUniform (c: String) (n: String) (args: List LMonoTy) (t: LMonoTy) : Exc ) () args1 | _ => .ok () - /-- -Check for strict positivity and uniformity of datatype `d` in type `ty`. The -string `c` appears only for error message information. +Check for strict positivity and uniformity of all datatypes in a mutual block +within type `ty`. The string `c` appears only for error message information. -/ -def checkStrictPosUnifTy (c: String) (d: LDatatype IDMeta) (ty: LMonoTy) : Except Format Unit := +def checkStrictPosUnifTy (c: String) (block: MutualDatatype IDMeta) (ty: LMonoTy) : Except Format Unit := match ty with | .arrow t1 t2 => - if tyNameAppearsIn d.name t1 then - .error f!"Error in constructor {c}: Non-strictly positive occurrence of {d.name} in type {ty}" - else checkStrictPosUnifTy c d t2 - | _ => checkUniform c d.name (d.typeArgs.map .ftvar) ty + -- Check that no datatype in the block appears in the left side of an arrow + match block.find? (fun d => tyNameAppearsIn d.name t1) with + | some d => .error f!"Error in constructor {c}: Non-strictly positive occurrence of {d.name} in type {ty}" + | none => checkStrictPosUnifTy c block t2 + | _ => + -- Check uniformity for all datatypes in the block + block.foldlM (fun _ d => checkUniform c d.name (d.typeArgs.map .ftvar) ty) () /-- -Check for strict positivity and uniformity of a datatype +Check for strict positivity and uniformity across a mutual block of datatypes -/ -def checkStrictPosUnif (d: LDatatype IDMeta) : Except Format Unit := - List.foldrM (fun ⟨name, args, _⟩ _ => - List.foldrM (fun ⟨ _, ty ⟩ _ => - checkStrictPosUnifTy name.name d ty - ) () args - ) () d.constrs +def checkStrictPosUnif (block: MutualDatatype IDMeta) : Except Format Unit := + block.foldlM (fun _ d => + d.constrs.foldlM (fun _ ⟨name, args, _⟩ => + args.foldlM (fun _ ⟨_, ty⟩ => + checkStrictPosUnifTy name.name block ty + ) () + ) () + ) () + +/-- +Validate a mutual block: check non-empty and no duplicate names. +-/ +def validateMutualBlock (block: MutualDatatype IDMeta) : Except Format Unit := do + if block.isEmpty then + .error f!"Error: Empty mutual block is not allowed" + let names : Std.HashSet String := ∅ + match (block.foldl (fun o d => + if d.name ∈ names then some d else o) none) with + | some dup => .error f!"Duplicate datataype name in mutual block: {dup}" + | none => .ok () --------------------------------------------------------------------- @@ -171,35 +190,88 @@ def freshTypeArg (l: List TyIdentifier) : TyIdentifier := | t :: _ => t | _ => "" +--------------------------------------------------------------------- + +-- Mutual block eliminator generation + +/- +In the following, we will use 3 examples to demonstrate different parts: +1. `List α = | Nil | Cons α (List α)` should generate an eliminator of +type `list α → β → (α → List α → β → β) → β` with rules +`List$Elim Nil e1 e2 = e1` and +`List$Elim (x :: xs) e1 e2 = e2 x xs (List$Elim xs e1 e2)` +2. `tree = | T (int -> tree)` should generate an eliminator of type +`tree → ((int → tree) → (int → β)) → β with rule +`Tree$Elim (T f) e = e f (fun (x: int) => Tree$Elim (f x) e)` +3. `RoseTree = R (Forest) with Forest = FNil | FCons RoseTree Forest` +should generate two eliminators: +`RoseTree$Elim : RoseTree → (Forest → β → α) → β → (RoseTree → Forest → α → β → β) → α` +`Forest$Elim : Forest → (Forest → β → α) → β → (RoseTree → Forest → α → β → β) → β` +with behavior: +`RoseTree$Elim (R x) r1 f1 f2 = r1 x (Forest$Elim x r1 f1 f2)` +`Forest$Elim FNil r1 f1 f2 = f1` +`Forest$Elim (Fcons r f) r1 f1 f2 = f2 r f (Rose$Elim r r1 f1 f2) (Forest$Elim f r1 f1 f2)` +-/ + /-- Construct a recursive type argument for the eliminator. Specifically, determine if a type `ty` contains a strictly positive, uniform -occurrence of `t`, if so, replace this occurence with `retTy`. +occurrence of a datatype in `block`, if so, +replace this occurence with the corresponding variable from `retTyVars`. -For example, given `ty` (int -> (int -> List α)), datatype List, and `retTy` β, -gives (int -> (int -> β)) +Examples: +Single datatype: given `ty` (int -> (int -> List α)), datatype List, +and `retTys` [β], gives (int -> (int -> β)) +Mutually recursive type: given `ty` (int -> (int -> RoseTree)) and +`retTys` `[β₁, β₂]`, gives (int -> (int -> β₁)) -/ -def genRecTy (t: LDatatype IDMeta) (retTy: LMonoTy) (ty: LMonoTy) : Option LMonoTy := - if ty == dataDefault t then .some retTy else +def genRecTy (block : MutualDatatype IDMeta) (retTyVars : List TyIdentifier) (ty : LMonoTy) : Option LMonoTy := + match block.findIdx? (fun d => ty == dataDefault d) with + | some idx => retTyVars[idx]? |>.map LMonoTy.ftvar + | none => + match ty with + | .arrow t1 t2 => (genRecTy block retTyVars t2).map (fun r => .arrow t1 r) + | _ => none + +/-- Find which datatype in a mutual block a recursive type refers to. -/ +def findRecTy (block : MutualDatatype IDMeta) (ty : LMonoTy) : Option (LDatatype IDMeta) := match ty with - | .arrow t1 t2 => (genRecTy t retTy t2).map (fun r => .arrow t1 r) - | _ => .none + | .arrow _ t2 => findRecTy block t2 + | _ => block.find? (fun d => ty == dataDefault d) + +/-- Check if a type is recursive within a mutual block. -/ +def isRecTy (block : MutualDatatype IDMeta) (ty : LMonoTy) : Bool := + (findRecTy block ty).isSome + -def isRecTy (t: LDatatype IDMeta) (ty: LMonoTy) : Bool := - (genRecTy t .int ty).isSome /-- Generate types for eliminator arguments. The types are functions taking in 1) each argument of constructor `c` of -datatype `d` and 2) recursive results for each recursive argument of `c` and -returning an element of type `outputType`. +each datatype in the block `block` (in order) and +2) recursive results for each recursive argument of `c` +It returns an element of type `retTyVars[dtIdx]`, where `dtIdx` is the +index of this constructor in the mutual block -For example, the eliminator type argument for `cons` is α → List α → β → β +Examples: +1. For `Cons`, the eliminator type argument is `α → List α → β → β` +2. For `FCons`, the eliminator type argument is `RoseTree → Forest → α → β → β` -/ -def elimTy (outputType : LMonoTy) (t: LDatatype IDMeta) (c: LConstr IDMeta): LMonoTy := +def elimTy (block : MutualDatatype IDMeta) (retTyVars : List TyIdentifier) + (dtIdx : Nat) (c : LConstr IDMeta) : LMonoTy := + let outputType := retTyVars[dtIdx]? |>.map LMonoTy.ftvar |>.getD (.ftvar "") match c.args with | [] => outputType - | _ :: _ => LMonoTy.mkArrow' outputType (c.args.map Prod.snd ++ (c.args.map (fun x => (genRecTy t outputType x.2).toList)).flatten) + | _ :: _ => + let recTypes := c.args.filterMap fun (_, ty) => genRecTy block retTyVars ty + LMonoTy.mkArrow' outputType (c.args.map Prod.snd ++ recTypes) + +/-- Compute global constructor index within a mutual block. + E.g. if we have a mutual type with types A, B, C, with constructors + A1, A2, B1, B2, B3, C1, the global index of B3 is 4. -/ +def blockConstrIdx (block : MutualDatatype IDMeta) (dtIdx : Nat) (constrIdx : Nat) : Nat := + let prevCount := (block.take dtIdx).foldl (fun acc d => acc + d.constrs.length) 0 + prevCount + constrIdx /-- Simulates pattern matching on operator o. @@ -210,24 +282,37 @@ def LExpr.matchOp {T: LExprParams} [BEq T.Identifier] (e: LExpr T.mono) (o: T.Id | _ => .none /-- -Determine which constructor, if any, a datatype instance belongs to and get the -arguments. Also gives the index in the constructor list as well as the -recursive arguments (somewhat redundantly) - -For example, expression `cons x l` gives constructor `cons`, index `1` (cons is -the second constructor), arguments `[x, l]`, and recursive argument -`[(l, List α)]` +Determine the constructor, index, and arguments for a constructor application. +E.g. `Cons x l` gives `Cons`, constructor index `1`, and `[x, l]`. -/ -def datatypeGetConstr {T: LExprParams} [BEq T.Identifier] (d: LDatatype T.IDMeta) (x: LExpr T.mono) : Option (LConstr T.IDMeta × Nat × List (LExpr T.mono) × List (LExpr T.mono × LMonoTy)) := +def datatypeGetConstr {T: LExprParams} [BEq T.Identifier] (d: LDatatype T.IDMeta) (x: LExpr T.mono) : Option (LConstr T.IDMeta × Nat × List (LExpr T.mono)) := List.foldr (fun (c, i) acc => match x.matchOp c.name with | .some args => - -- Get the elements of args corresponding to recursive calls, in order - let recs := (List.zip args (c.args.map Prod.snd)).filter (fun (_, ty) => isRecTy d ty) - - .some (c, i, args, recs) + .some (c, i, args) | .none => acc) .none (List.zip d.constrs (List.range d.constrs.length)) + +/-- +Determine which datatype and constructor, if any, a datatype instance +belongs to and get the arguments. Also gives the index in the block and +constructor list as well as the recursive arguments + +For example, expression `Cons x l` gives constructor `Cons`, datatype +index `0`, constructor index `1` (cons is the second constructor), +arguments `[x, l]`, and recursive argument `[(l, List α)]` +-/ +def matchConstr {T: LExprParams} [BEq T.Identifier] (block : MutualDatatype T.IDMeta) (x : LExpr T.mono) + : Option (Nat × Nat × LConstr T.IDMeta × List (LExpr T.mono) × List (LExpr T.mono × LMonoTy)) := + List.zip block (List.range block.length) |>.findSome? fun (d, dtIdx) => + (datatypeGetConstr d x).map fun (c, constrIdx, args) => + let recs := (List.zip args (c.args.map Prod.snd)).filter ( + fun (_, ty) => isRecTy block ty) + (dtIdx, constrIdx, c, args, recs) + +def elimFuncName (d: LDatatype IDMeta) : Identifier IDMeta := + d.name ++ "$Elim" + /-- Determines which category a recursive type argument falls in: either `d (typeArgs)` or `τ₁ → ... → τₙ → d(typeArgs)`. @@ -253,14 +338,16 @@ Invariant: `recTy` must either have the form `d(typeArgs)` or `τ₁ → ... → (typeArgs)`. This is enforced by `dataTypeGetConstr` -/ -def elimRecCall {T: LExprParams} (d: LDatatype T.IDMeta) (recArg: LExpr T.mono) (recTy: LMonoTy) (elimArgs: List (LExpr T.mono)) (m: T.Metadata) (elimName : Identifier T.IDMeta) : LExpr T.mono := - match recTyStructure d recTy with - | .inl _ => -- Generate eliminator call directly - (LExpr.op m elimName .none).mkApp m (recArg :: elimArgs) - | .inr funArgs => - /- Construct lambda, first arg of eliminator is recArg applied to lambda - arguments -/ - LExpr.absMulti m funArgs ((LExpr.op m elimName .none).mkApp m (recArg.mkApp m (getBVars m funArgs.length) :: elimArgs)) +def elimRecCall {T: LExprParams} [Inhabited T.IDMeta] (block : MutualDatatype T.IDMeta) (recArg : LExpr T.mono) + (recTy : LMonoTy) (elimArgs : List (LExpr T.mono)) (m : T.Metadata) : + Option (LExpr T.mono) := + (findRecTy block recTy).map fun d => + let elimName := elimFuncName d + match recTyStructure d recTy with + | .inl _ => -- Generate eliminator call directly + (LExpr.op m elimName .none).mkApp m (recArg :: elimArgs) + | .inr funArgs => + LExpr.absMulti m funArgs ((LExpr.op m elimName .none).mkApp m (recArg.mkApp m (getBVars m funArgs.length) :: elimArgs)) /-- Generate eliminator concrete evaluator. Idea: match on 1st argument (e.g. @@ -274,61 +361,77 @@ Examples: `List$Elim (x :: xs) e1 e2 = e2 x xs (List$Elim xs e1 e2)` 2. For `tree = | T (int -> tree)`, the generated function is: `Tree$Elim (T f) e = e f (fun (x: int) => Tree$Elim (f x) e)` - +3. For `RoseTree = R (Forest) with Forest = FNil | FCons RoseTree Forest`, +the generated functions are: +`RoseTree$Elim (R x) r1 f1 f2 = r1 x (Forest$Elim x r1 f1 f2)` +`Forest$Elim FNil r1 f1 f2 = f1` +`Forest$Elim (Fcons r f) r1 f1 f2 = f2 r f (Rose$Elim r r1 f1 f2) (Forest$Elim f r1 f1 f2)` -/ -def elimConcreteEval {T: LExprParams} [BEq T.Identifier] (d: LDatatype T.IDMeta) (m: T.Metadata) (elimName : Identifier T.IDMeta) : - T.Metadata → List (LExpr T.mono) → Option (LExpr T.mono) := +def elimConcreteEval {T: LExprParams} [BEq T.Identifier] [Inhabited T.IDMeta] (block : MutualDatatype T.IDMeta) + (m : T.Metadata) : T.Metadata → List (LExpr T.mono) → Option (LExpr T.mono) := fun _ args => match args with | x :: xs => - match datatypeGetConstr d x with - | .some (_, i, a, recs) => - match xs[i]? with - | .some f => f.mkApp m (a ++ recs.map (fun (r, rty) => elimRecCall d r rty xs m elimName)) - | .none => .none - | .none => .none - | _ => .none - -def elimFuncName (d: LDatatype IDMeta) : Identifier IDMeta := - d.name ++ "$Elim" + match matchConstr block x with + | some (dtIdx, constrIdx, _, a, recs) => + let gIdx := blockConstrIdx block dtIdx constrIdx + xs[gIdx]?.bind fun f => + some (f.mkApp m (a ++ recs.filterMap (fun (r, rty) => elimRecCall block r rty xs m))) + | none => none + | _ => none /-- -The `LFunc` corresponding to the eliminator for datatype `d`, called e.g. -`List$Elim` for type `List`. +Generate eliminators for all datatypes in a mutual block. +Each datatype gets its own eliminator, but they share case function arguments +for all constructors across the block. +For example, +`RoseTree$Elim : RoseTree → (Forest → β → α) → β → (RoseTree → Forest → α → β → β) → α` +`Forest$Elim : Forest → (Forest → β → α) → β → (RoseTree → Forest → α → β → β) → β` -/ -def elimFunc [Inhabited T.IDMeta] [BEq T.Identifier] (d: LDatatype T.IDMeta) (m: T.Metadata) : LFunc T := - let outTyId := freshTypeArg d.typeArgs - { name := elimFuncName d, typeArgs := outTyId :: d.typeArgs, inputs := List.zip (genArgNames (d.constrs.length + 1)) (dataDefault d :: d.constrs.map (elimTy (.ftvar outTyId) d)), output := .ftvar outTyId, concreteEval := elimConcreteEval d m (elimFuncName d)} +def elimFuncs [Inhabited T.IDMeta] [BEq T.Identifier] (block : MutualDatatype T.IDMeta) (m : T.Metadata) + : List (LFunc T) := + if h: block.isEmpty then [] else + have hlen : 0 < List.length block := by + have := @List.isEmpty_iff_length_eq_zero _ block; grind + let typeArgs := block[0].typeArgs -- OK because all must have same typevars + let retTyVars := freshTypeArgs block.length typeArgs + let allConstrs := + block.mapIdx (fun dtIdx d => d.constrs.map (dtIdx, ·)) |>.flatten + let caseTypes := allConstrs.map fun (dtIdx, c) => elimTy block retTyVars dtIdx c + (block.zip retTyVars).map fun (d, outputTyVar) => + { name := elimFuncName d + typeArgs := retTyVars ++ d.typeArgs + inputs := List.zip (genArgNames (allConstrs.length + 1)) (dataDefault d :: caseTypes) + output := .ftvar outputTyVar + concreteEval := elimConcreteEval block m } --------------------------------------------------------------------- -- Generating testers and destructors /-- -Generate tester body (see `testerFunc`). The body assigns each argument of the -eliminator (fun _ ... _ => b), where b is true for the constructor's index and -false otherwise. This requires knowledge of the number of arguments for each -argument to the eliminator.-/ -def testerFuncBody {T : LExprParams} [Inhabited T.IDMeta] (d: LDatatype T.IDMeta) (c: LConstr T.IDMeta) (input: LExpr T.mono) (m: T.Metadata) : LExpr T.mono := - -- Number of arguments is number of constr args + number of recursive args - let numargs (c: LConstr T.IDMeta) := c.args.length + ((c.args.map Prod.snd).filter (isRecTy d)).length - let args := List.map (fun c1 => LExpr.absMultiInfer m (numargs c1) (.boolConst m (c.name.name == c1.name.name))) d.constrs +Generate tester body for mutual blocks. For mutual eliminators, we need case functions +for ALL constructors across the block, not just the constructors of one datatype. +-/ +def testerFuncBody {T : LExprParams} [Inhabited T.IDMeta] (block: MutualDatatype T.IDMeta) + (d: LDatatype T.IDMeta) (c: LConstr T.IDMeta) (input: LExpr T.mono) (m: T.Metadata) : LExpr T.mono := + -- Number of arguments for a constructor in a mutual block + let numargs (c: LConstr T.IDMeta) := c.args.length + ((c.args.map Prod.snd).filter (isRecTy block)).length + let args := block.flatMap (fun d' => d'.constrs.map (fun c1 => + LExpr.absMultiInfer m (numargs c1) (.boolConst m (c.name.name == c1.name.name)))) .mkApp m (.op m (elimFuncName d) .none) (input :: args) /-- -Generate tester function for a constructor (e.g. `List$isCons` and -`List$isNil`). The semantics of the testers are given via a body, -and they are defined in terms of eliminators. For example: -`List$isNil l := List$Elim l true (fun _ _ _ => false)` -`List$isCons l := List$Elim l false (fun _ _ _ => true)` +Generate tester function for a constructor in a mutual block. -/ -def testerFunc {T} [Inhabited T.IDMeta] (d: LDatatype T.IDMeta) (c: LConstr T.IDMeta) (m: T.Metadata) : LFunc T := +def testerFunc {T} [Inhabited T.IDMeta] (block: MutualDatatype T.IDMeta) + (d: LDatatype T.IDMeta) (c: LConstr T.IDMeta) (m: T.Metadata) : LFunc T := let arg := genArgName {name := c.testerName, typeArgs := d.typeArgs, inputs := [(arg, dataDefault d)], output := .bool, - body := testerFuncBody d c (.fvar m arg .none) m, + body := testerFuncBody block d c (.fvar m arg .none) m, attr := #["inline_if_val"] } @@ -341,7 +444,7 @@ def destructorConcreteEval {T: LExprParams} [BEq T.Identifier] (d: LDatatype T.I fun args => match args with | [x] => - (datatypeGetConstr d x).bind (fun (c1, _, a, _) => + (datatypeGetConstr d x).bind (fun (c1, _, a) => if c1.name.name == c.name.name then a[idx]? else none) | _ => none @@ -368,27 +471,77 @@ def destructorFuncs {T} [BEq T.Identifier] [Inhabited T.IDMeta] (d: LDatatype T -- Type Factories -def TypeFactory := Array (LDatatype IDMeta) +/-- A TypeFactory stores datatypes grouped by mutual recursion. -/ +def TypeFactory := Array (MutualDatatype IDMeta) instance: ToFormat (@TypeFactory IDMeta) where - format f := Std.Format.joinSep f.toList f!"{Format.line}" + format f := + let formatBlock (block : MutualDatatype IDMeta) : Format := + match block with + | [d] => format d + | ds => f!"mutual {Std.Format.joinSep (ds.map format) Format.line} end" + Std.Format.joinSep (f.toList.map formatBlock) f!"{Format.line}" + +instance [Repr IDMeta] : Repr (@TypeFactory IDMeta) where + reprPrec f n := reprPrec f.toList n instance : Inhabited (@TypeFactory IDMeta) where default := #[] def TypeFactory.default : @TypeFactory IDMeta := #[] +/-- Get all datatypes in the TypeFactory as a flat list. -/ +def TypeFactory.allDatatypes (t : @TypeFactory IDMeta) : List (LDatatype IDMeta) := + t.toList.flatten + +/-- Find a datatype by name in the TypeFactory. -/ +def TypeFactory.getType (F : @TypeFactory IDMeta) (name : String) : Option (LDatatype IDMeta) := + F.allDatatypes.find? (fun d => d.name == name) + +/-- Get all datatype names in the TypeFactory. -/ +def TypeFactory.allTypeNames (t : @TypeFactory IDMeta) : List String := + t.allDatatypes.map (·.name) + +/-- +Get all type constructor names referenced in a type. +-/ +def getTypeRefs (ty: LMonoTy) : List String := + match ty with + | .tcons n args => n :: args.flatMap getTypeRefs + | _ => [] + /-- -Generates the Factory (containing the eliminator, constructors, testers, -and destructors) for a single datatype. +Ensures all type occuring a constructor are only primitive types, +types defined previously, or types in the same mutual block. -/ -def LDatatype.genFactory {T: LExprParams} [inst: Inhabited T.Metadata] [Inhabited T.IDMeta] [ToFormat T.IDMeta] [BEq T.Identifier] (d: LDatatype T.IDMeta): Except Format (@Lambda.Factory T) := do - _ ← checkStrictPosUnif d - Factory.default.addFactory ( - elimFunc d inst.default :: - d.constrs.map (fun c => constrFunc c d) ++ - d.constrs.map (fun c => testerFunc d c inst.default) ++ - (d.constrs.map (fun c => destructorFuncs d c)).flatten).toArray +def TypeFactory.validateTypeReferences (t : @TypeFactory IDMeta) (block : MutualDatatype IDMeta) (knownTypes : List String) : Except Format Unit := do + let validNames : Std.HashSet String := + Std.HashSet.ofList (knownTypes ++ t.allTypeNames ++ block.map (·.name)) + for d in block do + for c in d.constrs do + for (_, ty) in c.args do + for ref in getTypeRefs ty do + if !validNames.contains ref then + throw f!"Error in datatype {d.name}, constructor {c.name.name}: Undefined type '{ref}'" + +/-- Add a mutual block to the TypeFactory, checking for duplicates, + inconsistent types, and positivity. -/ +def TypeFactory.addMutualBlock (t : @TypeFactory IDMeta) (block : MutualDatatype IDMeta) (knownTypes : List String := []) : Except Format (@TypeFactory IDMeta) := do + -- Check for name clashes within block + validateMutualBlock block + -- Check for positivity and uniformity + checkStrictPosUnif block + -- Check for duplicate names with existing types + for d in block do + match t.getType d.name with + | some d' => throw f!"A datatype of name {d.name} already exists! \ + Redefinitions are not allowed.\n\ + Existing Type: {d'}\n\ + New Type:{d}" + | none => pure () + -- Check for consistent type dependencies + t.validateTypeReferences block knownTypes + .ok (t.push block) /-- Constructs maps of generated functions for datatype `d`: map of @@ -405,28 +558,26 @@ def LDatatype.genFunctionMaps {T: LExprParams} [Inhabited T.IDMeta] [BEq T.Ident (destructorFuncs d c).map (fun f => (f.name.name, (d, c))))).flatten) /-- -Generates the Factory (containing all constructor and eliminator functions) for the given `TypeFactory` +Generates the Factory (containing eliminators, constructors, testers, and destructors) +for a mutual block of datatypes. -/ -def TypeFactory.genFactory {T: LExprParams} [inst: Inhabited T.Metadata] [Inhabited T.IDMeta] [ToFormat T.IDMeta] [BEq T.Identifier] (t: @TypeFactory T.IDMeta) : Except Format (@Lambda.Factory T) := - t.foldlM (fun f d => do - let f' ← d.genFactory - f.addFactory f') Factory.default - -def TypeFactory.getType (F : @TypeFactory IDMeta) (name : String) : Option (LDatatype IDMeta) := - F.find? (fun d => d.name == name) +def genBlockFactory {T: LExprParams} [inst: Inhabited T.Metadata] [Inhabited T.IDMeta] [ToFormat T.IDMeta] [BEq T.Identifier] + (block : MutualDatatype T.IDMeta) : Except Format (@Lambda.Factory T) := do + if block.isEmpty then return Factory.default + let elims := elimFuncs block inst.default + let constrs := block.flatMap (fun d => d.constrs.map (fun c => constrFunc c d)) + let testers := block.flatMap (fun d => d.constrs.map (fun c => testerFunc block d c inst.default)) + let destrs := block.flatMap (fun d => d.constrs.flatMap (fun c => destructorFuncs d c)) + Factory.default.addFactory (elims ++ constrs ++ testers ++ destrs).toArray /-- -Add an `LDatatype` to an existing `TypeFactory`, checking that no -types are duplicated. +Generates the Factory (containing all constructor and eliminator functions) for the given `TypeFactory`. -/ -def TypeFactory.addDatatype (t: @TypeFactory IDMeta) (d: LDatatype IDMeta) : Except Format (@TypeFactory IDMeta) := - -- Check that type is not redeclared - match t.getType d.name with - | none => .ok (t.push d) - | some d' => .error f!"A datatype of name {d.name} already exists! \ - Redefinitions are not allowed.\n\ - Existing Type: {d'}\n\ - New Type:{d}" +def TypeFactory.genFactory {T: LExprParams} [inst: Inhabited T.Metadata] [Inhabited T.IDMeta] [ToFormat T.IDMeta] [BEq T.Identifier] (t: @TypeFactory T.IDMeta) : Except Format (@Lambda.Factory T) := + t.foldlM (fun f block => do + let f' ← genBlockFactory block + f.addFactory f' + ) Factory.default --------------------------------------------------------------------- diff --git a/Strata/DL/SMT/Solver.lean b/Strata/DL/SMT/Solver.lean index be75bb67d..ca6f3b0b4 100644 --- a/Strata/DL/SMT/Solver.lean +++ b/Strata/DL/SMT/Solver.lean @@ -131,6 +131,20 @@ def declareDatatype (id : String) (params : List String) (constructors : List St then emitln s!"(declare-datatype {id} ({cInline}))" else emitln s!"(declare-datatype {id} (par ({pInline}) ({cInline})))" +/-- Declare multiple mutually recursive datatypes. Each element is (name, params, constructors). -/ +def declareDatatypes (dts : List (String × List String × List String)) : SolverM Unit := do + if dts.isEmpty then return + let sortDecls := dts.map fun (name, params, _) => s!"({name} {params.length})" + let sortDeclStr := String.intercalate " " sortDecls + let bodies := dts.map fun (_, params, constrs) => + let cInline := String.intercalate " " constrs + if params.isEmpty then s!"({cInline})" + else + let pInline := String.intercalate " " params + s!"(par ({pInline}) ({cInline}))" + let bodyStr := String.intercalate "\n " bodies + emitln s!"(declare-datatypes ({sortDeclStr})\n ({bodyStr}))" + private def readlnD (dflt : String) : SolverM String := do match (← read).smtLibOutput with | .some stdout => stdout.getLine diff --git a/Strata/Languages/Core/DDMTransform/Translate.lean b/Strata/Languages/Core/DDMTransform/Translate.lean index f5deb5f99..18fac1c22 100644 --- a/Strata/Languages/Core/DDMTransform/Translate.lean +++ b/Strata/Languages/Core/DDMTransform/Translate.lean @@ -251,10 +251,13 @@ partial def translateLMonoTy (bindings : TransBindings) (arg : Arg) : | .type (.syn syn) _md => let ty := syn.toLHSLMonoTy pure ty - | .type (.data ldatatype) => + | .type (.data (ldatatype :: _)) _md => -- Datatype Declaration + -- TODO: Handle mutual blocks, need to find the specific datatype by name let args := ldatatype.typeArgs.map LMonoTy.ftvar pure (.tcons ldatatype.name args) + | .type (.data []) _md => + TransM.error "Empty mutual datatype block" | _ => TransM.error s!"translateLMonoTy not yet implemented for this declaration: \ @@ -1347,7 +1350,7 @@ def translateDatatype (p : Program) (bindings : TransBindings) (op : Operation) typeArgs := typeArgs constrs := [{ name := datatypeName, args := [], testerName := "" }] constrs_ne := by simp } - let placeholderDecl := Core.Decl.type (.data placeholderLDatatype) + let placeholderDecl := Core.Decl.type (.data [placeholderLDatatype]) let bindingsWithPlaceholder := { bindings with freeVars := bindings.freeVars.push placeholderDecl } -- Extract constructor information (possibly recursive) @@ -1376,14 +1379,14 @@ def translateDatatype (p : Program) (bindings : TransBindings) (op : Operation) -- Generate factory from LDatatype and convert to Core.Decl -- (used only for bindings.freeVars, not for allDecls) - let factory ← match ldatatype.genFactory (T := CoreLParams) with + let factory ← match genBlockFactory [ldatatype] (T := CoreLParams) with | .ok f => pure f | .error e => TransM.error s!"Failed to generate datatype factory: {e}" let funcDecls : List Core.Decl := factory.toList.map fun func => Core.Decl.func func -- Only includes typeDecl, factory functions generated later - let typeDecl := Core.Decl.type (.data ldatatype) + let typeDecl := Core.Decl.type (.data [ldatatype]) let allDecls := [typeDecl] /- diff --git a/Strata/Languages/Core/Env.lean b/Strata/Languages/Core/Env.lean index 2ecb1694d..6d6d77db5 100644 --- a/Strata/Languages/Core/Env.lean +++ b/Strata/Languages/Core/Env.lean @@ -318,10 +318,13 @@ def Env.merge (cond : Expression.Expr) (E1 E2 : Env) : Env := else Env.performMerge cond E1 E2 (by simp_all) (by simp_all) -def Env.addDatatypes (E: Env) (datatypes: List (Lambda.LDatatype Visibility)) : Except Format Env := do - let f ← Lambda.TypeFactory.genFactory (T:=CoreLParams) (datatypes.toArray) +def Env.addMutualDatatype (E: Env) (block: Lambda.MutualDatatype Visibility) : Except Format Env := do + let f ← Lambda.genBlockFactory (T:=CoreLParams) block let env ← E.addFactory f - return { env with datatypes := datatypes.toArray } + return { env with datatypes := E.datatypes.push block } + +def Env.addDatatypes (E: Env) (blocks: List (Lambda.MutualDatatype Visibility)) : Except Format Env := + blocks.foldlM Env.addMutualDatatype E end Core diff --git a/Strata/Languages/Core/Program.lean b/Strata/Languages/Core/Program.lean index 67fe33ddd..8fbd1c39d 100644 --- a/Strata/Languages/Core/Program.lean +++ b/Strata/Languages/Core/Program.lean @@ -79,6 +79,12 @@ def Decl.name (d : Decl) : Expression.Ident := | .proc p _ => p.header.name | .func f _ => f.name +/-- Get all names from a declaration. For mutual datatypes, returns all datatype names. -/ +def Decl.names (d : Decl) : List Expression.Ident := + match d with + | .type t _ => t.names + | _ => [d.name] + def Decl.getVar? (d : Decl) : Option (Expression.Ident × Expression.Ty × Expression.Expr) := match d with @@ -207,7 +213,7 @@ def Program.getInit? (P: Program) (x : Expression.Ident) : Option Expression.Exp def Program.getNames (P: Program) : List Expression.Ident := go P.decls - where go decls := decls.map Decl.name + where go decls := decls.flatMap Decl.names def Program.Type.find? (P : Program) (x : Expression.Ident) : Option TypeDecl := match P.find? .type x with diff --git a/Strata/Languages/Core/ProgramType.lean b/Strata/Languages/Core/ProgramType.lean index c28b88b27..de2c0febe 100644 --- a/Strata/Languages/Core/ProgramType.lean +++ b/Strata/Languages/Core/ProgramType.lean @@ -34,9 +34,11 @@ def typeCheck (C: Core.Expression.TyContext) (Env : Core.Expression.TyEnv) (prog | decl :: drest => do let sourceLoc := Imperative.MetaData.formatFileRangeD decl.metadata (includeEnd? := true) let errorWithSourceLoc := fun e => if sourceLoc.isEmpty then e else f!"{sourceLoc} {e}" - let C := {C with idents := (← C.idents.addWithError decl.name - f!"{sourceLoc} Error in {decl.kind} {decl.name}: \ - a declaration of this name already exists.")} + -- Add all names from the declaration (multiple for mutual datatypes) + let C ← decl.names.foldlM (fun C name => do + let idents ← C.idents.addWithError name + f!"{sourceLoc} Error in {decl.kind} {name}: a declaration of this name already exists." + pure {C with idents}) C let (decl', C, Env) ← match decl with @@ -61,8 +63,8 @@ def typeCheck (C: Core.Expression.TyContext) (Env : Core.Expression.TyEnv) (prog | .syn ts => let Env ← TEnv.addTypeAlias { typeArgs := ts.typeArgs, name := ts.name, type := ts.type } C Env .ok (.type td, C, Env) - | .data d => - let C ← C.addDatatype d + | .data block => + let C ← C.addMutualBlock block .ok (.type td, C, Env) catch e => .error (errorWithSourceLoc e) diff --git a/Strata/Languages/Core/ProgramWF.lean b/Strata/Languages/Core/ProgramWF.lean index 10379b42f..2a09967fd 100644 --- a/Strata/Languages/Core/ProgramWF.lean +++ b/Strata/Languages/Core/ProgramWF.lean @@ -280,17 +280,22 @@ theorem addKnownTypeWithErrorIdents {C: Expression.TyContext}: C.addKnownTypeWit case error => intros _; contradiction case ok k'=> simp[Except.bind]; intros T'; subst T'; rfl +/- theorem addDatatypeIdents {C: Expression.TyContext}: C.addDatatype d = .ok C' → C.idents = C'.idents := by unfold LContext.addDatatype; simp only[bind, Except.bind, pure, Except.pure]; intros Hok repeat (split at Hok <;> try contradiction) cases Hok <;> rfl +-/ /-- If a program typechecks successfully, then every identifier in the list of program decls is not in the original `LContext` -/ theorem Program.typeCheckFunctionDisjoint : Program.typeCheck.go p C T decls acc = .ok (d', T') → (∀ x, x ∈ Program.getNames.go decls → ¬ C.idents.contains x) := by + -- TODO: This proof needs to be updated to handle mutual datatypes (multiple names per decl) + sorry + /- induction decls generalizing acc p d' T' T C with | nil => simp[Program.getNames.go] | cons r rs IH => @@ -329,12 +334,16 @@ theorem Program.typeCheckFunctionDisjoint : Program.typeCheck.go p C T decls acc rename_i hmatch2; split at hmatch2 <;> try grind simp only [LContext.addFactoryFunction] at hmatch2; grind done + -/ /-- If a program typechecks succesfully, all identifiers defined in the program are unique. -/ theorem Program.typeCheckFunctionNoDup : Program.typeCheck.go p C T decls acc = .ok (d', T') → (Program.getNames.go decls).Nodup := by + -- TODO: This proof needs to be updated to handle mutual datatypes (multiple names per decl) + sorry + /- induction decls generalizing acc p C T with | nil => simp[Program.getNames.go] | cons r rs IH => @@ -376,6 +385,7 @@ theorem Program.typeCheckFunctionNoDup : Program.typeCheck.go p C T decls acc = rename_i hmatch2; split at hmatch2 <;> try grind simp only [LContext.addFactoryFunction] at hmatch2; grind done + -/ /-- The main lemma stating that a program 'p' that passes type checking is well formed diff --git a/Strata/Languages/Core/SMTEncoder.lean b/Strata/Languages/Core/SMTEncoder.lean index 7819858bb..11c02f53b 100644 --- a/Strata/Languages/Core/SMTEncoder.lean +++ b/Strata/Languages/Core/SMTEncoder.lean @@ -35,9 +35,12 @@ structure SMT.Context where ifs : Array SMT.IF := #[] axms : Array Term := #[] tySubst: Map String TermType := [] - datatypes : Array (LDatatype CoreLParams.IDMeta) := #[] + /-- Stores the TypeFactory purely for ordering datatype declarations + correctly (TypeFactory in topological order) -/ + typeFactory : @Lambda.TypeFactory CoreLParams.IDMeta := #[] + seenDatatypes : Std.HashSet String := {} datatypeFuns : Map String (Op.DatatypeFuncs × LConstr CoreLParams.IDMeta) := Map.empty -deriving Repr, DecidableEq, Inhabited +deriving Repr, Inhabited def SMT.Context.default : SMT.Context := {} @@ -65,7 +68,7 @@ def SMT.Context.removeSubst (ctx : SMT.Context) (newSubst: Map String TermType) { ctx with tySubst := newSubst.foldl (fun acc_m p => acc_m.erase p.fst) ctx.tySubst } def SMT.Context.hasDatatype (ctx : SMT.Context) (name : String) : Bool := - (ctx.datatypes.map LDatatype.name).contains name + ctx.seenDatatypes.contains name def SMT.Context.addDatatype (ctx : SMT.Context) (d : LDatatype CoreLParams.IDMeta) : SMT.Context := if ctx.hasDatatype d.name then ctx @@ -74,7 +77,10 @@ def SMT.Context.addDatatype (ctx : SMT.Context) (d : LDatatype CoreLParams.IDMet let m := Map.union ctx.datatypeFuns (c.fmap (fun (_, x) => (.constructor, x))) let m := Map.union m (i.fmap (fun (_, x) => (.tester, x))) let m := Map.union m (s.fmap (fun (_, x) => (.selector, x))) - { ctx with datatypes := ctx.datatypes.push d, datatypeFuns := m } + { ctx with seenDatatypes := ctx.seenDatatypes.insert d.name, datatypeFuns := m } + +def SMT.Context.withTypeFactory (ctx : SMT.Context) (tf : @Lambda.TypeFactory CoreLParams.IDMeta) : SMT.Context := + { ctx with typeFactory := tf } /-- Helper function to convert LMonoTy to SMT string representation. @@ -93,89 +99,32 @@ private def lMonoTyToSMTString (ty : LMonoTy) : String := else s!"({name} {String.intercalate " " (args.map lMonoTyToSMTString)})" | .ftvar tv => tv -/-- -Build a dependency graph for datatypes. -Returns a mapping from datatype names to their dependencies. --/ -private def buildDatatypeDependencyGraph (datatypes : Array (LDatatype CoreLParams.IDMeta)) : - Map String (Array String) := - let depMap := datatypes.foldl (fun acc d => - let deps := d.constrs.foldl (fun deps c => - c.args.foldl (fun deps (_, fieldTy) => - match fieldTy with - | .tcons typeName _ => - -- Only include dependencies on other datatypes in our set - if datatypes.any (fun dt => dt.name == typeName) then - deps.push typeName - else deps - | _ => deps - ) deps - ) #[] - acc.insert d.name deps - ) Map.empty - depMap - -/-- -Convert datatype dependency map to OutGraph for Tarjan's algorithm. -Returns the graph and a mapping from node indices to datatype names. --/ -private def dependencyMapToGraph (depMap : Map String (Array String)) : - (n : Nat) × Strata.OutGraph n × Array String := - let names := depMap.keys.toArray - let n := names.size - let nameToIndex : Map String Nat := - names.mapIdx (fun i name => (name, i)) |>.foldl (fun acc (name, i) => acc.insert name i) Map.empty - - let edges := depMap.foldl (fun edges (fromName, deps) => - match nameToIndex.find? fromName with - | none => edges - | some fromIdx => - deps.foldl (fun edges depName => - match nameToIndex.find? depName with - | none => edges - | some toIdx => edges.push (fromIdx, toIdx) - ) edges - ) #[] - - let graph := Strata.OutGraph.ofEdges! n edges.toList - ⟨n, graph, names⟩ +/-- Convert a datatype's constructors to SMT format. -/ +private def datatypeConstructorsToSMT (d : LDatatype CoreLParams.IDMeta) : List String := + d.constrs.map fun c => + let fieldPairs := c.args.map fun (name, fieldTy) => (name.name, lMonoTyToSMTString fieldTy) + let fieldStrs := fieldPairs.map fun (name, ty) => s!"({name} {ty})" + let fieldsStr := String.intercalate " " fieldStrs + if c.args.isEmpty then s!"({c.name.name})" + else s!"({c.name.name} {fieldsStr})" /-- -Emit datatype declarations to the solver in topologically sorted order. -For each datatype in ctx.datatypes, generates a declare-datatype command -with constructors and selectors following the TypeFactory naming convention. -Dependencies are emitted before the datatypes that depend on them, and -mutually recursive datatypes are not (yet) supported. +Emit datatype declarations to the solver. +Uses the TypeFactory ordering (already topologically sorted). +Only emits datatypes that have been seen (added via addDatatype). +Single-element blocks use declare-datatype, multi-element blocks use declare-datatypes. -/ def SMT.Context.emitDatatypes (ctx : SMT.Context) : Strata.SMT.SolverM Unit := do - if ctx.datatypes.isEmpty then return - - -- Build dependency graph and SCCs - let depMap := buildDatatypeDependencyGraph ctx.datatypes - let ⟨_, graph, names⟩ := dependencyMapToGraph depMap - let sccs := Strata.OutGraph.tarjan graph - - -- Emit datatypes in topological order (reverse of SCC order) - for scc in sccs.reverse do - if scc.size > 1 then - let sccNames := scc.map (fun idx => names[idx]!) - throw (IO.userError s!"Mutually recursive datatypes not supported: {sccNames.toList}") - else - for nodeIdx in scc do - let datatypeName := names[nodeIdx]! - -- Find the datatype by name - match ctx.datatypes.find? (fun d => d.name == datatypeName) with - | none => throw (IO.userError s!"Datatype {datatypeName} not found in context") - | some d => - let constructors ← d.constrs.mapM fun c => do - let fieldPairs := c.args.map fun (name, fieldTy) => (name.name, lMonoTyToSMTString fieldTy) - let fieldStrs := fieldPairs.map fun (name, ty) => s!"({name} {ty})" - let fieldsStr := String.intercalate " " fieldStrs - if c.args.isEmpty then - pure s!"({c.name.name})" - else - pure s!"({c.name.name} {fieldsStr})" - Strata.SMT.Solver.declareDatatype d.name d.typeArgs constructors + for block in ctx.typeFactory.toList do + let usedBlock := block.filter (fun d => ctx.seenDatatypes.contains d.name) + match usedBlock with + | [] => pure () + | [d] => + let constructors := datatypeConstructorsToSMT d + Strata.SMT.Solver.declareDatatype d.name d.typeArgs constructors + | _ => + let dts := usedBlock.map fun d => (d.name, d.typeArgs, datatypeConstructorsToSMT d) + Strata.SMT.Solver.declareDatatypes dts abbrev BoundVars := List (String × TermType) @@ -608,6 +557,7 @@ end def toSMTTerms (E : Env) (es : List (LExpr CoreLParams.mono)) (ctx : SMT.Context) : Except Format ((List Term) × SMT.Context) := do + let ctx := if ctx.typeFactory.isEmpty then ctx.withTypeFactory E.datatypes else ctx match es with | [] => .ok ([], ctx) | e :: erest => @@ -725,7 +675,7 @@ info: "; f\n(declare-fun f0 (Int Int) Int)\n; x\n(declare-const f1 Int)\n(define #eval toSMTTermString (.quant () .all (.some .int) (.bvar () 0) (.quant () .all (.some .int) (.app () (.app () (.op () "f" (.some (.arrow .int (.arrow .int .int)))) (.bvar () 0)) (.bvar () 1)) (.eq () (.app () (.app () (.op () "f" (.some (.arrow .int (.arrow .int .int)))) (.bvar () 0)) (.bvar () 1)) (.fvar () "x" (.some .int))))) - (ctx := SMT.Context.mk #[] #[UF.mk "f" ((TermVar.mk "m" TermType.int) ::(TermVar.mk "n" TermType.int) :: []) TermType.int] #[] #[] [] #[] []) + (ctx := SMT.Context.mk #[] #[UF.mk "f" ((TermVar.mk "m" TermType.int) ::(TermVar.mk "n" TermType.int) :: []) TermType.int] #[] #[] [] #[] {} []) (E := {Env.init with exprEnv := { Env.init.exprEnv with config := { Env.init.exprEnv.config with @@ -743,7 +693,7 @@ info: "; f\n(declare-fun f0 (Int Int) Int)\n; x\n(declare-const f1 Int)\n(define #eval toSMTTermString (.quant () .all (.some .int) (.bvar () 0) (.quant () .all (.some .int) (.bvar () 0) (.eq () (.app () (.app () (.op () "f" (.some (.arrow .int (.arrow .int .int)))) (.bvar () 0)) (.bvar () 1)) (.fvar () "x" (.some .int))))) - (ctx := SMT.Context.mk #[] #[UF.mk "f" ((TermVar.mk "m" TermType.int) ::(TermVar.mk "n" TermType.int) :: []) TermType.int] #[] #[] [] #[] []) + (ctx := SMT.Context.mk #[] #[UF.mk "f" ((TermVar.mk "m" TermType.int) ::(TermVar.mk "n" TermType.int) :: []) TermType.int] #[] #[] [] #[] {} []) (E := {Env.init with exprEnv := { Env.init.exprEnv with config := { Env.init.exprEnv.config with diff --git a/Strata/Languages/Core/TypeDecl.lean b/Strata/Languages/Core/TypeDecl.lean index 4777445bb..5a45f8286 100644 --- a/Strata/Languages/Core/TypeDecl.lean +++ b/Strata/Languages/Core/TypeDecl.lean @@ -86,7 +86,7 @@ def TypeSynonym.toRHSLTy (t : TypeSynonym) : LTy := inductive TypeDecl where | con : TypeConstructor → TypeDecl | syn : TypeSynonym → TypeDecl - | data : LDatatype Visibility → TypeDecl + | data : List (LDatatype Visibility) → TypeDecl deriving Repr instance : ToFormat TypeDecl where @@ -94,12 +94,23 @@ instance : ToFormat TypeDecl where match d with | .con tc => f!"{tc}" | .syn ts => f!"{ts}" - | .data td => f!"{td}" + | .data [] => f!"" + | .data [td] => f!"{td}" + | .data tds => f!"mutual {Std.Format.joinSep (tds.map format) Format.line} end" +/-- Get all names from a TypeDecl. -/ +def TypeDecl.names (d : TypeDecl) : List Expression.Ident := + match d with + | .con tc => [tc.name] + | .syn ts => [ts.name] + | .data tds => tds.map (·.name) + +/-- Get the primary name of a TypeDecl (first name for mutual blocks). -/ def TypeDecl.name (d : TypeDecl) : Expression.Ident := match d with | .con tc => tc.name | .syn ts => ts.name - | .data td => td.name + | .data [] => "" + | .data (td :: _) => td.name --------------------------------------------------------------------- diff --git a/StrataTest/DL/Lambda/TypeFactoryTests.lean b/StrataTest/DL/Lambda/TypeFactoryTests.lean index 0a4eba152..351b7d164 100644 --- a/StrataTest/DL/Lambda/TypeFactoryTests.lean +++ b/StrataTest/DL/Lambda/TypeFactoryTests.lean @@ -57,7 +57,7 @@ info: #3 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[weekTy] (Factory.default : @Factory TestParams) ((LExpr.op () ("Day$Elim" : TestParams.Identifier) .none).mkApp () (.op () ("W" : TestParams.Identifier) (.some (.tcons "Day" [])) :: (List.range 7).map (intConst () ∘ Int.ofNat))) + typeCheckAndPartialEval #[[weekTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("Day$Elim" : TestParams.Identifier) .none).mkApp () (.op () ("W" : TestParams.Identifier) (.some (.tcons "Day" [])) :: (List.range 7).map (intConst () ∘ Int.ofNat))) /-- @@ -69,7 +69,7 @@ info: #true -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[weekTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[weekTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("Day$isW" : TestParams.Identifier) .none).mkApp () [.op () ("W" : TestParams.Identifier) (.some (.tcons "Day" []))]) /-- @@ -81,7 +81,7 @@ info: #false -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[weekTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[weekTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("Day$isW" : TestParams.Identifier) .none).mkApp () [.op () ("M" : TestParams.Identifier) (.some (.tcons "Day" []))]) @@ -116,7 +116,7 @@ info: #3 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[tupTy] Factory.default (fst (prod (intConst () 3) (strConst () "a"))) + typeCheckAndPartialEval #[[tupTy]] Factory.default (fst (prod (intConst () 3) (strConst () "a"))) /-- info: Annotated expression: @@ -127,7 +127,7 @@ info: #a -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[tupTy] Factory.default (snd (prod (intConst () 3) (strConst () "a"))) + typeCheckAndPartialEval #[[tupTy]] Factory.default (snd (prod (intConst () 3) (strConst () "a"))) /-- @@ -139,7 +139,7 @@ info: #1 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[tupTy] Factory.default (fst (snd (prod (strConst () "a") (prod (intConst () 1) (strConst () "b"))))) + typeCheckAndPartialEval #[[tupTy]] Factory.default (fst (snd (prod (strConst () "a") (prod (intConst () 1) (strConst () "b"))))) -- Test 3: Polymorphic Lists @@ -171,7 +171,7 @@ info: #1 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) ((LExpr.op () ("List$Elim" : TestParams.Identifier) .none).mkApp () [nil, (intConst () 1), .abs () .none (.abs () .none (.abs () .none (intConst () 1)))]) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("List$Elim" : TestParams.Identifier) .none).mkApp () [nil, (intConst () 1), .abs () .none (.abs () .none (.abs () .none (intConst () 1)))]) -- Test: elim(cons 1 nil, 0, fun x y => x) -> (fun x y => x) 1 nil @@ -185,7 +185,7 @@ info: #2 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) ((LExpr.op () ("List$Elim" : TestParams.Identifier) .none).mkApp () [listExpr [intConst () 2], intConst () 0, .abs () .none (.abs () .none (.abs () .none (bvar () 2)))]) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("List$Elim" : TestParams.Identifier) .none).mkApp () [listExpr [intConst () 2], intConst () 0, .abs () .none (.abs () .none (.abs () .none (bvar () 2)))]) -- Test testers (isNil and isCons) @@ -197,7 +197,7 @@ info: #true -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("isNil" : TestParams.Identifier) .none).mkApp () [nil]) /-- info: Annotated expression: @@ -208,7 +208,7 @@ info: #false -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("isNil" : TestParams.Identifier) .none).mkApp () [cons (intConst () 1) nil]) /-- info: Annotated expression: @@ -219,7 +219,7 @@ info: #false -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("isCons" : TestParams.Identifier) .none).mkApp () [nil]) /-- info: Annotated expression: @@ -230,7 +230,7 @@ info: #true -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("isCons" : TestParams.Identifier) .none).mkApp () [cons (intConst () 1) nil]) -- But a non-value should NOT reduce @@ -247,7 +247,7 @@ info: ((~isCons : (arrow (List int) bool)) (~l : (List int))) #guard_msgs in #eval format $ do let f ← ((Factory.default : @Factory TestParams).addFactoryFunc ex_list) - (typeCheckAndPartialEval (T:=TestParams) #[listTy] f + (typeCheckAndPartialEval (T:=TestParams) #[[listTy]] f ((LExpr.op () ("isCons" : TestParams.Identifier) (some (LMonoTy.arrow (.tcons "List" [.int]) .bool))).mkApp () [.op () "l" .none])) -- Test destructors @@ -261,7 +261,7 @@ info: #1 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("hd" : TestParams.Identifier) .none).mkApp () [cons (intConst () 1) nil]) /-- @@ -272,7 +272,7 @@ info: (((~Cons : (arrow int (arrow (List int) (List int)))) #2) (~Nil : (List in -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("tl" : TestParams.Identifier) .none).mkApp () [cons (intConst () 1) (cons (intConst () 2) nil)]) -- Destructor does not evaluate on a different constructor @@ -284,7 +284,7 @@ info: Annotated expression: ((~tl : (arrow (List $__ty1) (List $__ty1))) (~Nil : info: ((~tl : (arrow (List $__ty1) (List $__ty1))) (~Nil : (List $__ty1)))-/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (Factory.default : @Factory TestParams) + typeCheckAndPartialEval #[[listTy]] (Factory.default : @Factory TestParams) ((LExpr.op () ("tl" : TestParams.Identifier) .none).mkApp () [nil]) @@ -308,7 +308,7 @@ info: #7 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy, tupTy] (IntBoolFactory : @Factory TestParams) + typeCheckAndPartialEval #[[listTy], [tupTy]] (IntBoolFactory : @Factory TestParams) ((LExpr.op () ("List$Elim" : TestParams.Identifier) .none).mkApp () [listExpr [(prod (intConst () 3) (strConst () "a")), (prod (intConst () 4) (strConst () "b"))], intConst () 0, @@ -332,7 +332,7 @@ info: #3 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (IntBoolFactory : @Factory TestParams) (length (listExpr [strConst () "a", strConst () "b", strConst () "c"])) + typeCheckAndPartialEval #[[listTy]] (IntBoolFactory : @Factory TestParams) (length (listExpr [strConst () "a", strConst () "b", strConst () "c"])) /-- info: Annotated expression: @@ -343,7 +343,7 @@ info: #15 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (IntBoolFactory : @Factory TestParams) (length (listExpr ((List.range 15).map (intConst () ∘ Int.ofNat)))) + typeCheckAndPartialEval #[[listTy]] (IntBoolFactory : @Factory TestParams) (length (listExpr ((List.range 15).map (intConst () ∘ Int.ofNat)))) /- Append is trickier since it takes in two arguments, so the eliminator returns @@ -367,7 +367,7 @@ info: (((~Cons : (arrow int (arrow (List int) (List int)))) #2) (((~Cons : (arro -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy] (IntBoolFactory : @Factory TestParams) (append list1 list2) + typeCheckAndPartialEval #[[listTy]] (IntBoolFactory : @Factory TestParams) (append list1 list2) -- 2. Preorder traversal of binary tree @@ -420,7 +420,7 @@ info: (((~Cons : (arrow int (arrow (List int) (List int)))) #1) (((~Cons : (arro -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[listTy, binTreeTy] (IntBoolFactory : @Factory TestParams) (toList tree1) + typeCheckAndPartialEval #[[listTy], [binTreeTy]] (IntBoolFactory : @Factory TestParams) (toList tree1) -- 3. Infinite-ary trees namespace Tree @@ -463,7 +463,7 @@ info: #3 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[treeTy] (IntBoolFactory : @Factory TestParams) (height 0 tree1) + typeCheckAndPartialEval #[[treeTy]] (IntBoolFactory : @Factory TestParams) (height 0 tree1) /--info: Annotated expression: ((((~tree$Elim : (arrow (tree int) (arrow (arrow int int) (arrow (arrow (arrow int (tree int)) (arrow (arrow int int) int)) int)))) ((~Node : (arrow (arrow int (tree int)) (tree int))) (λ ((~Node : (arrow (arrow int (tree int)) (tree int))) (λ (if ((((~Int.Add : (arrow int (arrow int int))) %1) %0) == #0) then ((~Node : (arrow (arrow int (tree int)) (tree int))) (λ ((~Leaf : (arrow int (tree int))) #3))) else ((~Leaf : (arrow int (tree int))) #4))))))) (λ #0)) (λ (λ (((~Int.Add : (arrow int (arrow int int))) #1) (%0 #1))))) @@ -473,7 +473,7 @@ info: #2 -/ #guard_msgs in #eval format $ - typeCheckAndPartialEval #[treeTy] (IntBoolFactory : @Factory TestParams) (height 1 tree1) + typeCheckAndPartialEval #[[treeTy]] (IntBoolFactory : @Factory TestParams) (height 1 tree1) end Tree @@ -490,7 +490,7 @@ def badTy1 : LDatatype Unit := {name := "Bad", typeArgs := [], constrs := [badCo /-- info: Error in constructor C: Non-strictly positive occurrence of Bad in type (arrow Bad Bad) -/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[badTy1] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[badTy1]] (IntBoolFactory : @Factory TestParams) (intConst () 0) /- 2.Non-strictly positive type @@ -502,7 +502,7 @@ def badTy2 : LDatatype Unit := {name := "Bad", typeArgs := ["a"], constrs := [ba /-- info: Error in constructor C: Non-strictly positive occurrence of Bad in type (arrow (arrow (Bad a) int) int)-/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[badTy2] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[badTy2]] (IntBoolFactory : @Factory TestParams) (intConst () 0) /- 3. Non-strictly positive type 2 @@ -514,7 +514,7 @@ def badTy3 : LDatatype Unit := {name := "Bad", typeArgs := ["a"], constrs := [ba /--info: Error in constructor C: Non-strictly positive occurrence of Bad in type (arrow (Bad a) int)-/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[badTy3] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[badTy3]] (IntBoolFactory : @Factory TestParams) (intConst () 0) /- 4. Strictly positive type @@ -531,7 +531,7 @@ def goodTy1 : LDatatype Unit := {name := "Good", typeArgs := ["a"], constrs := [ info: #0 -/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[goodTy1] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[goodTy1]] (IntBoolFactory : @Factory TestParams) (intConst () 0) /- 5. Non-uniform type @@ -542,7 +542,7 @@ def nonUnifTy1 : LDatatype Unit := {name := "Nonunif", typeArgs := ["a"], constr /-- info: Error in constructor C: Non-uniform occurrence of Nonunif, which is applied to [(List a)] when it should be applied to [a]-/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[listTy, nonUnifTy1] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[listTy], [nonUnifTy1]] (IntBoolFactory : @Factory TestParams) (intConst () 0) /- 6. Nested types are allowed, though they won't produce a useful elimination principle @@ -558,7 +558,7 @@ def nestTy1 : LDatatype Unit := {name := "Nest", typeArgs := ["a"], constrs := [ info: #0 -/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[listTy, nestTy1] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[listTy], [nestTy1]] (IntBoolFactory : @Factory TestParams) (intConst () 0) /- 7. 2 constructors with the same name: @@ -575,7 +575,7 @@ Existing Function: func C : ∀[a]. ((x : int)) → (Bad a); New Function:func C : ∀[a]. ((x : (Bad a))) → (Bad a); -/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[badTy4] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[badTy4]] (IntBoolFactory : @Factory TestParams) (intConst () 0) /- 8. Constructor with same name as function not allowed @@ -588,6 +588,211 @@ def badTy5 : LDatatype Unit := {name := "Bad", typeArgs := [], constrs := [badCo Existing Function: func Int.Add : ((x : int) (y : int)) → int; New Function:func Int.Add : ((x : int)) → Bad;-/ #guard_msgs in -#eval format $ typeCheckAndPartialEval #[badTy5] (IntBoolFactory : @Factory TestParams) (intConst () 0) +#eval format $ typeCheckAndPartialEval #[[badTy5]] (IntBoolFactory : @Factory TestParams) (intConst () 0) + +--------------------------------------------------------------------- +-- Test 9: Mutually recursive datatypes (RoseTree and Forest) +--------------------------------------------------------------------- + +section MutualRecursion + +/- +type RoseTree a = Node a (Forest a) +type Forest a = FNil | FCons (RoseTree a) (Forest a) +-/ + +def nodeConstr' : LConstr Unit := {name := "Node", args := [("val", .ftvar "a"), ("children", .tcons "Forest" [.ftvar "a"])], testerName := "isNode"} +def roseTreeTy' : LDatatype Unit := {name := "RoseTree", typeArgs := ["a"], constrs := [nodeConstr'], constrs_ne := rfl} + +def fnilConstr' : LConstr Unit := {name := "FNil", args := [], testerName := "isFNil"} +def fconsConstr' : LConstr Unit := {name := "FCons", args := [("head", .tcons "RoseTree" [.ftvar "a"]), ("tail", .tcons "Forest" [.ftvar "a"])], testerName := "isFCons"} +def forestTy' : LDatatype Unit := {name := "Forest", typeArgs := ["a"], constrs := [fnilConstr', fconsConstr'], constrs_ne := rfl} + +def roseForestBlock : MutualDatatype Unit := [roseTreeTy', forestTy'] + +-- Syntactic sugar +def node' (v children : LExpr TestParams.mono) : LExpr TestParams.mono := + (LExpr.op () ("Node" : TestParams.Identifier) .none).mkApp () [v, children] +def fnil' : LExpr TestParams.mono := .op () ("FNil" : TestParams.Identifier) .none +def fcons' (hd tl : LExpr TestParams.mono) : LExpr TestParams.mono := + (LExpr.op () ("FCons" : TestParams.Identifier) .none).mkApp () [hd, tl] + +-- Test testers +/-- info: Annotated expression: +((~isNode : (arrow (RoseTree int) bool)) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #1) (~FNil : (Forest int)))) + +--- +info: #true +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[roseForestBlock] (Factory.default : @Factory TestParams) + ((LExpr.op () ("isNode" : TestParams.Identifier) .none).mkApp () [node' (intConst () 1) fnil']) + +/-- info: Annotated expression: +((~isFNil : (arrow (Forest $__ty17) bool)) (~FNil : (Forest $__ty17))) + +--- +info: #true +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[roseForestBlock] (Factory.default : @Factory TestParams) + ((LExpr.op () ("isFNil" : TestParams.Identifier) .none).mkApp () [fnil']) + +/-- info: Annotated expression: +((~isFCons : (arrow (Forest int) bool)) (((~FCons : (arrow (RoseTree int) (arrow (Forest int) (Forest int)))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #1) (~FNil : (Forest int)))) (~FNil : (Forest int)))) + +--- +info: #true +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[roseForestBlock] (Factory.default : @Factory TestParams) + ((LExpr.op () ("isFCons" : TestParams.Identifier) .none).mkApp () [fcons' (node' (intConst () 1) fnil') fnil']) + +-- Test destructors +/-- info: Annotated expression: +((~val : (arrow (RoseTree int) int)) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #42) (~FNil : (Forest int)))) + +--- +info: #42 +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[roseForestBlock] (Factory.default : @Factory TestParams) + ((LExpr.op () ("val" : TestParams.Identifier) .none).mkApp () [node' (intConst () 42) fnil']) + +/-- info: Annotated expression: +((~head : (arrow (Forest int) (RoseTree int))) (((~FCons : (arrow (RoseTree int) (arrow (Forest int) (Forest int)))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #7) (~FNil : (Forest int)))) (~FNil : (Forest int)))) + +--- +info: (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #7) (~FNil : (Forest int))) +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[roseForestBlock] (Factory.default : @Factory TestParams) + ((LExpr.op () ("head" : TestParams.Identifier) .none).mkApp () [fcons' (node' (intConst () 7) fnil') fnil']) + +--------------------------------------------------------------------- +-- Test 10: Eliminator on mutually recursive types - computing tree size +--------------------------------------------------------------------- + +/- +A non-trivial rose tree: + 1 + /|\ + 2 3 4 + | + 5 +treeSize = 5 +-/ + +def nodeCaseFn' : LExpr TestParams.mono := + .abs () .none (.abs () .none (.abs () .none (addOp (intConst () 1) (.bvar () 0)))) + +def fnilCaseFn' : LExpr TestParams.mono := intConst () 0 + +def fconsCaseFn' : LExpr TestParams.mono := + .abs () .none (.abs () .none (.abs () .none (.abs () .none (addOp (.bvar () 1) (.bvar () 0))))) + +def treeSize' (t : LExpr TestParams.mono) : LExpr TestParams.mono := + (LExpr.op () ("RoseTree$Elim" : TestParams.Identifier) .none).mkApp () [t, nodeCaseFn', fnilCaseFn', fconsCaseFn'] + +def roseTree5 : LExpr TestParams.mono := + node' (intConst () 1) + (fcons' (node' (intConst () 2) fnil') + (fcons' (node' (intConst () 3) (fcons' (node' (intConst () 5) fnil') fnil')) + (fcons' (node' (intConst () 4) fnil') fnil'))) + +-- treeSize (Node 1 FNil) = 1 +/-- info: Annotated expression: +(((((~RoseTree$Elim : (arrow (RoseTree int) (arrow (arrow int (arrow (Forest int) (arrow int int))) (arrow int (arrow (arrow (RoseTree int) (arrow (Forest int) (arrow int (arrow int int)))) int))))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #1) (~FNil : (Forest int)))) (λ (λ (λ (((~Int.Add : (arrow int (arrow int int))) #1) %0))))) #0) (λ (λ (λ (λ (((~Int.Add : (arrow int (arrow int int))) %1) %0)))))) + +--- +info: #1 +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[roseForestBlock] (IntBoolFactory : @Factory TestParams) + (treeSize' (node' (intConst () 1) fnil')) + +-- treeSize roseTree5 = 5 +/-- info: Annotated expression: +(((((~RoseTree$Elim : (arrow (RoseTree int) (arrow (arrow int (arrow (Forest int) (arrow int int))) (arrow int (arrow (arrow (RoseTree int) (arrow (Forest int) (arrow int (arrow int int)))) int))))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #1) (((~FCons : (arrow (RoseTree int) (arrow (Forest int) (Forest int)))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #2) (~FNil : (Forest int)))) (((~FCons : (arrow (RoseTree int) (arrow (Forest int) (Forest int)))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #3) (((~FCons : (arrow (RoseTree int) (arrow (Forest int) (Forest int)))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #5) (~FNil : (Forest int)))) (~FNil : (Forest int))))) (((~FCons : (arrow (RoseTree int) (arrow (Forest int) (Forest int)))) (((~Node : (arrow int (arrow (Forest int) (RoseTree int)))) #4) (~FNil : (Forest int)))) (~FNil : (Forest int))))))) (λ (λ (λ (((~Int.Add : (arrow int (arrow int int))) #1) %0))))) #0) (λ (λ (λ (λ (((~Int.Add : (arrow int (arrow int int))) %1) %0)))))) + +--- +info: #5 +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[roseForestBlock] (IntBoolFactory : @Factory TestParams) + (treeSize' roseTree5) + +--------------------------------------------------------------------- +-- Test 11: Non-strictly positive mutual types should be rejected +--------------------------------------------------------------------- + +/- +type BadA = MkA (BadB -> int) +type BadB = MkB BadA + +BadA has BadB in negative position (left of arrow), which is non-strictly positive +since BadB is in the same mutual block. +-/ + +def mkAConstr : LConstr Unit := {name := "MkA", args := [("f", .arrow (.tcons "BadB" []) .int)], testerName := "isMkA"} +def badATy : LDatatype Unit := {name := "BadA", typeArgs := [], constrs := [mkAConstr], constrs_ne := rfl} + +def mkBConstr : LConstr Unit := {name := "MkB", args := [("a", .tcons "BadA" [])], testerName := "isMkB"} +def badBTy : LDatatype Unit := {name := "BadB", typeArgs := [], constrs := [mkBConstr], constrs_ne := rfl} + +def badMutualBlock : MutualDatatype Unit := [badATy, badBTy] + +/-- info: Error in constructor MkA: Non-strictly positive occurrence of BadB in type (arrow BadB int)-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[badMutualBlock] (Factory.default : @Factory TestParams) (intConst () 0) + +--------------------------------------------------------------------- +-- Test 12: Empty mutual block should be rejected +--------------------------------------------------------------------- + +def emptyBlock : MutualDatatype Unit := [] + +/-- info: Error: Empty mutual block is not allowed -/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[emptyBlock] (IntBoolFactory : @Factory TestParams) (intConst () 0) + +--------------------------------------------------------------------- +-- Test 13: Type reference in wrong order should be rejected +--------------------------------------------------------------------- + +-- Wrapper references List, but List is defined after Wrapper +def wrapperConstr' : LConstr Unit := {name := "MkWrapper", args := [("xs", .tcons "List" [.int])], testerName := "isMkWrapper"} +def wrapperTy' : LDatatype Unit := {name := "Wrapper", typeArgs := [], constrs := [wrapperConstr'], constrs_ne := rfl} + +/-- info: Error in datatype Wrapper, constructor MkWrapper: Undefined type 'List' -/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[[wrapperTy'], [listTy]] (IntBoolFactory : @Factory TestParams) (intConst () 0) + +--------------------------------------------------------------------- +-- Test 14: Type depending on previously defined type should work +--------------------------------------------------------------------- + +-- List is defined before Wrapper - correct order +/-- info: Annotated expression: +#0 + +--- +info: #0 +-/ +#guard_msgs in +#eval format $ + typeCheckAndPartialEval #[[listTy], [wrapperTy']] (IntBoolFactory : @Factory TestParams) (intConst () 0) + +end MutualRecursion end Lambda diff --git a/StrataTest/Languages/Core/DatatypeVerificationTests.lean b/StrataTest/Languages/Core/DatatypeVerificationTests.lean index b404b77df..6618a791b 100644 --- a/StrataTest/Languages/Core/DatatypeVerificationTests.lean +++ b/StrataTest/Languages/Core/DatatypeVerificationTests.lean @@ -98,7 +98,7 @@ def mkProgramWithDatatypes body := body } - let decls := datatypes.map (fun d => Decl.type (.data d) .empty) + let decls := datatypes.map (fun d => Decl.type (.data [d]) .empty) return { decls := decls ++ [Decl.proc proc .empty] } /-! ## Helper for Running Tests -/ @@ -614,5 +614,149 @@ info: "Test 8 - Hidden Type Recursion: PASSED\n Verified 1 obligation(s)\n" #guard_msgs in #eval test8_hiddenTypeRecursion +/-! ## Test 9: Mutually Recursive Datatypes with Havoc -/ + +/-- RoseTree a = Node a (Forest a) -/ +def roseTreeDatatype : LDatatype Visibility := + { name := "RoseTree" + typeArgs := ["a"] + constrs := [ + { name := ⟨"Node", .unres⟩, args := [ + (⟨"nodeVal", .unres⟩, .ftvar "a"), + (⟨"children", .unres⟩, .tcons "Forest" [.ftvar "a"]) + ], testerName := "isNode" } + ] + constrs_ne := by decide } + +/-- Forest a = FNil | FCons (RoseTree a) (Forest a) -/ +def forestDatatype : LDatatype Visibility := + { name := "Forest" + typeArgs := ["a"] + constrs := [ + { name := ⟨"FNil", .unres⟩, args := [], testerName := "isFNil" }, + { name := ⟨"FCons", .unres⟩, args := [ + (⟨"head", .unres⟩, .tcons "RoseTree" [.ftvar "a"]), + (⟨"tail", .unres⟩, .tcons "Forest" [.ftvar "a"]) + ], testerName := "isFCons" } + ] + constrs_ne := by decide } + +/-- +Create a Core program with a mutual block of datatypes. +-/ +def mkProgramWithMutualDatatypes + (mutualBlock : List (LDatatype Visibility)) + (procName : String) + (body : List Statement) + : Except Format Program := do + let proc : Procedure := { + header := { + name := CoreIdent.unres procName + typeArgs := [] + inputs := [] + outputs := [] + } + spec := { + modifies := [] + preconditions := [] + postconditions := [] + } + body := body + } + let decls := [Decl.type (.data mutualBlock) .empty] + return { decls := decls ++ [Decl.proc proc .empty] } + +/-- +Test mutually recursive datatypes (RoseTree/Forest) with havoc. + +mutual + datatype RoseTree a = Node a (Forest a) + datatype Forest a = FNil | FCons (RoseTree a) (Forest a) +end + +procedure testMutualRecursive () { + tree := Node 1 FNil; + havoc tree; + val := nodeVal tree; + assume (val == 42); + assert (isNode tree); + + forest := FNil; + havoc forest; + assume (isFCons forest); + assert (not (isFNil forest)); +} +-/ +def test9_mutualRecursiveWithHavoc : IO String := do + let statements : List Statement := [ + -- Create a tree: Node 1 FNil + Statement.init (CoreIdent.unres "tree") (.forAll [] (LMonoTy.tcons "RoseTree" [.int])) + (LExpr.app () + (LExpr.app () + (LExpr.op () (CoreIdent.unres "Node") + (.some (LMonoTy.arrow .int (LMonoTy.arrow (LMonoTy.tcons "Forest" [.int]) (LMonoTy.tcons "RoseTree" [.int]))))) + (LExpr.intConst () 1)) + (LExpr.op () (CoreIdent.unres "FNil") (.some (LMonoTy.tcons "Forest" [.int])))), + + -- Havoc the tree + Statement.havoc (CoreIdent.unres "tree"), + + -- Extract nodeVal + Statement.init (CoreIdent.unres "val") (.forAll [] LMonoTy.int) + (LExpr.app () + (LExpr.op () (CoreIdent.unres "nodeVal") + (.some (LMonoTy.arrow (LMonoTy.tcons "RoseTree" [.int]) .int))) + (LExpr.fvar () (CoreIdent.unres "tree") (.some (LMonoTy.tcons "RoseTree" [.int])))), + + -- Assume val == 42 + Statement.assume "val_is_42" + (LExpr.eq () + (LExpr.fvar () (CoreIdent.unres "val") (.some .int)) + (LExpr.intConst () 42)), + + -- Assert tree is a Node (always true for RoseTree) + Statement.assert "tree_is_node" + (LExpr.app () + (LExpr.op () (CoreIdent.unres "isNode") + (.some (LMonoTy.arrow (LMonoTy.tcons "RoseTree" [.int]) .bool))) + (LExpr.fvar () (CoreIdent.unres "tree") (.some (LMonoTy.tcons "RoseTree" [.int])))), + + -- Create a forest: FNil + Statement.init (CoreIdent.unres "forest") (.forAll [] (LMonoTy.tcons "Forest" [.int])) + (LExpr.op () (CoreIdent.unres "FNil") (.some (LMonoTy.tcons "Forest" [.int]))), + + -- Havoc the forest + Statement.havoc (CoreIdent.unres "forest"), + + -- Assume forest is FCons + Statement.assume "forest_is_fcons" + (LExpr.app () + (LExpr.op () (CoreIdent.unres "isFCons") + (.some (LMonoTy.arrow (LMonoTy.tcons "Forest" [.int]) .bool))) + (LExpr.fvar () (CoreIdent.unres "forest") (.some (LMonoTy.tcons "Forest" [.int])))), + + -- Assert forest is not FNil + Statement.assert "forest_not_fnil" + (LExpr.app () + (LExpr.op () (CoreIdent.unres "Bool.Not") + (.some (LMonoTy.arrow .bool .bool))) + (LExpr.app () + (LExpr.op () (CoreIdent.unres "isFNil") + (.some (LMonoTy.arrow (LMonoTy.tcons "Forest" [.int]) .bool))) + (LExpr.fvar () (CoreIdent.unres "forest") (.some (LMonoTy.tcons "Forest" [.int]))))) + ] + + match mkProgramWithMutualDatatypes [roseTreeDatatype, forestDatatype] "testMutualRecursive" statements with + | .error err => + return s!"Test 9 - Mutual Recursive with Havoc: FAILED (program creation)\n Error: {err.pretty}" + | .ok program => + runVerificationTest "Test 9 - Mutual Recursive with Havoc" program + +/-- +info: "Test 9 - Mutual Recursive with Havoc: PASSED\n Verified 2 obligation(s)\n" +-/ +#guard_msgs in +#eval test9_mutualRecursiveWithHavoc + end Core.DatatypeVerificationTests diff --git a/StrataTest/Languages/Core/SMTEncoderDatatypeTest.lean b/StrataTest/Languages/Core/SMTEncoderDatatypeTest.lean index 95fdb1c8f..1068aa068 100644 --- a/StrataTest/Languages/Core/SMTEncoderDatatypeTest.lean +++ b/StrataTest/Languages/Core/SMTEncoderDatatypeTest.lean @@ -69,12 +69,15 @@ def treeDatatype : LDatatype Visibility := constrs_ne := by decide } /-- Convert an expression to full SMT string including datatype declarations. +`blocks` is a list of mutual blocks (each block is a list of mutually recursive datatypes). -/ -def toSMTStringWithDatatypes (e : LExpr CoreLParams.mono) (datatypes : List (LDatatype Visibility)) : IO String := do - match Env.init.addDatatypes datatypes with +def toSMTStringWithDatatypeBlocks (e : LExpr CoreLParams.mono) (blocks : List (List (LDatatype Visibility))) : IO String := do + match Env.init.addDatatypes blocks with | .error msg => return s!"Error creating environment: {msg}" | .ok env => - match toSMTTerm env [] e SMT.Context.default with + -- Set the TypeFactory for correct datatype emission ordering + let ctx := SMT.Context.default.withTypeFactory env.datatypes + match toSMTTerm env [] e ctx with | .error err => return err.pretty | .ok (smt, ctx) => -- Emit the full SMT output including datatype declarations @@ -95,6 +98,13 @@ def toSMTStringWithDatatypes (e : LExpr CoreLParams.mono) (datatypes : List (LDa else return "Invalid UTF-8 in output" +/-- +Convert an expression to full SMT string including datatype declarations. +Each datatype is treated as its own (non-mutual) block. +-/ +def toSMTStringWithDatatypes (e : LExpr CoreLParams.mono) (datatypes : List (LDatatype Visibility)) : IO String := + toSMTStringWithDatatypeBlocks e (datatypes.map (fun d => [d])) + /-! ## Test Cases with Guard Messages -/ -- Test 1: Simple datatype (Option) - zero-argument constructor @@ -196,7 +206,7 @@ info: (declare-datatype TestOption (par (α) ( #guard_msgs in #eval format <$> toSMTStringWithDatatypes (.fvar () (CoreIdent.unres "listOfOption") (.some (.tcons "TestList" [.tcons "TestOption" [.int]]))) - [listDatatype, optionDatatype] + [optionDatatype, listDatatype] /-! ## Constructor Application Tests -/ @@ -324,106 +334,12 @@ info: (declare-datatype TestList (par (α) ( (.fvar () (CoreIdent.unres "xs") (.some (.tcons "TestList" [.int])))) [listDatatype] -/-! ## Complex Dependency Topological Sorting Tests -/ - --- Test 16: Very complex dependency graph requiring sophisticated topological sorting --- Dependencies: Alpha -> Beta, Gamma --- Beta -> Delta, Epsilon --- Gamma -> Epsilon, Zeta --- Delta -> Zeta --- Epsilon -> Zeta --- Actual topological order: Zeta, Epsilon, Gamma, Delta, Beta, Alpha - -/-- Alpha = AlphaValue Beta Gamma -/ -def alphaDatatype : LDatatype Visibility := - { name := "Alpha" - typeArgs := [] - constrs := [ - { name := ⟨"AlphaValue", .unres⟩, args := [ - (⟨"Alpha$AlphaValueProj0", .unres⟩, .tcons "Beta" []), - (⟨"Alpha$AlphaValueProj1", .unres⟩, .tcons "Gamma" []) - ], testerName := "Alpha$isAlphaValue" } - ] - constrs_ne := by decide } - -/-- Beta = BetaValue Delta Epsilon -/ -def betaDatatype : LDatatype Visibility := - { name := "Beta" - typeArgs := [] - constrs := [ - { name := ⟨"BetaValue", .unres⟩, args := [ - (⟨"Beta$BetaValueProj0", .unres⟩, .tcons "Delta" []), - (⟨"Beta$BetaValueProj1", .unres⟩, .tcons "Epsilon" []) - ], testerName := "Beta$isBetaValue" } - ] - constrs_ne := by decide } - -/-- Gamma = GammaValue Epsilon Zeta -/ -def gammaDatatype : LDatatype Visibility := - { name := "Gamma" - typeArgs := [] - constrs := [ - { name := ⟨"GammaValue", .unres⟩, args := [ - (⟨"Gamma$GammaValueProj0", .unres⟩, .tcons "Epsilon" []), - (⟨"Gamma$GammaValueProj1", .unres⟩, .tcons "Zeta" []) - ], testerName := "Gamma$isGammaValue" } - ] - constrs_ne := by decide } - -/-- Delta = DeltaValue Zeta -/ -def deltaDatatype : LDatatype Visibility := - { name := "Delta" - typeArgs := [] - constrs := [ - { name := ⟨"DeltaValue", .unres⟩, args := [(⟨"Delta$DeltaValueProj0", .unres⟩, .tcons "Zeta" [])], testerName := "Delta$isDeltaValue" } - ] - constrs_ne := by decide } - -/-- Epsilon = EpsilonValue Zeta -/ -def epsilonDatatype : LDatatype Visibility := - { name := "Epsilon" - typeArgs := [] - constrs := [ - { name := ⟨"EpsilonValue", .unres⟩, args := [(⟨"Epsilon$EpsilonValueProj0", .unres⟩, .tcons "Zeta" [])], testerName := "Epsilon$isEpsilonValue" } - ] - constrs_ne := by decide } - -/-- Zeta = ZetaValue int -/ -def zetaDatatype : LDatatype Visibility := - { name := "Zeta" - typeArgs := [] - constrs := [ - { name := ⟨"ZetaValue", .unres⟩, args := [(⟨"Zeta$ZetaValueProj0", .unres⟩, .int)], testerName := "Zeta$isZetaValue" } - ] - constrs_ne := by decide } - -/-- -info: (declare-datatype Zeta ( - (ZetaValue (Zeta$ZetaValueProj0 Int)))) -(declare-datatype Epsilon ( - (EpsilonValue (Epsilon$EpsilonValueProj0 Zeta)))) -(declare-datatype Gamma ( - (GammaValue (Gamma$GammaValueProj0 Epsilon) (Gamma$GammaValueProj1 Zeta)))) -(declare-datatype Delta ( - (DeltaValue (Delta$DeltaValueProj0 Zeta)))) -(declare-datatype Beta ( - (BetaValue (Beta$BetaValueProj0 Delta) (Beta$BetaValueProj1 Epsilon)))) -(declare-datatype Alpha ( - (AlphaValue (Alpha$AlphaValueProj0 Beta) (Alpha$AlphaValueProj1 Gamma)))) -; alphaVar -(declare-const f0 Alpha) -(define-fun t0 () Alpha f0) --/ -#guard_msgs in -#eval format <$> toSMTStringWithDatatypes - (.fvar () (CoreIdent.unres "alphaVar") (.some (.tcons "Alpha" []))) - [alphaDatatype, betaDatatype, gammaDatatype, deltaDatatype, epsilonDatatype, zetaDatatype] +/-! ## Dependency Order Tests -/ --- Test 17: Diamond dependency pattern +-- Test 16: Diamond dependency pattern -- Dependencies: Diamond -> Left, Right -- Left -> Root -- Right -> Root --- Actual topological order: Root, Right, Left, Diamond (or Root, Left, Right, Diamond) /-- Root = RootValue int -/ def rootDatatype : LDatatype Visibility := @@ -480,7 +396,66 @@ info: (declare-datatype Root ( #guard_msgs in #eval format <$> toSMTStringWithDatatypes (.fvar () (CoreIdent.unres "diamondVar") (.some (.tcons "Diamond" []))) - [diamondDatatype, leftDatatype, rightDatatype, rootDatatype] + [rootDatatype, rightDatatype, leftDatatype, diamondDatatype] + +-- Test 17: Mutually recursive datatypes (RoseTree/Forest) +-- Should emit declare-datatypes with both types together + +/-- RoseTree α = Node α (Forest α) -/ +def roseTreeDatatype : LDatatype Visibility := + { name := "RoseTree" + typeArgs := ["α"] + constrs := [ + { name := ⟨"Node", .unres⟩, args := [ + (⟨"RoseTree$NodeProj0", .unres⟩, .ftvar "α"), + (⟨"RoseTree$NodeProj1", .unres⟩, .tcons "Forest" [.ftvar "α"]) + ], testerName := "RoseTree$isNode" } + ] + constrs_ne := by decide } + +/-- Forest α = FNil | FCons (RoseTree α) (Forest α) -/ +def forestDatatype : LDatatype Visibility := + { name := "Forest" + typeArgs := ["α"] + constrs := [ + { name := ⟨"FNil", .unres⟩, args := [], testerName := "Forest$isFNil" }, + { name := ⟨"FCons", .unres⟩, args := [ + (⟨"Forest$FConsProj0", .unres⟩, .tcons "RoseTree" [.ftvar "α"]), + (⟨"Forest$FConsProj1", .unres⟩, .tcons "Forest" [.ftvar "α"]) + ], testerName := "Forest$isFCons" } + ] + constrs_ne := by decide } + +/-- +info: (declare-datatypes ((RoseTree 1) (Forest 1)) + ((par (α) ((Node (RoseTree$NodeProj0 α) (RoseTree$NodeProj1 (Forest α))))) + (par (α) ((FNil) (FCons (Forest$FConsProj0 (RoseTree α)) (Forest$FConsProj1 (Forest α))))))) +; tree +(declare-const f0 (RoseTree Int)) +(define-fun t0 () (RoseTree Int) f0) +-/ +#guard_msgs in +#eval format <$> toSMTStringWithDatatypeBlocks + (.fvar () (CoreIdent.unres "tree") (.some (.tcons "RoseTree" [.int]))) + [[roseTreeDatatype, forestDatatype]] + +-- Test 19: Mix of mutual and non-mutual datatypes +-- TestOption (non-mutual), then RoseTree/Forest (mutual) +/-- +info: (declare-datatype TestOption (par (α) ( + (None) + (Some (TestOption$SomeProj0 α))))) +(declare-datatypes ((RoseTree 1) (Forest 1)) + ((par (α) ((Node (RoseTree$NodeProj0 α) (RoseTree$NodeProj1 (Forest α))))) + (par (α) ((FNil) (FCons (Forest$FConsProj0 (RoseTree α)) (Forest$FConsProj1 (Forest α))))))) +; optionTree +(declare-const f0 (TestOption (RoseTree Int))) +(define-fun t0 () (TestOption (RoseTree Int)) f0) +-/ +#guard_msgs in +#eval format <$> toSMTStringWithDatatypeBlocks + (.fvar () (CoreIdent.unres "optionTree") (.some (.tcons "TestOption" [.tcons "RoseTree" [.int]]))) + [[optionDatatype], [roseTreeDatatype, forestDatatype]] end DatatypeTests