diff --git a/hkmc2/shared/src/main/scala/hkmc2/Config.scala b/hkmc2/shared/src/main/scala/hkmc2/Config.scala index ce12c08f42..bbbb1d736b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/Config.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/Config.scala @@ -63,9 +63,9 @@ object Config: case class EffectHandlers( debug: Bool, stackSafety: Opt[StackSafety], - // Whether we check `Instantiate` nodes for effects, currently no effect can be raised in a constructor. + // Whether we check `Instantiate` nodes for effects. Currently, effects cannot be raised in constructors. checkInstantiateEffect: Bool = false, - // A debug option that allow codegen to continue even if a unlifted definition is encountered. + // A debug option that allows codegen to continue even if an unlifted definition is encountered. softLifterError: Bool = false ) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index fc6fa76683..2549b2641d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -8,6 +8,7 @@ import utils.* import hkmc2.codegen.* import hkmc2.semantics.* import hkmc2.Message.* +import hkmc2.ScopeData.* import hkmc2.semantics.Elaborator.State import hkmc2.syntax.Tree import hkmc2.codegen.llir.FreshInt @@ -15,8 +16,10 @@ import hkmc2.codegen.llir.FreshInt import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.Map as MutMap import scala.collection.mutable.Set as MutSet +import scala.collection.mutable.ListBuffer object Lifter: + /** * Describes the free variables of a function that have been accessed by its nested definitions. * @param vars The free variables that are accessed by nested classes/functions. @@ -50,7 +53,7 @@ object Lifter: case class AccessInfo( accessed: Set[Local], mutated: Set[Local], - refdDefns: Set[BlockMemberSymbol] + refdDefns: Set[ScopedInfo] ): def ++(that: AccessInfo) = AccessInfo( accessed ++ that.accessed, @@ -67,19 +70,9 @@ object Lifter: mutated.intersect(locals), refdDefns ) - def withoutBms(locals: Set[BlockMemberSymbol]) = AccessInfo( - accessed, - mutated, - refdDefns -- locals - ) - def intersectBms(locals: Set[BlockMemberSymbol]) = AccessInfo( - accessed, - mutated, - refdDefns.intersect(locals) - ) def addAccess(l: Local) = copy(accessed = accessed + l) def addMutated(l: Local) = copy(accessed = accessed + l, mutated = mutated + l) - def addRefdDefn(l: BlockMemberSymbol) = copy(refdDefns = refdDefns + l) + def addRefdScopedObj(l: ScopedInfo) = copy(refdDefns = refdDefns + l) object AccessInfo: val empty = AccessInfo(Set.empty, Set.empty, Set.empty) @@ -97,10 +90,6 @@ object Lifter: case _ => Set.empty - def getVarsBlk(b: Block): Set[Local] = - b.definedVars.collect: - case s: LocalVarSymbol => s - object RefOfBms: def unapply(p: Path): Opt[(BlockMemberSymbol, Opt[DefinitionSymbol[?]])] = p match case Value.Ref(l: BlockMemberSymbol, disamb) => S((l, disamb)) @@ -119,346 +108,106 @@ object Lifter: case c: ClsLikeDefn => (c.companion.isDefined) || (c.k is syntax.Obj) // TODO: refine handling of companions case _ => false - /** * Lifts classes and functions to the top-level. Also automatically rewrites lambdas. * Assumes the input block does not have any `HandleBlock`s. */ -class Lifter()(using State, Raise, Config): +class Lifter(topLevelBlk: Block, handlerPaths: HandlerPaths)(using State, Raise, Config): import Lifter.* - - /** - * The context of the class lifter. One can create an empty context using `LifterCtx.empty`. - * - * @param defns A map from all BlockMemberSymbols to their definitions. - * @param defnsCur All definitions that are nested in the current top level definition. - * @param nestedDefns Definitions which are nested in a given definition (shallow). - * @param usedLocals Describes the locals belonging to each function that are accessed/mutated by nested definitions. - * @param accessInfo Which previously defined variables/definitions could be accessed/modified by a particular definition, - * possibly through calls to other functions or by constructing a class. - * @param ignoredDefns The definitions which must not be lifted. - * @param inScopeDefns Definitions which are in scope to another definition (excluding itself and its nested definitions). - * @param modObjLocals A map from the modules and objects to the local to which it is instantiated after lifting. - * @param localCaptureSyms The symbols in a capture corresponding to a particular local. - * The `VarSymbol` is the parameter in the capture class. - * We used to also store along with it a `BlockMemberSymbol`, the field in the class, but it wasn't used. - * @param prevFnLocals Locals belonging to function definitions that have already been traversed - * @param prevClsDefns Class definitions that have already been traversed, excluding modules - * @param inScopeISyms Inner symbols that are currently in scope (and therefore don't need to be rewritten). - * @param curModules Modules that that we are currently nested in (cleared if we are lifted out) - * @param capturePaths The path to access a particular function's capture in the local scope - * @param bmsReqdInfo The (mutable) captures and (immutable) local variables each function requires - * @param ignoredBmsPaths The path to access a particular BlockMemberSymbol (for definitions which could not be lifted) - * @param localPaths The path to access a particular local (possibly belonging to a previous function) in the current scope - * @param iSymPaths The path to access a particular `innerSymbol` (possibly belonging to a previous class) in the current scope - * @param replacedDefns Ignored (unlifted) definitions that have been rewritten and need to be replaced at the definition site. - * @param firstClsFns Nested functions which are used as first-class functions. - * @param companionMap Map from companion object symbols to the corresponding regular class symbol. - */ - case class LifterCtx private ( - val defns: Map[BlockMemberSymbol, Defn] = Map.empty, - val defnsCur: Set[BlockMemberSymbol] = Set.empty, - val nestedDefns: Map[BlockMemberSymbol, List[Defn]] = Map.empty, - val usedLocals: UsedLocalsMap = UsedLocalsMap(Map.empty), - val accessInfo: Map[BlockMemberSymbol, AccessInfo] = Map.empty, - val ignoredDefns: Set[BlockMemberSymbol] = Set.empty, - val inScopeDefns: Map[BlockMemberSymbol, Set[BlockMemberSymbol]] = Map.empty, - val modObjLocals: Map[BlockMemberSymbol, Local] = Map.empty, - val localCaptureSyms: Map[Local, VarSymbol] = Map.empty, - val prevFnLocals: FreeVars = FreeVars.empty, - val prevClsDefns: List[ClsLikeDefn] = Nil, - val inScopeISyms: Set[InnerSymbol] = Set.empty, - val curModules: List[ClsLikeDefn] = Nil, - val capturePaths: Map[BlockMemberSymbol, LocalPath] = Map.empty, - val bmsReqdInfo: Map[BlockMemberSymbol, LiftedInfo] = Map.empty, // required captures - val ignoredBmsPaths: Map[BlockMemberSymbol, LocalPath] = Map.empty, - val localPaths: Map[Local, LocalPath] = Map.empty, - val isymPaths: Map[InnerSymbol, LocalPath] = Map.empty, - val replacedDefns: Map[BlockMemberSymbol, Defn] = Map.empty, - val firstClsFns: Set[BlockMemberSymbol] = Set.empty, - val companionMap: Map[InnerSymbol, InnerSymbol] = Map.empty, - ): - // gets the function to which a local belongs - def lookup(l: Local) = usedLocals.lookup(l) - - def getCapturePath(b: BlockMemberSymbol) = capturePaths.get(b) - def getLocalClosPath(l: Local) = lookup(l).flatMap(capturePaths.get(_)) - def getLocalCaptureSym(l: Local) = localCaptureSyms.get(l) - def getLocalPath(l: Local) = localPaths.get(l) - def resolveIsymPath(l: InnerSymbol) = getIsymPath(companionMap.getOrElse(l, l)) - def getIsymPath(l: InnerSymbol) = isymPaths.get(l) - def getIgnoredBmsPath(b: BlockMemberSymbol) = ignoredBmsPaths.get(b) - def ignored(b: BlockMemberSymbol) = ignoredDefns.contains(b) - def isModOrObj(b: BlockMemberSymbol) = modObjLocals.contains(b) - def getAccesses(sym: BlockMemberSymbol) = accessInfo(sym) - def isRelevant(sym: BlockMemberSymbol) = defnsCur.contains(sym) - - def addIgnored(defns: Set[BlockMemberSymbol]) = copy(ignoredDefns = ignoredDefns ++ defns) - def withModObjLocals(mp: Map[BlockMemberSymbol, Local]) = copy(modObjLocals = modObjLocals ++ mp) - def withDefns(mp: Map[BlockMemberSymbol, Defn]) = copy(defns = mp) - def withDefnsCur(defns: Set[BlockMemberSymbol]) = copy(defnsCur = defns) - def withNestedDefns(mp: Map[BlockMemberSymbol, List[Defn]]) = copy(nestedDefns = mp) - def withAccesses(mp: Map[BlockMemberSymbol, AccessInfo]) = copy(accessInfo = mp) - def withInScopes(mp: Map[BlockMemberSymbol, Set[BlockMemberSymbol]]) = copy(inScopeDefns = mp) - def withFirstClsFns(fns: Set[BlockMemberSymbol]) = copy(firstClsFns = fns) - def withCompanionMap(mp: Map[InnerSymbol, InnerSymbol]) = copy(companionMap = mp) - def addFnLocals(f: FreeVars) = copy(prevFnLocals = prevFnLocals ++ f) - def addClsDefn(c: ClsLikeDefn) = copy(prevClsDefns = c :: prevClsDefns) - def addLocalCaptureSyms(m: Map[Local, VarSymbol]) = copy(localCaptureSyms = localCaptureSyms ++ m) - def getBmsReqdInfo(sym: BlockMemberSymbol) = bmsReqdInfo.get(sym) - def replCapturePaths(paths: Map[BlockMemberSymbol, LocalPath]) = copy(capturePaths = paths) - def addCapturePath(src: BlockMemberSymbol, path: LocalPath) = copy(capturePaths = capturePaths + (src -> path)) - def addBmsReqdInfo(mp: Map[BlockMemberSymbol, LiftedInfo]) = copy(bmsReqdInfo = bmsReqdInfo ++ mp) - def replLocalPaths(m: Map[Local, LocalPath]) = copy(localPaths = m) - def replIgnoredBmsPaths(m: Map[BlockMemberSymbol, LocalPath]) = copy(ignoredBmsPaths = m) - def replIsymPaths(m: Map[InnerSymbol, LocalPath]) = copy(isymPaths = m) - def addLocalPaths(m: Map[Local, LocalPath]) = copy(localPaths = localPaths ++ m) - def addLocalPath(target: Local, path: LocalPath) = copy(localPaths = localPaths + (target -> path)) - def addIgnoredBmsPaths(m: Map[BlockMemberSymbol, LocalPath]) = copy(ignoredBmsPaths = ignoredBmsPaths ++ m) - def addIsymPath(isym: InnerSymbol, l: LocalPath) = copy(isymPaths = isymPaths + (isym -> l)) - def addIsymPaths(mp: Map[InnerSymbol, LocalPath]) = copy(isymPaths = isymPaths ++ mp) - def addreplacedDefns(mp: Map[BlockMemberSymbol, Defn]) = copy(replacedDefns = replacedDefns ++ mp) - def inModule(defn: ClsLikeDefn) = copy(curModules = defn :: curModules) - def inISym(sym: InnerSymbol) = copy(inScopeISyms = inScopeISyms + sym) - def resetScope = copy(inScopeISyms = Set.empty) - def flushModules = - // called when we are lifted out while in some module, so we need to add the modules' isym paths - copy(curModules = Nil).addIsymPaths(curModules.map(d => d.isym -> LocalPath.Sym(d.sym)).toMap) - object LifterCtx: - def empty = LifterCtx() - def withLocals(u: UsedLocalsMap) = empty.copy(usedLocals = u) - + val handlerSyms: Set[Symbol] = Set(State.nonLocalRet, State.effectSigSymbol) + + extension (l: Local) + def asLocalPath: LocalPath = LocalPath.Sym(l) + def asDefnRef: DefnRef = DefnRef.Sym(l) + enum LocalPath: case Sym(l: Local) - case PubField(isym: DefinitionSymbol[? <: ClassLikeDef] & InnerSymbol, sym: BlockMemberSymbol) + case BmsRef(l: BlockMemberSymbol, d: DefinitionSymbol[?]) + case InCapture(capturePath: Path, field: TermSymbol) - def read = this match + def read(using ctx: LifterCtxNew): Path = this match case Sym(l) => l.asPath - case PubField(isym, sym) => Select(isym.asPath, Tree.Ident(sym.nme))(N) + case BmsRef(l, d) => Value.Ref(l, S(d)) + case InCapture(path, field) => Select(path, field.id)(S(field)) - def asArg = read.asArg + def asArg(using ctx: LifterCtxNew) = read.asArg - def assign(value: Result, rest: Block) = this match + def assign(value: Result, rest: Block)(using ctx: LifterCtxNew): Block = this match case Sym(l) => Assign(l, value, rest) - case PubField(isym, sym) => AssignField(isym.asPath, Tree.Ident(sym.nme), value, rest)(S(sym)) - - def readDisamb(d: Opt[DefinitionSymbol[?]]) = this match - case Sym(l) => Value.Ref(l, d) - case PubField(isym, sym) => Select(isym.asPath, Tree.Ident(sym.nme))(d) - - + case BmsRef(l, d) => lastWords("Tried to assign to a BlockMemberSymbol") + case InCapture(path, field) => AssignField(path, field.id, value, rest)(S(field)) - val ignoredSet = Set(State.runtimeSymbol.asPath.selSN("NonLocalReturn")) - - def isIgnoredPath(p: Path) = ignoredSet.contains(p) + enum DefnRef: + case Sym(l: Local) + case InScope(l: BlockMemberSymbol, d: DefinitionSymbol[?]) + case Field(isym: InnerSymbol, l: BlockMemberSymbol, d: DefinitionSymbol[?]) - /** - * Creates a capture class for a function consisting of its mutable (and possibly immutable) local variables. - * @param f The function to create the capture class for. - * @param ctx The lifter context. Determines which variables will be captured. - * @return The triple (defn, varsMap, varsList), where `defn` is the capture class's definition, - * `varsMap` maps the function's locals to the corresponding `VarSymbol` (for the class parameters), and - * `varsList` specifies the order of these variables in the class's constructor. - */ - def createCaptureCls(f: FunDefn, ctx: LifterCtx) - : (ClsLikeDefn, Map[Local, VarSymbol], List[Local]) = - val nme = f.sym.nme + "$capture" - - val clsSym = ClassSymbol( - Tree.DummyTypeDef(syntax.Cls), - Tree.Ident(nme) - ) - - val FreeVars(_, cap) = ctx.usedLocals(f.sym) - - val fresh = FreshInt() - - val sortedVars = cap.toArray.sortBy(_.uid).map: sym => - val id = fresh.make - val nme = sym.nme + "$capture$" + id - - val ident = new Tree.Ident(nme) - val varSym = VarSymbol(ident) - val fldSym = BlockMemberSymbol(nme, Nil) - val tSym = TermSymbol(syntax.MutVal, S(clsSym), ident) - - val p = Param(FldFlags.empty.copy(isVal = true), varSym, N, Modulefulness.none) - varSym.decl = S(p) // * Currently this is only accessed to create the class' toString method - - val vd = ValDefn( - tSym, - fldSym, - Value.Ref(varSym) - ) - - (sym -> varSym, p, vd) - - val defn = ClsLikeDefn( - None, clsSym, BlockMemberSymbol(nme, Nil), - S(TermSymbol(syntax.Fun, S(clsSym), clsSym.id)), - syntax.Cls, - N, - PlainParamList(sortedVars.iterator.map(_._2).toList) :: Nil, None, Nil, Nil, - Nil, - End(), - sortedVars.iterator.foldLeft[Block](End()): - case (acc, (_, _, vd)) => Define(vd, acc), - N, - N, - ) + def read(using ctx: LifterCtxNew): Path = this match + case Sym(l) => l.asPath + case InScope(l, d) => Value.Ref(l, S(d)) + case Field(isym, l, d) => Select(ctx.symbolsMap(isym).read, Tree.Ident(l.nme))(S(d)) - (defn, sortedVars.iterator.map(_._1).toMap, sortedVars.iterator.map(_._1._1).toList) - - private val innerSymCache: MutMap[Local, Set[Local]] = MutMap.empty + def asArg(using ctx: LifterCtxNew) = read.asArg - /** - * Gets the inner symbols referenced within a class (including those within a member symbol). - * @param c The class from which to get the inner symbols. - * @return The inner symbols reference within a class. - */ - def getInnerSymbols(c: Defn) = - val sym = c match - case f: FunDefn => f.sym - case c: ClsLikeDefn => c.isym - case _ => wat("unreachable", c.sym) - - def create: Set[Local] = c.freeVars.collect: - case s: InnerSymbol => s - case t: TermSymbol if t.owner.isDefined => t.owner.get - - innerSymCache.getOrElseUpdate(sym, create) - - /** - * Determines whether a certain class's `this` needs to be captured by a class being lifted. - * @param captureCls The class in question that is considered for capture. - * @param liftDefn The class being lifted. - * @return Whether the class needs to be captured. - */ - private def needsClsCapture(captureCls: ClsLikeDefn, liftDefn: Defn) = - getInnerSymbols(liftDefn).contains(captureCls.isym) - - /** - * Determines whether a certain function's mutable closure needs to be captured by a definition being lifted. - * @param captureFn The function in question that is considered for capture. - * @param liftDefn The definition being lifted. - * @return Whether the function needs to be captured. - */ - private def needsCapture(captureFn: FunDefn, liftDefn: Defn, ctx: LifterCtx) = - val candVars = liftDefn.freeVars - val captureFnVars = ctx.usedLocals(captureFn.sym).reqCapture.toSet - !candVars.intersect(captureFnVars).isEmpty - - /** - * Gets the immutable local variables of a function that need to be captured by a definition being lifted. - * @param captureFn The function in question whose local variables need to be captured. - * @param liftDefn The definition being lifted. - * @return The local variables that need to be captured. - */ - private def neededImutLocals(captureFn: FunDefn, liftDefn: Defn, ctx: LifterCtx) = - val candVars = liftDefn.freeVars - val captureFnVars = ctx.usedLocals(captureFn.sym) - val mutVars = captureFnVars.reqCapture.toSet - val imutVars = captureFnVars.vars - imutVars.filter: s => - !mutVars.contains(s) && candVars.contains(s) - case class FunSyms[T <: DefinitionSymbol[?]](b: BlockMemberSymbol, d: T): def asPath = Value.Ref(b, S(d)) object FunSyms: def fromFun(b: BlockMemberSymbol, owner: Opt[InnerSymbol] = N) = FunSyms(b, TermSymbol.fromFunBms(b, owner)) - // Info required for lifting a definition. - case class LiftedInfo( - val reqdCaptures: List[BlockMemberSymbol], // The mutable captures a lifted definition must take. - val reqdVars: List[Local], // The (passed by value) variables a lifted definition must take. - val reqdInnerSyms: List[InnerSymbol], // The inner symbols a lifted definition must take. - val reqdBms: List[BlockMemberSymbol], // BMS's belonging to unlifted definitions that this definition references. - val fakeCtorBms: Option[FunSyms[TermSymbol]], // only for classes - val singleCallBms: FunSyms[TermSymbol], // optimization - ) - - case class Lifted[+T <: Defn]( - val liftedDefn: T, - val extraDefns: List[Defn], - ) - - private case class LifterMetadata( - unliftable: Set[BlockMemberSymbol], - modules: List[ClsLikeDefn], - objects: List[ClsLikeDefn], - firstClsFns: Set[BlockMemberSymbol] - ) + type ClsLikeSym = DefinitionSymbol[? <: ClassDef | ModuleOrObjectDef] + type ClsSym = DefinitionSymbol[? <: ClassLikeDef] + type ModuleOrObjSym = DefinitionSymbol[? <: ModuleOrObjectDef] + + case class LifterMetadata( + unliftable: Set[ClsSym | ModuleOrObjSym], + modules: Set[ModuleOrObjSym], + firstClsFns: Set[TermSymbol] + ): + def ++(that: LifterMetadata) = + LifterMetadata(unliftable ++ that.unliftable, modules ++ that.modules, firstClsFns ++ that.firstClsFns) + object LifterMetadata: + def empty = LifterMetadata(Set.empty, Set.empty, Set.empty) // d is a top-level definition // returns (ignored classes, modules, objects) - private def createMetadata(d: Defn, ctx: LifterCtx): LifterMetadata = - var ignored: Set[BlockMemberSymbol] = Set.empty - var firstClsFns: Set[BlockMemberSymbol] = Set.empty - var unliftable: Set[BlockMemberSymbol] = Set.empty - var clsSymToBms: Map[Local, BlockMemberSymbol] = Map.empty - var modules: List[ClsLikeDefn] = Nil - var objects: List[ClsLikeDefn] = Nil - var extendsGraph: Set[(BlockMemberSymbol, BlockMemberSymbol)] = Set.empty - - d match - case c @ ClsLikeDefn(k = syntax.Mod) => modules ::= c - case c @ ClsLikeDefn(k = syntax.Obj) => objects ::= c - case _ => () - - // search for modules - new BlockTraverser: - applyDefn(d) - override def applyDefn(defn: Defn): Unit = - if defn === d then - super.applyDefn(defn) - else - defn match - case c: ClsLikeDefn => - clsSymToBms += c.isym -> c.sym - - if c.companion.isDefined then // TODO: refine handling of companions - raise(WarningReport( - msg"Modules are not yet lifted." -> N :: Nil, - N, Diagnostic.Source.Compilation - )) - modules ::= c - ignored += c.sym - else if c.k is syntax.Obj then - objects ::= c - case _ => () - super.applyDefn(defn) + private def createMetadata(s: ScopeNode): LifterMetadata = + var ignored: Set[ClsSym | ModuleOrObjSym] = Set.empty + var firstClsFns: Set[TermSymbol] = Set.empty + val nestedScopeNodes: List[ScopeNode] = s.allChildNodes + val nestedScopes: Set[ScopedInfo] = nestedScopeNodes.map(_.obj.toInfo).toSet - s.obj.toInfo + + // hack: ClassLikeSymbol does not extend DefinitionSymbol directly, so we must + // use a map to convert - // search for defns nested within a top-level module, which are unnecessary to lift - def inModuleDefns(d: Defn): Set[BlockMemberSymbol] = - val nested = ctx.nestedDefns(d.sym) - nested.map(_.sym).toSet ++ nested.flatMap: nested => - if modules.contains(nested.sym) then inModuleDefns(nested) else Set.empty + val moduleObjs = nestedScopeNodes.collect: + case s @ ScopeNode(obj = o: ScopedObject.Companion) if !s.isTopLevel => o - val isMod = d match - case c: ClsLikeDefn => c.companion.isDefined // TODO: refine handling of companions - case _ => false + // TODO: refine handling of companions + for m <- moduleObjs do + ignored += m.par.isym + ignored += m.comp.isym + raise(WarningReport( + msg"Modules are not yet lifted." -> m.comp.isym.toLoc :: Nil, + N, Diagnostic.Source.Compilation + )) - val inModTopLevel = if isMod then inModuleDefns(d) else Set.empty - ignored ++= inModTopLevel + val modules: Set[ModuleOrObjSym] = moduleObjs.map(_.comp.isym).toSet + var extendsGraph: Set[(ClsSym, ClsSym)] = Set.empty // search for unliftable classes and build the extends graph - val clsSyms = clsSymToBms.values.toSet new BlockTraverser: - applyDefn(d) + this.applyScopedObject(s.obj) override def applyCase(cse: Case): Unit = cse match - case Case.Cls(cls, path) => - clsSymToBms.get(cls) match - case Some(value) if !ignored.contains(value) => // don't generate a warning if it's already ignored + case Case.Cls(cls: (ClassSymbol | ModuleOrObjectSymbol), _) => + if nestedScopes.contains(cls) && !ignored.contains(cls) && !data.getNode(cls).isModOrTopLevel then // don't generate a warning if it's already ignored raise(WarningReport( - msg"Cannot yet lift class/module `${value.nme}` as it is used in an instance check." -> N :: Nil, + msg"Cannot yet lift class/module `${cls.nme}` as it is used in an instance check." -> N :: Nil, N, Diagnostic.Source.Compilation )) - ignored += value - unliftable += value - case _ => () + ignored += cls case _ => () override def applyResult(r: Result): Unit = r match @@ -485,18 +234,18 @@ class Lifter()(using State, Raise, Config): // If B extends A, then A -> B is an edge parentPath match case None => () - case Some(path) if isIgnoredPath(path) => () - case Some(Select(RefOfBms(s, _), Tree.Ident("class"))) => - if clsSyms.contains(s) then extendsGraph += (s -> defn.sym) - case Some(RefOfBms(s, _)) => - if clsSyms.contains(s) then extendsGraph += (s -> defn.sym) - case _ if !ignored.contains(defn.sym) => + // for now, allow selecting runtime symbols + case Some(Select(qual = Value.Ref(l, _))) if State.runtimeSymbol is l => () + case Some(RefOfBms(_, S(s: ClassSymbol))) if handlerSyms.contains(s) => () + case Some(RefOfBms(s, _)) if handlerSyms.contains(s) => () + case Some(RefOfBms(_, S(s: ClassSymbol))) => + if nestedScopes.contains(s) then extendsGraph += (s -> isym) + case _ if !ignored.contains(isym) => raise(WarningReport( msg"Cannot yet lift definition `${sym.nme}` as it extends an expression." -> N :: Nil, N, Diagnostic.Source.Compilation )) - ignored += defn.sym - unliftable += defn.sym + ignored += isym case _ => () paramsOpt.foreach(applyParamList) auxParams.foreach(applyParamList) @@ -513,25 +262,29 @@ class Lifter()(using State, Raise, Config): case _ => false override def applyValue(v: Value): Unit = v match - case RefOfBms(l, _) if clsSyms.contains(l) && !modOrObj(ctx.defns(l)) => - raise(WarningReport( - msg"Cannot yet lift class `${l.nme}` as it is used as a first-class class." -> N :: Nil, - N, Diagnostic.Source.Compilation - )) - ignored += l - unliftable += l - case RefOfBms(l, _) if ctx.defns.contains(l) && isFun(ctx.defns(l)) => - // naked reference to a function definition - firstClsFns += l + case RefOfBms(_, S(l)) if nestedScopes.contains(l) => data.getNode(l).obj match + case c: ScopedObject.Class if c.isObj => () + case c: (ScopedObject.Class | ScopedObject.ClassCtor) => + if !c.node.get.isModOrTopLevel then + raise(WarningReport( + msg"Cannot yet lift class `${l.nme}` as it is used as a first-class class." -> N :: Nil, + N, Diagnostic.Source.Compilation + )) + val isym = c match + case c: ScopedObject.Class => c.cls.isym + case c: ScopedObject.ClassCtor => c.cls.isym + ignored += isym + case ScopedObject.Func(fun, isMethod) => firstClsFns += fun.dSym + case _ => super.applyValue(v) case _ => super.applyValue(v) // analyze the extends graph val extendsEdges = extendsGraph.groupBy(_._1).map: case (a, bs) => a -> bs.map(_._2) .toMap - var newUnliftable: Set[BlockMemberSymbol] = Set.empty + var newUnliftable: Set[ClsSym] = Set.empty // dfs starting from unliftable classes - def dfs(s: BlockMemberSymbol): Unit = + def dfs(s: ClsSym): Unit = for edges <- extendsEdges.get(s) b <- edges if !newUnliftable.contains(b) && !ignored.contains(b) @@ -542,84 +295,15 @@ class Lifter()(using State, Raise, Config): )) newUnliftable += b dfs(b) - for s <- ignored do + for case s: ClsLikeSym <- ignored do dfs(s) - LifterMetadata(ignored ++ newUnliftable, modules.toList, objects.toList, firstClsFns) - - extension (b: Block) - private def floatOut(ctx: LifterCtx) = - b.extractDefns(preserve = defn => ctx.isModOrObj(defn.sym) || ctx.ignored(defn.sym)) - private def gather(ctx: LifterCtx) = - b.gatherDefns(preserve = defn => ctx.isModOrObj(defn.sym) || ctx.ignored(defn.sym)) - - - def createLiftInfoCont(d: Defn, parentCls: Opt[ClsLikeDefn], ctx: LifterCtx): Map[BlockMemberSymbol, LiftedInfo] = - val AccessInfo(accessed, _, refdDefns) = ctx.getAccesses(d.sym) - - val inScopeRefs = refdDefns.intersect(ctx.inScopeDefns(d.sym)) - - val includedCaptures = ctx.prevFnLocals.reqCapture - .intersect(accessed) - .map(sym => ctx.lookup(sym).get) - .toList.sortBy(_.uid) - - val refMod = inScopeRefs.intersect(ctx.modObjLocals.keySet) - val includedLocals = ((accessed -- ctx.prevFnLocals.reqCapture) ++ refMod).toList.sortBy(_.uid) - val clsCaptures: List[InnerSymbol] = ctx.prevClsDefns.map(_.isym) - val refBms = inScopeRefs.intersect(ctx.ignoredDefns).toList.sortBy(_.uid) - - val isModLocal = d match - case c: ClsLikeDefn if modOrObj(c) && !ctx.ignored(c.sym) => true - case _ => false - - if ctx.ignored(d.sym) || - (includedCaptures.isEmpty && includedLocals.isEmpty && clsCaptures.isEmpty && refBms.isEmpty) then - d match - case f: FunDefn => - createLiftInfoFn(f, ctx) - case c: ClsLikeDefn => - createLiftInfoCls(c, ctx) - case _ => Map.empty - else - val fakeCtorBms = d match - case c: ClsLikeDefn if !isModLocal => S(BlockMemberSymbol(d.sym.nme + "$ctor", Nil)) - case _ => N - - val singleCallBms = BlockMemberSymbol(d.sym.nme + "$", Nil) - - val info = LiftedInfo( - includedCaptures, includedLocals, clsCaptures, - refBms, fakeCtorBms.map(FunSyms.fromFun(_)), FunSyms.fromFun(singleCallBms) - ) - - d match - case f: FunDefn => - createLiftInfoFn(f, ctx) + (d.sym -> info) - case c: ClsLikeDefn => - createLiftInfoCls(c, ctx) + (d.sym -> info) - case _ => Map.empty - - def createLiftInfoFn(f: FunDefn, ctx: LifterCtx): Map[BlockMemberSymbol, LiftedInfo] = - val defns = ctx.nestedDefns(f.sym) - defns.flatMap(createLiftInfoCont(_, N, ctx.addFnLocals(ctx.usedLocals(f.sym)))).toMap - - def createLiftInfoCls(c: ClsLikeDefn, ctx: LifterCtx): Map[BlockMemberSymbol, LiftedInfo] = - val defns = c.preCtor.gather(ctx) ++ c.ctor.gather(ctx) ++ c.companion.fold(Nil)(_.ctor.gather(ctx)) - val newCtx = if (c.companion.isDefined) && !ctx.ignored(c.sym) then ctx else ctx.addClsDefn(c) - val staticMtdInfo = c.companion.fold(Map.empty): - case value => value.methods.flatMap(f => createLiftInfoFn(f, newCtx)) - - defns.flatMap(f => createLiftInfoCont(f, S(c), newCtx)).toMap - ++ c.methods.flatMap(f => createLiftInfoFn(f, newCtx)) - ++ staticMtdInfo + LifterMetadata(ignored ++ newUnliftable, modules, firstClsFns) // This rewrites code so that it's valid when lifted to the top level. // This way, no piece of code must be traversed by a BlockRewriter more than once. // Remark: This is why so much prior analysis is needed and is the main source of complexity in the lifter. - class BlockRewriter(inScopeIsyms: Set[InnerSymbol], ctx: LifterCtx) extends BlockTransformerShallow(SymbolSubst()): - def iSymInScope(l: InnerSymbol) = inScopeIsyms.contains(l) - + class BlockRewriter(using ctx: LifterCtxNew) extends ScopeRewriter: // Closure symbols that point to an initialized closure in this scope var activeClosures: Set[Local] = Set.empty // Map from block member symbols to initialized closures @@ -637,6 +321,9 @@ class Lifter()(using State, Raise, Config): // instantiations are rewritten. // // Does *not* rewrite references to non-lifted BMS symbols. + // + // References to methods and unlifted classes nested inside classes/modules are + // always rewritten using `this.defnName` (when accessed internally) or `object.defnName`. def rewriteBms(b: Block) = // BMS's that need to be created val syms: LinkedHashMap[FunSyms[?], Local] = LinkedHashMap.empty @@ -645,53 +332,83 @@ class Lifter()(using State, Raise, Config): val walker = new BlockDataTransformer(SymbolSubst()): // only scan within the block. don't traverse - override def applyResult(r: Result)(k: Result => Block): Block = r match + def resolveDefnRef(l: BlockMemberSymbol, d: DefinitionSymbol[?], r: RewrittenScope[?]) = + ctx.defnsMap.get(d) match + case Some(defnRef) => S(defnRef.read) + case None => r.obj match + case c: ScopedObject.Class if c.isObj => + S(ctx.symbolsMap(c.cls.isym).read) + case r: ScopedObject.Referencable[?] => + // rewrite the parent + r.owner.flatMap(ctx.symbolsMap.get(_)) match + case Some(value) => + S(Select(value.read, Tree.Ident(l.nme))(S(d))) + case None => N + case _ => N + + override def applyResult(r: Result)(k: Result => Block): Block = + r match // if possible, directly rewrite the call using the efficient version - case c @ Call(RefOfBms(l, S(d)), args) => ctx.bmsReqdInfo.get(l) match - case Some(info) if !ctx.isModOrObj(l) => - val extraArgs = ctx.defns.get(l) match - // If it's a class, we need to add the isMut parameter. - // Instantiation without `new mut` is always immutable - case Some(c: ClsLikeDefn) => Value.Lit(Tree.BoolLit(false)).asArg :: getCallArgs(FunSyms(l, d), ctx) - case _ => getCallArgs(FunSyms(l, d), ctx) - applyListOf(args, applyArg(_)(_)): newArgs => - k(Call(info.singleCallBms.asPath, extraArgs ++ newArgs)(c.isMlsFun, c.mayRaiseEffects, c.explicitTailCall)) - case _ => super.applyResult(r)(k) - case c @ Instantiate(mut, InstSel(l, S(d)), args) => - ctx.bmsReqdInfo.get(l) match - case Some(info) if !ctx.isModOrObj(l) => - val extraArgs = Value.Lit(Tree.BoolLit(mut)).asArg :: getCallArgs(FunSyms(l, d), ctx) - applyListOf(args, applyArg(_)(_)): newArgs => - k(Call(info.singleCallBms.asPath, extraArgs ++ newArgs)(true, config.checkInstantiateEffect, false)) - case _ => super.applyResult(r)(k) - // LEGACY CODE: We previously directly created the closure and assigned it to the - // variable here. But, since this closure may be re-used later, this doesn't work - // in general, so we will always create a TempSymbol for it. - // case RefOfBms(l) if ctx.bmsReqdInfo.contains(l) && !ctx.isModOrObj(l) => - // createCall(l, ctx) + case c @ Call(RefOfBms(l, S(d)), args) => + ctx.rewrittenScopes.get(d) match + case N => super.applyResult(r)(k) // external call, or have not yet traversed that function + case S(r) => + applyArgs(args): newArgs => + def join2: Block = + resolveDefnRef(l, d, r) match + case Some(value) => k(c.copy(fun = value, args = newArgs)(c.isMlsFun, c.mayRaiseEffects, c.explicitTailCall)) + case None => super.applyResult(c)(k) + r match + // function call + case f: LiftedFunc => k(f.rewriteCall(c, newArgs)) + // ctor call (without using `new`) + case ctor: RewrittenClassCtor => ctor.getRewrittenCls match + case cls: LiftedClass => + k(cls.rewriteCall(c, newArgs)) + case _ => join2 + case _ => join2 + case inst @ Instantiate(mut, RefOfBms(l, S(d)), args) => + applyArgs(args): newArgs => + def join = + if args is newArgs then inst + else inst.copy(args = newArgs) + val res = ctx.rewrittenScopes.get(d) match + case N => join + case S(c: LiftedClass) => c.rewriteInstantiate(inst, newArgs) + case S(r) => resolveDefnRef(l, d, r) match + case Some(value) => Instantiate(inst.mut, value, newArgs) + case None => join + k(res) case _ => super.applyResult(r)(k) // extract the call - override def applyPath(p: Path)(k: Path => Block): Block = - p match - case RefOfBms(l, S(d)) if ctx.bmsReqdInfo.contains(l) && !ctx.isModOrObj(l) => - val newSym = closureMap.get(l) match - case None => - // $this was previously used, but it may be confused with the `this` keyword - // let's use $here instead - val newSym = TempSymbol(N, l.nme + "$here") - extraLocals.add(newSym) - syms.addOne(FunSyms(l, d) -> newSym) // add to `syms`: this closure will be initialized in `applyBlock` - closureMap.addOne(l -> newSym) // add to `closureMap`: `newSym` refers to the closure and can be used later - newSym - - // symbol exists, and is initialized - case Some(value) if activeClosures.contains(value) => value - // symbol exists, needs initialization - case Some(value) => - syms.addOne(FunSyms(l, d) -> value) - value - k(Value.Ref(newSym, S(d))) + override def applyPath(p: Path)(k: Path => Block): Block = p match + case r @ RefOfBms(l, S(d)) => ctx.rewrittenScopes.get(d) match + case S(f: LiftedFunc) => + if f.isTrivial then k(r) + else + val newSym = closureMap.get(l) match + case None => + val newSym = TempSymbol(N, l.nme + "$here") + extraLocals.add(newSym) + syms.addOne(FunSyms(l, d) -> newSym) // add to `syms`: this closure will be initialized in `applyBlock` + closureMap.addOne(l -> newSym) // add to `closureMap`: `newSym` refers to the closure and can be used later + newSym + + // symbol exists, and is initialized + case Some(value) if activeClosures.contains(value) => value + // symbol exists, needs initialization + case Some(value) => + syms.addOne(FunSyms(l, d) -> value) + value + k(Value.Ref(newSym, N)) + + // Other naked references to BlockMemberSymbols. + case S(r) => resolveDefnRef(l, d, r) match + case Some(value) => k(value) + case None => super.applyPath(p)(k) + case _ => super.applyPath(p)(k) + case _ => super.applyPath(p)(k) (walker.applyBlock(b), syms.toList, extraLocals) end rewriteBms @@ -702,17 +419,16 @@ class Lifter()(using State, Raise, Config): activeClosures = curActive ret - override def applyBlock(b: Block): Block = + override def applyBlock(b: Block): Block = // extract references to BlockMemberSymbols in the block which now may // need to be enriched with aux parameters val (rewritten, syms, extras) = rewriteBms(b) extraLocals.addAll(extras) val pre = syms.foldLeft(blockBuilder): - case (blk, (bms, local)) => - val initial = blk.assign(local, createCall(bms, ctx)) - ctx.defns(bms.b) match - case c: ClsLikeDefn => initial.assignFieldN(local.asPath, Tree.Ident("class"), bms.asPath) - case _ => initial + case (blk, (funSym, local)) => + ctx.liftedScopes(funSym.d) match + case l: LiftedFunc => blk.assign(local, l.rewriteRef) + case _ => die // Rewrite the rest val remaining = rewritten match @@ -736,607 +452,866 @@ class Lifter()(using State, Raise, Config): if (scrut2 is scrut) && (arms2 is arms) && (dflt2 is dflt) && (rst2 is rst) - then b else Match(scrut2, arms2, dflt2, rst2) + then rewritten else Match(scrut2, arms2, dflt2, rst2) - case Label(lbl, loop, bod, rst) => + case Label(lbl, false, bod, rst) => val lbl2 = lbl.subst val bod2 = applySubBlockAndReset(bod) val rst2 = applySubBlock(rst) - if (lbl2 is lbl) && (bod2 is bod) && (rst2 is rst) then b else Label(lbl2, loop, bod2, rst2) + if (lbl2 is lbl) && (bod2 is bod) && (rst2 is rst) then rewritten else Label(lbl2, false, bod2, rst2) case TryBlock(sub, fin, rst) => val sub2 = applySubBlockAndReset(sub) val fin2 = applySubBlockAndReset(fin) val rst2 = applySubBlock(rst) - if (sub2 is sub) && (fin2 is fin) && (rst2 is rst) then b else TryBlock(sub2, fin2, rst2) - - // Detect private field usages - case Assign(t: TermSymbol, rhs, rest) if t.owner.isDefined => - ctx.resolveIsymPath(t.owner.get) match - case Some(value) if !iSymInScope(t.owner.get) => - if (t.k is syntax.LetBind) && !t.owner.forall(_.isInstanceOf[semantics.TopLevelSymbol]) then - // TODO: improve the error message - raise(ErrorReport( - msg"Uses of private fields cannot yet be lifted." -> N :: Nil, - N, Diagnostic.Source.Compilation - )) - applyResult(rhs): newRhs => - AssignField(value.read, t.id, newRhs, applyBlock(rest))(N) - case _ => super.applyBlock(rewritten) + if (sub2 is sub) && (fin2 is fin) && (rst2 is rst) then rewritten else TryBlock(sub2, fin2, rst2) // Assignment to variables - case Assign(lhs, rhs, rest) => ctx.getLocalCaptureSym(lhs) match - case Some(captureSym) => - applyResult(rhs): newRhs => - AssignField(ctx.getLocalClosPath(lhs).get.read, captureSym.id, newRhs, applyBlock(rest))(N) - case None => ctx.getLocalPath(lhs) match - case None => super.applyBlock(rewritten) - case Some(value) => - applyResult(rhs): newRhs => - value.assign(newRhs, applyBlock(rest)) - - // rewrite ValDefns (in ctors) - case Define(d: ValDefn, rest: Block) if d.owner.isDefined => - ctx.getIsymPath(d.owner.get) match - case Some(value) if !iSymInScope(d.owner.get) => - applyResult(d.rhs): newRhs => - AssignField(value.read, Tree.Ident(d.sym.nme), newRhs, applyBlock(rest))(S(d.sym)) - case _ => super.applyBlock(rewritten) + case Assign(lhs, rhs, rest) => ctx.symbolsMap.get(lhs) match + case Some(path) => applyResult(rhs): rhs2 => + path.assign(rhs2, applySubBlock(rest)) + case _ => super.applyBlock(rewritten) - // rewrite object definitions, assigning to the given symbol in modObjLocals - case Define(d: ClsLikeDefn, rest: Block) => ctx.modObjLocals.get(d.sym) match - case Some(sym) if !ctx.ignored(d.sym) => ctx.getBmsReqdInfo(d.sym) match - case Some(_) => // has args - extraLocals.add(sym) - blockBuilder - .assign(sym, Instantiate(mut = false, d.sym.asPath, getCallArgs(FunSyms(d.sym, d.isym), ctx))) - .rest(applyBlock(rest)) - case None => // has no args - // Objects with no parameters are instantiated statically - blockBuilder - .assign(sym, d.sym.asPath) - .rest(applyBlock(rest)) - case _ => ctx.replacedDefns.get(d.sym) match - case Some(value) => Define(value, applyBlock(rest)) - case None => super.applyBlock(rewritten) - + // rewrite object definitions, assigning to the saved symbol + case Define(d @ ClsLikeDefn(k = syntax.Obj), rest: Block) => ctx.liftedScopes.get(d.isym) match + case Some(l: LiftedClass) if l.obj.isObj => + ctx.symbolsMap(l.cls.isym).assign(l.instObject, applySubBlock(rest)) + case _ => super.applyBlock(rewritten) case _ => super.applyBlock(rewritten) pre.rest(remaining) - override def applyPath(p: Path)(k: Path => Block): Block = - p match - // These two cases rewrites `this.whatever` when referencing an outer class's fields. - case Value.Ref(l: InnerSymbol, _) => - ctx.resolveIsymPath(l) match - case Some(value) if !iSymInScope(l) => k(value.read) + override def applyPath(p: Path)(k: Path => Block): Block = p match + // This rewrites naked references to locals, + case Value.Ref(l, _) => ctx.symbolsMap.get(l) match + case Some(value) => k(value.read) case _ => super.applyPath(p)(k) - case Value.Ref(t: TermSymbol, _) if t.owner.isDefined => - ctx.resolveIsymPath(t.owner.get) match - case Some(value) if !iSymInScope(t.owner.get) => - if (t.k is syntax.LetBind) && !t.owner.forall(_.isInstanceOf[semantics.TopLevelSymbol]) then - // TODO: improve the error message - raise(ErrorReport( - msg"Uses of private fields cannot yet be lifted." -> N :: Nil, - N, Diagnostic.Source.Compilation - )) - k(Select(value.read, t.id)(N)) - case _ => super.applyPath(p)(k) - - // For objects inside classes: When an object is nested inside a class, its defn will be - // replaced by a symbol, to which the object instance is assigned. This rewrites references - // from the objects BlockMemberSymbol to that new symbol. - case s @ Select(qual, ident) => - s.symbol.flatMap(ctx.getLocalPath) match - case Some(LocalPath.Sym(value: DefinitionSymbol[?])) => - k(Select(qual, Tree.Ident(value.nme))(S(value))) - case _ => super.applyPath(p)(k) - - // This is to rewrite references to classes that are not lifted (when their BlockMemberSymbol - // reference is passed as function parameters). - case RefOfBms(l, disamb) if ctx.ignored(l) && ctx.isRelevant(l) => ctx.getIgnoredBmsPath(l) match - case Some(value) => k(value.readDisamb(disamb)) - case None => super.applyPath(p)(k) - // This rewrites naked references to locals. If a function is in a capture, then we select that value - // from the capture; otherwise, we see if that local is passed directly as a parameter to this defn. - case Value.Ref(l, _) => ctx.getLocalCaptureSym(l) match - case Some(captureSym) => - k(Select(ctx.getLocalClosPath(l).get.read, captureSym.id)(N)) - case None => ctx.getLocalPath(l) match - case Some(value) => k(value.read) - case None => super.applyPath(p)(k) case _ => super.applyPath(p)(k) - - // When calling a lifted function or constructor, we need to pass, as arguments, the local variables, - // inner symbols, etc that it needs to access. This function creates those arguments for that in - // the correct order. - def getCallArgs(sym: FunSyms[?], ctx: LifterCtx) = - val info = ctx.getBmsReqdInfo(sym.b).get - val localsArgs = info.reqdVars.map(s => ctx.getLocalPath(s).get.asArg) - val capturesArgs = info.reqdCaptures.map(ctx.getCapturePath(_).get.asArg) - val iSymArgs = info.reqdInnerSyms.map(ctx.getIsymPath(_).get.asArg) - val bmsArgs = info.reqdBms.map(ctx.getIgnoredBmsPath(_).get.asArg) - bmsArgs ++ iSymArgs ++ localsArgs ++ capturesArgs - // This creates a call to a lifted function or constructor. - def createCall(syms: FunSyms[?], ctx: LifterCtx): Call = - val info = ctx.getBmsReqdInfo(syms.b).get - val callSym = info.fakeCtorBms match - case Some(v) => v - case None => syms - Call(callSym.asPath, getCallArgs(syms, ctx))(false, false, false) - - /* - * Explanation of liftOutDefnCont, liftDefnsInCls, liftDefnsInFn: - * - * The initial call is to liftDefnsInFn or liftDefnsInCls: - * - liftDefnsInFn rewrites a function's body so that it references variables correctly, and calls liftOutDefnCont - * on its nested definitions and lifts them (if they're not ignored). - * - liftDefnsInCls does the same but for classes by rewriting their constructors and methods. Notably, it directly - * calls liftDefnsInFn on its member functions. - * - * liftOutDefnCont's purpose is to rewrite definitions' signatures so that they make sense after being lifted. This - * includes adding the parameter lists which take in variables, captures, references to inner symbols etc. If a - * definition has been marked as "ignored" (not lifted), or if the definition is so simple that it doesn't need, - * extra parameter lists, it will directly call liftDefnsInFn or liftDefnsInCls on that definition. - */ - def liftOutDefnCont(base: Defn, d: Defn, ctx: LifterCtx): Lifted[Defn] = ctx.getBmsReqdInfo(d.sym) match - case N => d match - case f: FunDefn => liftDefnsInFn(f, ctx) - case c: ClsLikeDefn => liftDefnsInCls(c, ctx) - case _ => Lifted(d, Nil) - case S(LiftedInfo(includedCaptures, includedLocals, clsCaptures, reqdBms, fakeCtorBms, singleCallBms)) => - - def createSymbolsUpdateCtx[T <: LocalPath](createSym: String => (VarSymbol, T)) - : (List[Param], LifterCtx, List[(Local, (VarSymbol, T))]) - = - val capturesSymbols = includedCaptures.map: sym => - (sym, createSym(sym.nme + "$capture")) - - val localsSymbols = includedLocals.map: sym => - (sym, createSym(sym.nme)) - - val isymSymbols = clsCaptures.map: sym => - (sym, createSym(sym.nme + "$instance")) - - val bmsSymbols = reqdBms.map: sym => - (sym, createSym(sym.nme + "$member")) - - val extraParamsCaptures = capturesSymbols.map: // parameter list - case (d, (sym, _)) => Param(FldFlags.empty, sym, N, Modulefulness.none) - val newCapturePaths = capturesSymbols.map: // mapping from sym to param symbol - case (d, (_, lp)) => d -> lp - .toMap - - val extraParamsLocals = localsSymbols.map: // parameter list - case (d, (sym, _)) => Param(FldFlags.empty, sym, N, Modulefulness.none) - val newLocalsPaths = localsSymbols.map: // mapping from sym to param symbol - case (d, (_, lp)) => d -> lp - .toMap - - val extraParamsIsyms = isymSymbols.map: // parameter list - case (d, (sym, _)) => Param(FldFlags.empty, sym, N, Modulefulness.none) - val newIsymPaths = isymSymbols.map: // mapping from sym to param symbol - case (d, (_, lp)) => d -> lp - .toMap + given ignoredScopes: IgnoredScopes = IgnoredScopes(N) + val data = ScopeData(topLevelBlk) + val metadata = data.root.children.foldLeft(LifterMetadata.empty)(_ ++ createMetadata(_)) + + def asDSym(s: ClsSym | ModuleOrObjSym): DefinitionSymbol[?] = s + val ignored: Set[ScopedInfo] = metadata.unliftable.map(asDSym) + ignoredScopes.ignored = S(ignored) + + val usedVars = UsedVarAnalyzer(topLevelBlk, data) + + // for debugging + def printMap[T, V](m: Map[T, V]) = + println("Map(") + for case (k, v) <- m do + print(" ") + print(k) + print(" -> ") + println(v) + println(")") + + + /* + println("accessesShallow") + printMap(usedVars.shallowAccesses) + println("accesses") + printMap(usedVars.accessMap) + printMap(usedVars.accessMapWithIgnored) + println("usedVars") + printMap(usedVars.reqdCaptures) + */ + + case class LifterResult[+T](liftedDefn: T, extraDefns: List[Defn]) + case class LifterCtxNew( + liftedScopes: MutMap[LiftedSym, LiftedScope[?]] = MutMap.empty, + rewrittenScopes: MutMap[ScopedInfo, RewrittenScope[?]] = MutMap.empty, + var symbolsMap: Map[Local, LocalPath] = Map.empty, + var defnsMap: Map[DefinitionSymbol[?], DefnRef] = Map.empty, + var capturesMap: Map[ScopedInfo, Path] = Map.empty + ) + + /** + * Creates a capture class for a function consisting of its mutable (and possibly immutable) local variables. + * @param f The function to create the capture class for. + * @param ctx The lifter context. Determines which variables will be captured. + * @return The tuple (defn, varsMap), where `defn` is the capture class's definition, and + * `varsMap` maps the function's locals to the corresponding `VarSymbol` (for the class parameters) in the correct order. + */ + def createCaptureCls(s: ScopedObject) + : (ClsLikeDefn, List[(Symbol, TermSymbol)]) = + val nme = "Capture$" + s.nme - val extraParamsBms = bmsSymbols.map: // parameter list - case (d, (sym, _)) => Param(FldFlags.empty, sym, N, Modulefulness.none) - val newBmsPaths = bmsSymbols.map: // mapping from sym to param symbol - case (d, (_, lp)) => d -> lp - .toMap + val clsSym = ClassSymbol( + Tree.DummyTypeDef(syntax.Cls), + Tree.Ident(nme) + ) - val extraParams = extraParamsBms ++ extraParamsIsyms ++ extraParamsLocals ++ extraParamsCaptures + val cap = usedVars.reqdCaptures(s.toInfo) - val newCtx = ctx - .replCapturePaths(newCapturePaths) - .replLocalPaths(newLocalsPaths) - .addIsymPaths(newIsymPaths) - .replIgnoredBmsPaths(newBmsPaths) + val fresh = FreshInt() + + val sortedVars: Array[(ctorSyms: (local: Local, vs: VarSymbol), param: Param, valDefn: ValDefn)] = cap.toArray.sortBy(_.uid).map: sym => + val id = fresh.make + val nme = sym.nme + "$" + id + + val ident = new Tree.Ident(nme) + val varSym = VarSymbol(ident) + val fldSym = BlockMemberSymbol(nme, Nil) + val tSym = TermSymbol(syntax.MutVal, S(clsSym), ident) + + val p = Param(FldFlags.empty.copy(isVal = true), varSym, N, Modulefulness.none) + varSym.decl = S(p) // * Currently this is only accessed to create the class' toString method + + val vd = ValDefn( + tSym, + fldSym, + Value.Ref(varSym) + ) + + (sym -> varSym, p, vd) + + val defn = ClsLikeDefn( + None, clsSym, BlockMemberSymbol(nme, Nil), + S(TermSymbol(syntax.Fun, S(clsSym), clsSym.id)), + syntax.Cls, + N, + PlainParamList(sortedVars.iterator.map(_.param).toList) :: Nil, None, Nil, Nil, + Nil, + End(), + sortedVars.iterator.foldLeft[Block](End()): + case (acc, (_, _, vd)) => Define(vd, acc), + N, + N, + ) + + (defn, sortedVars.iterator.map(x => (x.ctorSyms.local, x.valDefn.tsym)).toList) + + class ScopeRewriter(using ctx: LifterCtxNew) extends BlockTransformerShallow(SymbolSubst()): + + val extraDefns: ListBuffer[Defn] = ListBuffer.empty + + def applyRewrittenScope[T](r: RewrittenScope[T]): T = + val LifterResult(rewritten, defns) = liftNestedScopes(r) + extraDefns ++= defns + rewritten + + override def applyBlock(b: Block): Block = b match + case s: Scoped => + val uid = data.getUID(s) + applyRewrittenScope(ctx.rewrittenScopes(uid)) match + case b: Block => b + case _ => die + case l: Label if l.loop => + val node = data.getNode(l.label) + val blk = applyRewrittenScope(ctx.rewrittenScopes(l.label)) match + case b: Block => b + case _ => die + l.copy(body = blk) + case Define(defn, rest) => + val dsym = defn match + case f: FunDefn => f.dSym + case v: ValDefn => v.tsym + case c: ClsLikeDefn => c.isym + ctx.liftedScopes.get(dsym) match + case Some(_) => applySubBlock(rest) + case None => super.applyBlock(b) + case _ => super.applyBlock(b) + override def applyFunDefn(fun: FunDefn) = + applyRewrittenScope(ctx.rewrittenScopes(fun.dSym)) match + case f: FunDefn => f + case _ => die + override def applyDefn(defn: Defn)(k: Defn => Block) = defn match + case f: FunDefn => k(applyFunDefn(f)) + case c: ClsLikeDefn => + val newCls = applyRewrittenScope(ctx.rewrittenScopes(c.isym)) match + case c: ClsLikeDefn => c + case _ => die + val newComp = c.companion.map(comp => applyRewrittenScope(ctx.rewrittenScopes(comp.isym))) match + case Some(c: ClsLikeBody) => S(c) + case Some(_) => die + case None => N - (extraParams, newCtx, capturesSymbols ++ localsSymbols ++ isymSymbols ++ bmsSymbols) - - d match - case f: FunDefn => - def createSym(nme: String) = - val vsym = VarSymbol(Tree.Ident(nme)) - (vsym, LocalPath.Sym(vsym)) - val (extraParams, newCtx, _) = createSymbolsUpdateCtx(createSym) - - // create second param list with different symbols - val extraParamsCpy = extraParams.map(p => p.copy(sym = VarSymbol(p.sym.id))) + k(newCls.copy(companion = newComp)) + case _ => super.applyDefn(defn)(k) - val headPlistCopy = f.params.headOption match - case None => PlainParamList(Nil) - case Some(value) => ParamList(value.flags, value.params.map(p => p.copy(sym = VarSymbol(p.sym.id))), value.restParam) - - val flatPlist = f.params match - case head :: next => ParamList(head.flags, extraParams ++ head.params, head.restParam) :: next - case Nil => PlainParamList(extraParams) :: Nil - - val newDef = FunDefn( - base.owner, f.sym, f.dSym, PlainParamList(extraParams) :: f.params, f.body - )(f.forceTailRec) - val Lifted(lifted, extras) = liftDefnsInFn(newDef, newCtx) - - val args1 = extraParamsCpy.map(p => p.sym.asPath.asArg) - val args2 = headPlistCopy.params.map(p => p.sym.asPath.asArg) - - val bdy = blockBuilder - .ret(Call(singleCallBms.asPath, args1 ++ args2)(true, true, false)) // TODO: restParams not considered + /** + * Represents a scoped object that will be rewritten to reference the lifted version of objects and variables. + */ + sealed abstract class RewrittenScope[T](val obj: TScopedObject[T]): + val node = obj.node.get + + protected final val thisCapturedLocals = usedVars.reqdCaptures(obj.toInfo) + val hasCapture = !thisCapturedLocals.isEmpty + + // These are lazy, because we don't necessarily need a captrue + private final lazy val captureInfo: (ClsLikeDefn, List[(Local, TermSymbol)]) = createCaptureCls(obj) + + lazy val captureClass = captureInfo._1 + lazy val captureMap = captureInfo._2.toMap + lazy val liftedObjsMap: Map[InnerSymbol, LocalPath] + + lazy val capturePath: Path - val mainDefn = FunDefn(f.owner, f.sym, f.dSym, PlainParamList(extraParamsCpy) :: headPlistCopy :: Nil, bdy)(false) - val auxDefn = FunDefn(N, singleCallBms.b, singleCallBms.d, flatPlist, lifted.body)(forceTailRec = f.forceTailRec) - - if ctx.firstClsFns.contains(f.sym) then - Lifted(mainDefn, auxDefn :: extras) - else - Lifted(auxDefn, extras) // we can include just the flattened defn - case c: ClsLikeDefn => - val fresh = FreshInt() - def createSym(nme: String): (VarSymbol, LocalPath.PubField) = - ( - VarSymbol(Tree.Ident(nme)), - LocalPath.PubField(c.isym, BlockMemberSymbol(nme, Nil, true)) + protected def rewriteImpl: LifterResult[T] + + protected final def addExtraSyms(b: Block, captureSym: Local, objSyms: Iterable[Local], define: Bool): Block = + if hasCapture then + val undef = Value.Lit(Tree.UnitLit(false)).asArg + val inst = Instantiate( + true, + Value.Ref(captureClass.sym, S(captureClass.isym)), + captureInfo._2.map: + case (sym, _) => sym.asPath.asArg + ) + val assign = Assign(captureSym, inst, b) + if define then + Scoped( + Set(captureSym) ++ objSyms, + assign + ) + else assign + else + if define then Scoped(objSyms.toSet, b) else b + + /** + * Rewrites the contents of this scoped object to reference the lifted versions of variables. + * + * @return The rewritten scoped object, plus any extra scoped definitions arising from lifting the nested scoped objects. + */ + final def rewrite = + if hasCapture then + val LifterResult(defn, extra) = rewriteImpl + LifterResult(defn, captureClass :: extra) + else rewriteImpl + + /** The path to access locals defined by this object. The primary purpose of this is to rewrite accesses + * to locals that have been moved to a capture. + */ + protected final def pathsFromThisObj: Map[Local, LocalPath] = + // Remove child BlockMemberSymbols; we will use their definition symbols instead + + // Locals introduced by this object + val fromThisObj = node.localsWithoutBms + .map: s => + s -> s.asLocalPath + .toMap + // Locals introduced by this object that are inside this object's capture + val fromCap = thisCapturedLocals + .map: s => + val tSym = captureMap(s) + s -> LocalPath.InCapture(capturePath, tSym) + .toMap + // Inner symbols of nested modules and objects + val isyms = node.children + .collect: + case ScopeNode(obj = c: ScopedObject.Companion) => + val s: Local = c.comp.isym + s -> LocalPath.BmsRef(c.bsym, c.comp.isym) + case ScopeNode(obj = c: ScopedObject.Class) if c.isObj => + c.cls.isym -> (liftedObjsMap.get(c.cls.isym) match + case Some(value) => value // lifted + case None => LocalPath.BmsRef(c.bsym, c.cls.isym) // not lifted ) - val (extraParams, newCtx, flds) = createSymbolsUpdateCtx(createSym) - - // add aux params, private fields, update preCtor - val newAuxParams = c.auxParams.appended(PlainParamList(extraParams)) - - val pubFieldsPairs = flds.map: - case (_, (vs, LocalPath.PubField(isym, sym))) => vs -> sym - val newPubFields = c.publicFields ::: pubFieldsPairs.map(_._2).map(bsym => bsym -> - TermSymbol(syntax.MutVal, S(c.isym), Tree.Ident(bsym.nme))) - - val newCtor = pubFieldsPairs.foldRight(c.ctor): - case ((sym, bms), blk) => Define(ValDefn.mk(S(c.isym), syntax.MutVal, bms, sym.asPath), blk) - - if modOrObj(c) then // module or object - // force it to be a class - val newK = c.k match - case syntax.Obj => syntax.Cls - case _ => wat("unreachable", c.k) - - val newDef = c.copy( - k = newK, paramsOpt = N, - owner = N, auxParams = PlainParamList(extraParams) :: Nil, - publicFields = newPubFields, - ctor = newCtor - ) - liftDefnsInCls(newDef, newCtx) - else // normal class - - val newDef = c.copy( - owner = N, - auxParams = newAuxParams, - publicFields = newPubFields, - ctor = newCtor - ) - - val Lifted(lifted, extras) = liftDefnsInCls(newDef, newCtx) - - val bms = fakeCtorBms.get - - // create the fake ctor here - inline def mapParams(ps: ParamList) = ps.params.map(p => VarSymbol(p.sym.id)) - - val paramSyms = c.paramsOpt.map(mapParams) // what is defined in paramsOpt - val auxSyms = c.auxParams.map(mapParams) // the original class's aux params - val extraSyms = extraParams.map(p => VarSymbol(p.sym.id)) // these will be added to the aux params - - // pop one list fromm auxSyms if paramsOpt is empty - // these are for creating the body only - val (newParamSyms, newAuxSyms) = paramSyms match - case None => auxSyms match - case head :: next => (S(head), next.appended(extraSyms)) - case Nil => (S(extraSyms), Nil) - case Some(value) => (paramSyms, auxSyms.appended(extraSyms)) - - val paramArgs = newParamSyms.getOrElse(Nil).map(_.asPath.asArg) - - inline def toPaths(l: List[Local]) = l.map(_.asPath) - - val isMutSym = VarSymbol(Tree.Ident("isMut")) - - def instInner(isMut: Bool) = - Instantiate(mut = isMut, Value.Ref(c.sym, S(c.isym)), paramArgs) - - val initSym = TempSymbol(None, "tmp") - - def go(cur: List[List[VarSymbol]], curSym: Symbol, acc: Block => Block): Block = cur match - case ps :: rst => - val call = Call(curSym.asPath, ps.map(_.asPath.asArg))(true, - rst === Nil && config.checkInstantiateEffect, false) - val thisSym = TempSymbol(None, "tmp") - go(rst, thisSym, acc.assignScoped(thisSym, call)) - case Nil => acc.ret(curSym.asPath) - - val bod = go(newAuxSyms, initSym, blk => Match( - isMutSym.asPath, - Case.Lit(Tree.BoolLit(true)) -> Assign(initSym, instInner(true), End()) :: Nil, - S(Assign(initSym, instInner(false), End())), - blk - )) - - inline def toPlist(ls: List[VarSymbol]) = - PlainParamList(ls.map(s => Param(FldFlags.empty, s, N, Modulefulness.none))) - - val paramPlist = paramSyms.map(toPlist) - val auxPlist = auxSyms.map(toPlist) - // isMut determines whether the instantiation is `new` or `new mut` - val extraPlist = toPlist(isMutSym :: extraSyms) - - // NOTE: The fake ctor was to support first-class classes. - // These are currently unused. - - /* - val plist = paramPlist match - case None => extraPlist :: PlainParamList(Nil) :: auxPlist - case Some(value) => extraPlist :: value :: auxPlist - - val fakeCtorDefn = FunDefn( - None, bms, plist, bod - ) - */ - - val paramSym2 = paramSyms.getOrElse(Nil) - val auxSym2 = auxSyms.flatMap(l => l) - val allSymsMp = (paramSym2 ++ auxSym2 ++ extraSyms).map(s => s -> VarSymbol(s.id)).toMap - val subst = new SymbolSubst(): - override def mapVarSym(s: VarSymbol): VarSymbol = allSymsMp.get(s) match - case None => s - case Some(value) => value - - val (headParams, newAuxPlist) = paramPlist match - case None => auxPlist match - case head :: next => (ParamList(head.flags, extraPlist.params ++ head.params, head.restParam), next) - case Nil => (extraPlist, auxPlist) - - case Some(value) => (ParamList(value.flags, extraPlist.params ++ value.params, value.restParam), auxPlist) - - val auxCtorDefn_ = FunDefn(None, singleCallBms.b, singleCallBms.d, headParams :: newAuxPlist, Scoped(Set.single(initSym), bod))(false) - val auxCtorDefn = BlockTransformer(subst).applyFunDefn(auxCtorDefn_) - - // Lifted(lifted, extras ::: (fakeCtorDefn :: auxCtorDefn :: Nil)) - Lifted(lifted, extras ::: (auxCtorDefn :: Nil)) - case _ => Lifted(d, Nil) + .toMap + // Note: the order here is important, as fromCap must override keys from + // fromThisObj. + isyms ++ fromThisObj ++ fromCap + + lazy val capturePaths = + if thisCapturedLocals.isEmpty then Map.empty + else Map(obj.toInfo -> capturePath) + + // BMS refs from ignored defns (including child defns of modules) + // Note that we map the DefinitionSymbol to the disambiguated BMS. + protected val defnPathsFromThisObj: Map[DefinitionSymbol[?], DefnRef] = + node.children.filter: + case s @ ScopeNode(obj = r: ScopedObject.Class) if r.isObj => false + case _ => true + .collect: + case s @ ScopeNode(obj = r: ScopedObject.Referencable[?]) => !s.isLifted + val path = r.owner match + case Some(isym) => DefnRef.Field(isym, r.bsym, r.sym) + case None => DefnRef.InScope(r.bsym, r.sym) + r.sym -> path + .toMap + + lazy val defnPaths: Map[DefinitionSymbol[?], DefnRef] = defnPathsFromThisObj + + lazy val symbolsMap: Map[Local, LocalPath] = pathsFromThisObj - end liftOutDefnCont + /** Represents a scoped object that is to be rewritten and lifted. */ + sealed abstract class LiftedScope[T <: Defn](override val obj: ScopedObject.Liftable[T])(using ctx: LifterCtxNew) extends RewrittenScope[T](obj): + private val AccessInfo(accessed, _, refdScopes) = usedVars.accessMap(obj.toInfo) + private val AccessInfo(_, _, allRefdScopes) = usedVars.accessMapWithIgnored(obj.toInfo) + private val refdDSyms = refdScopes.collect: + case d: LiftedSym => d + .toSet + + /** Symbols that this object will lose access to once lifted, and therefore must receive + * as a parameter. Does not include neighbouring objects that this definition may lose + * access to. Those are in a separate list. + * + * Includes symbols introduced by modules and objects. + */ + final val reqSymbols = accessed ++ allRefdScopes.map(data.getNode(_).obj) + .collect: + case s: ScopedObject.Referencable[?] if s.owner.isDefined => s.owner.get + .collect: + case d: DefinitionSymbol[?] => d + .filter: d => + data.getNode(d).obj match + case _: ScopedObject.Companion => true + case c: ScopedObject.Class if c.isObj => true + case _ => false + + private val (reqPassedSymbols, captures) = reqSymbols + .partitionMap: s => + usedVars.capturesMap.get(s) match + case Some(info) => R((s, info)) + case None => L(s) + + /** Locals that are directly passed to this object, i.e. not via a capture. */ + final val passedSyms: Set[Local] = reqPassedSymbols + /** Maps locals to the scope where they were defined. */ + final val capturesOrigin: Map[Local, ScopedInfo] = captures.toMap + /** Locals that are inside captures. */ + final val inCaptureSyms: Set[Local] = captures.map(_._1) + /** Scopes whose captures this object requires. */ + final val reqCaptures: Set[ScopedInfo] = captures.map(_._2) + /** + * Neighbouring objects that this definition may lose access to + * once lifted, referenced by their *definition symbol* (not BMS). + */ + final val reqDefns = node.reqCaptureObjs + .filter: + case f: ScopedObject.Func if f.isMethod.isDefined => false + case _ => true + .map(_.sym) + .toSet.intersect(refdDSyms) + + /** Maps directly passed locals to the path representing that local within this object. */ + protected val passedSymsMap: Map[Local, LocalPath] + /** Maps scopes to the path representing their captures within this object. */ + protected val capSymsMap: Map[ScopedInfo, Path] + /** Maps definition symbols to the path representing that definition. */ + protected val passedDefnsMap: Map[DefinitionSymbol[?], DefnRef] + + protected lazy val capturesOrdered: List[ScopedInfo] + protected final lazy val passedSymsOrdered: List[Local] = reqPassedSymbols.toList.sortBy(_.uid) + protected final lazy val passedDefnsOrdered: List[DefinitionSymbol[?]] = reqDefns.toList.sortBy(_.uid) + + override lazy val capturePaths = + if thisCapturedLocals.isEmpty then capSymsMap + else capSymsMap + (obj.toInfo -> capturePath) + + // Note: we have to make this lazy because Scala's type system is unsound and + // lets you access the above two fields before they are initialized + // (since this constructor runs before the child classes' constructors) + + /** Maps symbols to the path representing that local within this object. + * Includes locals defined by this object's parents, and this object's own defined locals. + */ + override lazy val symbolsMap: Map[Local, LocalPath] = + val fromParents = reqSymbols + .map: s => + passedSymsMap.get(s) match + // The symbol is passed directly + case Some(value) => s -> value + // The symbol is passed in a capture + case None => + val fromScope = capturesOrigin(s) + val capSym = capSymsMap(fromScope) + val tSym = ctx.rewrittenScopes(fromScope).captureMap(s) + s -> LocalPath.InCapture(capSym, tSym) + .toMap + fromParents ++ pathsFromThisObj + + override lazy val defnPaths: Map[DefinitionSymbol[?], DefnRef] = + val fromParents = reqDefns + .map: s => + s -> passedDefnsMap(s) + .toMap + defnPathsFromThisObj ++ fromParents + + final def formatArgs: List[Arg] = + val defnsArgs = passedDefnsOrdered.map(d => ctx.defnsMap(d).asArg) + val captureArgs = capturesOrdered.map(c => ctx.capturesMap(c).asArg) + val localArgs = passedSymsOrdered.map(l => ctx.symbolsMap(l).asArg) + defnsArgs ::: captureArgs ::: localArgs - def removeDefnsFromScope(b: Block, defns: List[Defn]) = b match - case Scoped(syms, body) => Scoped(syms.toSet -- defns.map(_.sym), body) - case _ => b + /* MIXINS */ - def liftDefnsInCls(c: ClsLikeDefn, ctx: LifterCtx): Lifted[ClsLikeDefn] = - val ctxx = if c.companion.isDefined then ctx.inModule(c) else ctx // TODO: refine handling of companions - - // =========================================================== - // STEP 1: lift out definitions nested in the ctor and prector - // and deal with class defns in the companion class - - val (preCtor, preCtorDefns) = c.preCtor.floatOut(ctxx) - val (ctor, ctorDefns) = c.ctor.floatOut(ctxx) - val (cCtor, cCtorDefns) = c.companion.fold((None, Nil)): - case value => - val (a, b) = value.ctor.floatOut(ctxx) - (S(a), b) + /** + * A rewritten scope with a generic VarSymbol capture symbol. + */ + sealed trait GenericRewrittenScope[T] extends RewrittenScope[T]: + lazy val captureSym = VarSymbol(Tree.Ident(obj.nme + "$cap")) + override lazy val capturePath = captureSym.asPath + protected val liftedObjsSyms: Map[InnerSymbol, VarSymbol] = node.liftedObjSyms.map: s => + s -> VarSymbol(Tree.Ident(s.nme + "$")) + .toMap + override lazy val liftedObjsMap: Map[InnerSymbol, LocalPath] = liftedObjsSyms.map: + case k -> v => k -> v.asLocalPath + + protected def addExtraSyms(b: Block): Block = addExtraSyms(b, captureSym, liftedObjsSyms.values, true) + + /** + * A rewritten scope with a TermSymbol capture symbol. + */ + sealed trait ClsLikeRewrittenScope[T](sym: InnerSymbol) extends RewrittenScope[T]: + lazy val captureSym = TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(obj.nme + "$cap")) + override lazy val capturePath = captureSym.asPath + protected val liftedObjsSyms: Map[InnerSymbol, TermSymbol] = node.liftedObjSyms.map: s => + s -> TermSymbol(syntax.ImmutVal, S(sym), Tree.Ident(s.nme + "$")) + .toMap + override lazy val liftedObjsMap: Map[InnerSymbol, LocalPath] = liftedObjsSyms.map: + case k -> v => k -> v.asLocalPath + + // some helpers + private def dupParam(p: Param): Param = p.copy(sym = VarSymbol(Tree.Ident(p.sym.nme))) + private def dupParams(plist: List[Param]): List[Param] = plist.map(dupParam) + private def dupParamList(plist: ParamList): ParamList = + plist.copy(params = dupParams(plist.params), restParam = plist.restParam.map(dupParam)) + + /* CONCRETE IMPLS */ + + class RewrittenScopedBlock(override val obj: ScopedObject.ScopedBlock)(using ctx: LifterCtxNew) extends RewrittenScope[Block](obj) with GenericRewrittenScope[Block]: + override def rewriteImpl: LifterResult[Block] = + val rewriter = new BlockRewriter + + // Remove symbols belonging to lifted scopes + val liftedChildSyms = node.children.collect: + case s @ ScopeNode(obj = l: ScopedObject.Liftable[?]) if s.isLifted => l.defn.sym + + val (syms, rewritten) = (obj.block.syms.toSet -- liftedChildSyms, rewriter.rewrite(obj.block.body)) + val withCapture = addExtraSyms(rewritten) + LifterResult(Scoped(syms, withCapture), rewriter.extraDefns.toList) + + class RewrittenLoop(override val obj: ScopedObject.Loop)(using ctx: LifterCtxNew) extends RewrittenScope[Block](obj) with GenericRewrittenScope[Block]: + override def rewriteImpl: LifterResult[Block] = + val rewriter = new BlockRewriter + + val rewritten = rewriter.rewrite(obj.body) + val withCapture = addExtraSyms(rewritten) + LifterResult(withCapture, rewriter.extraDefns.toList) + + class RewrittenFunc(override val obj: ScopedObject.Func)(using ctx: LifterCtxNew) extends RewrittenScope[FunDefn](obj) with GenericRewrittenScope[FunDefn]: + override def rewriteImpl: LifterResult[FunDefn] = + val rewriter = new BlockRewriter + + val rewritten = rewriter.rewrite(obj.fun.body) + val withCapture = addExtraSyms(rewritten) + LifterResult(obj.fun.copy(body = withCapture)(obj.fun.forceTailRec), rewriter.extraDefns.toList) + + private def rewriteMethods(node: ScopeNode, methods: List[FunDefn])(using ctx: LifterCtxNew) = + val mtds = node.children + .map: c => + ctx.rewrittenScopes(c.obj.toInfo) + .collect: + case r: RewrittenFunc if r.obj.isMethod.isDefined => r + val (liftedMtds, extras) = mtds.map(liftNestedScopes).unzip(using l => (l.liftedDefn, l.extraDefns)) + LifterResult(liftedMtds, extras.flatten) + + class RewrittenClassCtor(override val obj: ScopedObject.ClassCtor)(using ctx: LifterCtxNew) extends RewrittenScope[Unit](obj): + override lazy val capturePath: Path = lastWords("tried to create a capture class for a class ctor") + override lazy val liftedObjsMap: Map[InnerSymbol, LocalPath] = lastWords("tried to create obj syms for a class ctor") - val allCtorDefns = preCtorDefns ++ ctorDefns ++ cCtorDefns + override protected def rewriteImpl: LifterResult[Unit] = LifterResult((), Nil) // dummy + + def getRewrittenCls = ctx.rewrittenScopes(obj.cls.isym) + + class RewrittenClass(override val obj: ScopedObject.Class)(using ctx: LifterCtxNew) + extends RewrittenScope[ClsLikeDefn](obj) + with ClsLikeRewrittenScope[ClsLikeDefn](obj.cls.isym): - // ctorIgnored: definitions within the class (i.e. ctor) that we don't lift - // ctorIncluded: ditto, but lifted - val (ctorIgnored, ctorIncluded) = allCtorDefns.partition(d => ctxx.ignored(d.sym)) + private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap")) + override lazy val capturePath: Path = captureSym.asPath + + override def rewriteImpl: LifterResult[ClsLikeDefn] = + val rewriterCtor = new BlockRewriter + val rewriterPreCtor = new BlockRewriter + val rewrittenCtor = rewriterCtor.rewrite(obj.cls.ctor) + val rewrittenPrector = rewriterPreCtor.rewrite(obj.cls.preCtor) + val ctorWithCap = addExtraSyms(rewrittenCtor, captureSym, Nil, false) + + val LifterResult(newMtds, extras) = rewriteMethods(node, obj.cls.methods) + val newCls = obj.cls.copy( + ctor = ctorWithCap, + preCtor = rewrittenPrector, + privateFields = captureSym :: liftedObjsSyms.values.toList ::: obj.cls.privateFields, + methods = newMtds, + ) + LifterResult(newCls, rewriterCtor.extraDefns.toList ::: rewriterPreCtor.extraDefns.toList ::: extras) - // Symbols containing refernces to nested classes and nested objects are here - - // Deals with references to lifted objects defined within the class - val nestedClsPaths: Map[Local, LocalPath] = ctorIncluded.map: - case c: ClsLikeDefn if modOrObj(c) => ctxx.modObjLocals.get(c.sym) match - case Some(sym) => S(c.isym -> LocalPath.Sym(sym)) - case _ => S(c.sym -> LocalPath.Sym(c.sym)) - case _ => None - .collect: - case Some(x) => x + class RewrittenCompanion(override val obj: ScopedObject.Companion)(using ctx: LifterCtxNew) + extends RewrittenScope[ClsLikeBody](obj) + with ClsLikeRewrittenScope[ClsLikeBody](obj.comp.isym): + + private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.comp.isym), Tree.Ident(obj.nme + "$cap")) + override lazy val capturePath: Path = captureSym.asPath + + override def rewriteImpl: LifterResult[ClsLikeBody] = + val rewriterCtor = new BlockRewriter + val rewrittenCtor = rewriterCtor.rewrite(obj.comp.ctor) + val ctorWithCap = addExtraSyms(rewrittenCtor, captureSym, Nil, false) + val LifterResult(newMtds, extras) = rewriteMethods(node, obj.comp.methods) + val newComp = obj.comp.copy( + ctor = ctorWithCap, + privateFields = captureSym :: liftedObjsSyms.values.toList ::: obj.comp.privateFields, + methods = newMtds + ) + LifterResult(newComp, rewriterCtor.extraDefns.toList ::: extras) + + class LiftedFunc(override val obj: ScopedObject.Func)(using ctx: LifterCtxNew) extends LiftedScope[FunDefn](obj) with GenericRewrittenScope[FunDefn]: + private val passedSymsMap_ : Map[Local, VarSymbol] = passedSyms.map: s => + s -> VarSymbol(Tree.Ident(s.nme)) + .toMap + private val capSymsMap_ : Map[ScopedInfo, VarSymbol] = reqCaptures.map: i => + val nme = data.getNode(i).obj.nme + i -> VarSymbol(Tree.Ident(nme + "$cap")) .toMap + private val defnSymsMap_ : Map[DefinitionSymbol[?], VarSymbol] = reqDefns.map: i => + val nme = data.getNode(i).obj.nme + i -> VarSymbol(Tree.Ident(nme + "$")) + .toMap + + override lazy val capturesOrdered: List[ScopedInfo] = reqCaptures.toList.sortBy(c => capSymsMap_(c).uid) + + override protected val passedSymsMap = passedSymsMap_.view.mapValues(_.asLocalPath).toMap + override protected val capSymsMap = capSymsMap_.view.mapValues(_.asPath).toMap + override protected val passedDefnsMap = defnSymsMap_.view.mapValues(_.asDefnRef).toMap + + val auxParams: List[Param] = + (passedDefnsOrdered.map(defnSymsMap_) ::: capturesOrdered.map(capSymsMap_) ::: passedSymsOrdered.map(passedSymsMap_)) + .map: s => + val decl = Param(FldFlags.empty.copy(isVal = false), s, N, Modulefulness.none) + s.decl = S(decl) + decl + + // Whether this can be lifted without the need to pass extra parameters. + val isTrivial = auxParams.isEmpty - val newCtx_ = ctxx - // references to lifted objects - .addLocalPaths(nestedClsPaths) - // references to variables defined in the class, including ones in the companion obj - .addLocalPaths(getVars(c).map(s => s -> LocalPath.Sym(s)).toMap) - // reference to unlifted BMS's - .addIgnoredBmsPaths(ctorIgnored.map(d => d.sym -> LocalPath.Sym(d.sym)).toMap) - .addIsymPath(c.isym, LocalPath.Sym(c.isym)) - .inISym(c.isym) + val fun = obj.fun + + val (mainSym, mainDsym) = (fun.sym, fun.dSym) + val auxSym = BlockMemberSymbol(fun.sym.nme + "$", Nil, fun.sym.nameIsMeaningful) + val auxDsym = TermSymbol.fromFunBms(auxSym, fun.owner) + + // Definition with the auxiliary parameters merged into the first parameter list. + private def mkFlattenedDefn: LifterResult[FunDefn] = + val newPlists = fun.params match + case head :: next => head.copy(params = auxParams ::: head.params) :: next + case Nil => PlainParamList(auxParams) :: Nil + val rewriter = new BlockRewriter + val newBod = rewriter.rewrite(fun.body) + val withCapture = addExtraSyms(newBod) + val newDefn = fun.copy(owner = N, sym = mainSym, dSym = mainDsym, params = newPlists, body = withCapture)(fun.forceTailRec) + LifterResult(newDefn, rewriter.extraDefns.toList) + + // Definition with the auxiliary parameters merged into the second parameter list. + private def mkAuxDefn: FunDefn = + val newPList = PlainParamList(dupParams(auxParams)) + val (newPlists, syms, restSym) = fun.params match + case head :: _ => + val duped = dupParamList(head) + ( + newPList :: duped :: Nil, + newPList.params.map(_.sym) ::: duped.params.map(_.sym), + duped.restParam.map(_.sym)) + // we need to append an empty param list so calling this function returns a lambda + case Nil => + ( + newPList :: PlainParamList(Nil) :: Nil, + newPList.params.map(_.sym), + N + ) + val args = restSym match + case Some(value) => + val tail = Arg(S(true), value.asPath) :: Nil + syms.foldLeft(tail): + case (acc, sym) => Arg(N, sym.asPath) :: acc + case None => syms.map(s => Arg(N, s.asPath)) - // add the reference to `this` if it has a companion object - val newCtx = c.companion match - case None => newCtx_ - case Some(value) => newCtx_.addIsymPath(value.isym, LocalPath.Sym(c.sym)).inISym(value.isym) - - // lifts the liftable definitions - // lifted defns can no longer access currently in-scope isyms, so call resetScope - val ctorDefnsLifted = ctorIncluded.flatMap: defn => - val Lifted(liftedDefn, extraDefns) = liftOutDefnCont(c, defn, newCtx.flushModules.resetScope) - liftedDefn :: extraDefns - - // we still need to lift out definitions within unliftable defns - val ctorIgnoredLift = ctorIgnored.map: defn => - liftOutDefnCont(c, defn, newCtx) - - // we still need to rewrite definitions that aren't lifted - // this map tells us how to rewrite the Defns - val ctorIgnoredExtra = ctorIgnoredLift.flatMap(_.extraDefns) - val ctorIgnoredRewrite = ctorIgnoredLift.map: lifted => - lifted.liftedDefn.sym -> lifted.liftedDefn + val call = Call(Value.Ref(fun.sym, S(fun.dSym)), args)(true, true, false) + val bod = Return(call, false) + + FunDefn( + N, + auxSym, + auxDsym, + newPlists, + bod + )(false) + + def rewriteCall(c: Call, args: List[Arg])(using ctx: LifterCtxNew): Call = + if isTrivial then c + else + Call( + Value.Ref(mainSym, S(mainDsym)), + formatArgs ::: args + )( + isMlsFun = true, + mayRaiseEffects = c.mayRaiseEffects, + explicitTailCall = c.explicitTailCall + ) + + def rewriteRef(using ctx: LifterCtxNew): Call = + Call( + Value.Ref(auxSym, S(auxDsym)), + formatArgs + )( + isMlsFun = true, + mayRaiseEffects = false, + explicitTailCall = false + ) + + def rewriteImpl: LifterResult[FunDefn] = + val LifterResult(lifted, extra) = mkFlattenedDefn + if isTrivial then LifterResult(lifted, extra) + else LifterResult(lifted, mkAuxDefn :: extra) + class LiftedClass(override val obj: ScopedObject.Class)(using ctx: LifterCtxNew) + extends LiftedScope[ClsLikeDefn](obj) + with ClsLikeRewrittenScope[ClsLikeDefn](obj.cls.isym): + + private val captureSym = TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(obj.nme + "$cap")) + override lazy val capturePath: Path = captureSym.asPath + + private val passedSymsMap_ : Map[Local, (vs: VarSymbol, ts: TermSymbol)] = passedSyms.map: s => + s -> + ( + VarSymbol(Tree.Ident(s.nme)), + TermSymbol(syntax.MutVal, S(obj.cls.isym), Tree.Ident(s.nme)) + ) + .toMap + private val capSymsMap_ : Map[ScopedInfo, (vs: VarSymbol, ts: TermSymbol)] = reqCaptures.map: i => + val nme = data.getNode(i).obj.nme + "$cap" + i -> + ( + VarSymbol(Tree.Ident(nme)), + TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(nme)) + ) .toMap + private val defnSymsMap_ : Map[DefinitionSymbol[?], (vs: VarSymbol, ts: TermSymbol)] = reqDefns.map: i => + i -> + ( + VarSymbol(Tree.Ident(i.nme + "$")), + TermSymbol(syntax.ImmutVal, S(obj.cls.isym), Tree.Ident(i.nme + "$")) + ) + .toMap + + override lazy val capturesOrdered: List[ScopedInfo] = reqCaptures.toList.sortBy(c => capSymsMap_(c).vs.uid) + + override protected val passedSymsMap = passedSymsMap_.view.mapValues(_.ts.asLocalPath).toMap + override protected val capSymsMap = capSymsMap_.view.mapValues(_.ts.asPath).toMap + override protected val passedDefnsMap = defnSymsMap_.view.mapValues(_.ts.asDefnRef).toMap + + val auxParams: List[Param] = + (passedDefnsOrdered.map(x => defnSymsMap_(x).vs) ::: capturesOrdered.map(x => capSymsMap_(x).vs) ::: passedSymsOrdered.map(x => passedSymsMap_(x).vs)) + .map(Param.simple(_)) + + // Whether this can be lifted without the need to pass extra parameters. + val isTrivial = auxParams.isEmpty - val replacedDefnsCtx = newCtx.addreplacedDefns(ctorIgnoredRewrite) - val rewriter = BlockRewriter(newCtx.inScopeISyms, replacedDefnsCtx) - val newPreCtor = removeDefnsFromScope(rewriter.rewrite(preCtor), ctorIncluded) - val newCtor = removeDefnsFromScope(rewriter.rewrite(ctor), ctorIncluded) - val newCCtor = cCtor.map(blk => removeDefnsFromScope(rewriter.rewrite(blk), ctorIncluded)) + val cls = obj.cls - // =========================================================== - // STEP 2: rewrite non-static class methods + val flattenedSym = BlockMemberSymbol(obj.cls.sym.nme + "$", Nil, true) + val flattenedDSym = TermSymbol.fromFunBms(flattenedSym, N) - val fLifted = c.methods.map(liftDefnsInFn(_, newCtx)) - val methods = fLifted.collect: - case Lifted(liftedDefn, extraDefns) => liftedDefn - val fExtra = fLifted.flatMap: - case Lifted(liftedDefn, extraDefns) => extraDefns + def mkFlattenedDefn: Opt[FunDefn] = + if isTrivial || obj.isObj then return N + val auxSyms = auxParams.map(p => VarSymbol(Tree.Ident(p.sym.nme))) + val main = obj.cls.paramsOpt match + case Some(value) => dupParamList(value) + case None => obj.cls.auxParams.headOption match + case Some(value) => dupParamList(value) + case None => PlainParamList(Nil) + val mainSyms = main.params.map(_.sym) + val restSym = main.restParam.map(_.sym) + val argList1_ = (restSym match + case Some(value) => mainSyms.appended(value) + case None => mainSyms + ).map(s => s.asPath.asArg) + val argList2_ = auxSyms.map(s => s.asPath.asArg) - // =========================================================== - // STEP 3: rewrite companion class methods - - val cfLifted = c.companion.fold(Nil)(_.methods.map(liftDefnsInFn(_, newCtx))) + val clsIsParamless = cls.paramsOpt.isEmpty && cls.auxParams.length == 0 - val cMethods = cfLifted.collect: - case Lifted(liftedDefn, extraDefns) => liftedDefn - val cfExtra = cfLifted.flatMap: - case Lifted(liftedDefn, extraDefns) => extraDefns + val argList1 = + if cls.paramsOpt.isEmpty && cls.auxParams.length == 0 then argList2_ + else argList1_ + val argList2 = argList2_ - val newCompanion = c.companion.fold(None): - case value => S(value.copy(methods = cMethods, ctor = newCCtor.get)) - - val extras = (ctorDefnsLifted ++ fExtra ++ cfExtra ++ ctorIgnoredExtra).map: - case f: FunDefn => f.copy(owner = N)(forceTailRec = f.forceTailRec) - case c: ClsLikeDefn => c.copy(owner = N) - case d => d - - def rewriteExtends(p: Path): Path = p match - case RefOfBms(b, _) if !ctx.ignored(b) && ctx.isRelevant(b) => b.asPath - case _ => return p + val isMut = VarSymbol(Tree.Ident("isMut")) + val params = ParamList( + ParamListFlags.empty, + Param.simple(isMut) :: auxSyms.map(Param.simple(_)) ::: main.params, + main.restParam + ) + val tmp = TempSymbol(N) + val ref = Value.Ref(obj.cls.sym, S(obj.cls.isym)) + val instMut = Assign(tmp, Instantiate(true, ref, argList1), End()) + val inst = Assign(tmp, Instantiate(false, ref, argList1), End()) + val ret = + if clsIsParamless then Return(tmp.asPath, false) + else Return(Call(tmp.asPath, argList2)(true, config.checkInstantiateEffect, false), false) + val bod = Scoped(Set(tmp), Match( + isMut.asPath, + Case.Lit(Tree.BoolLit(true)) -> instMut :: Nil, + S(inst), + ret + )) + + S(FunDefn(N, flattenedSym, flattenedDSym, params :: Nil, bod)(false)) - // if this class extends something, rewrite - val newPar = c.parentPath.map(rewriteExtends) - - val newDef = c.copy( - methods = methods, - preCtor = newPreCtor, - parentPath = newPar, - ctor = newCtor, - companion = newCompanion, - ) + def instObject = Instantiate(true, Value.Ref(cls.sym, S(cls.isym)), formatArgs) - Lifted(newDef, extras) - - end liftDefnsInCls - - def liftDefnsInFn(f: FunDefn, ctx: LifterCtx): Lifted[FunDefn] = - val (captureCls, varsMap, varsList) = createCaptureCls(f, ctx) + def rewriteInstantiate(inst: Instantiate, args: List[Arg]): Result = + if obj.isObj then lastWords("tried to rewrite instantiate for an object") + if isTrivial then + if inst.args is args then inst + else inst.copy(args = args) + else + Call( + Value.Ref(flattenedSym, S(flattenedDSym)), + Value.Lit(Tree.BoolLit(inst.mut)).asArg :: formatArgs ::: args + )(true, config.checkInstantiateEffect, false) - val (blk, nested) = f.body.floatOut(ctx) - - val (ignored, included) = nested.partition(d => ctx.ignored(d.sym)) - - val modPaths: Map[Local, LocalPath] = nested.map: - case c: ClsLikeDefn if modOrObj(c) => ctx.modObjLocals.get(c.sym) match - case Some(sym) => S(c.sym -> LocalPath.Sym(sym)) - case _ => S(c.sym -> LocalPath.Sym(c.sym)) - case _ => None - .collect: - case Some(x) => x - .toMap - - val thisVars = ctx.usedLocals(f.sym) - // add the mapping from this function's locals to the capture's symbols and the capture path - val captureSym = TempSymbol(N, "capture") - val captureCtx = ctx - .addLocalCaptureSyms(varsMap) // how to access locals via the capture class from now on - .addCapturePath(f.sym, LocalPath.Sym(captureSym)) // the path to this function's capture - .addLocalPaths((thisVars.vars.toSet -- thisVars.reqCapture).map(s => s -> LocalPath.Sym(s)).toMap) - .addLocalPaths(modPaths) - .addIgnoredBmsPaths(ignored.map(d => d.sym -> LocalPath.Sym(d.sym)).toMap) - val nestedCtx = captureCtx.addFnLocals(captureCtx.usedLocals(f.sym)) - - // lift out the nested defns - // for lifted definitions, any accessible isyms go out of scope, so call resetScope - val nestedLifted = included.map(liftOutDefnCont(f, _, nestedCtx.flushModules.resetScope)) - val ignoredLifted = ignored.map(liftOutDefnCont(f, _, nestedCtx)) - val ignoredExtra = ignoredLifted.flatMap(_.extraDefns) - val newDefns = ignoredExtra ++ nestedLifted.flatMap: - case Lifted(liftedDefn, extraDefns) => liftedDefn :: extraDefns + def rewriteCall(c: Call, args: List[Arg])(using ctx: LifterCtxNew): Call = + if obj.isObj then lastWords("tried to rewrite instantiate for an object") + if isTrivial then + if c.args is args then c + else c.copy(args = args)(c.isMlsFun, c.mayRaiseEffects, c.explicitTailCall) + else + Call( + Value.Ref(flattenedSym, S(flattenedDSym)), + Value.Lit(Tree.BoolLit(false)).asArg :: formatArgs ::: args + )( + isMlsFun = true, + mayRaiseEffects = c.mayRaiseEffects, + explicitTailCall = c.explicitTailCall + ) + + def rewriteImpl: LifterResult[ClsLikeDefn] = + val rewriterCtor = new BlockRewriter + val rewriterPreCtor = new BlockRewriter + val rewrittenCtor = rewriterCtor.rewrite(obj.cls.ctor) + val rewrittenPrector = rewriterPreCtor.rewrite(obj.cls.preCtor) - val ignoredRewrite = ignoredLifted.map: lifted => - lifted.liftedDefn.sym -> lifted.liftedDefn - .toMap - - val transformed = BlockRewriter(ctx.inScopeISyms, captureCtx.addreplacedDefns(ignoredRewrite)).rewrite(blk) - val newScopedBlk = removeDefnsFromScope(transformed, included) - - if thisVars.reqCapture.size == 0 then - Lifted(FunDefn(f.owner, f.sym, f.dSym, f.params, newScopedBlk)(forceTailRec = f.forceTailRec), newDefns) - else - // move the function's parameters to the capture - val paramsSet = f.params.flatMap(_.paramSyms) - val paramsList = varsList.map: s => - (if paramsSet.contains(s) then s.asPath else Value.Lit(Tree.UnitLit(true))).asArg - // moved when the capture is instantiated - val bod = blockBuilder - .assignScoped(captureSym, Instantiate(mut = true, // * Note: `mut` is needed for capture classes - captureCls.sym.asPath, paramsList)) - .rest(newScopedBlk) - val withScope = Scoped(Set(captureSym), bod) - Lifted(FunDefn(f.owner, f.sym, f.dSym, f.params, withScope)(forceTailRec = f.forceTailRec), captureCls :: newDefns) - - end liftDefnsInFn - - // top-level - def transform(_blk: Block) = - // this is already done once in the lowering, but the handler lowering adds lambdas currently - // so we need to desugar them again - val blk = LambdaRewriter.desugar(_blk) - - val analyzer = UsedVarAnalyzer(blk) - val ctx = LifterCtx - .withLocals(analyzer.findUsedLocals) - .withDefns(analyzer.defnsMap) - .withNestedDefns(analyzer.nestedDefns) - .withAccesses(analyzer.accessMap) - .withInScopes(analyzer.inScopeDefns) - .withCompanionMap(analyzer.companionMap) - - val walker1 = new BlockTransformerShallow(SymbolSubst()): - override def applyBlock(b: Block): Block = - b match - case Define(d, rest) => - val LifterMetadata(unliftable, modules, objects, firstClsFns) = createMetadata(d, ctx) - - val modObjLocals = (modules ++ objects).map: c => - analyzer.nestedIn.get(c.sym) match - case Some(bms) => - val nestedIn = analyzer.defnsMap(bms) - nestedIn match - // These will be the names of the objects/modules after being lifted - // We should use the nested object/module's **original name** if nested inside a class, - // so they can be accesed directly by name from the outside. - // For example, if a class C has an object M, (new C).M as a dynamic selection works - case cls: ClsLikeDefn => S(c.sym -> TermSymbol(syntax.ImmutVal, S(cls.isym), Tree.Ident(c.sym.nme))) - case _ => S(c.sym -> VarSymbol(Tree.Ident(c.sym.nme + "$"))) - case _ => N - .collect: - case S(v) => v - .toMap - - val ctxx = ctx - .addIgnored(unliftable) - .withModObjLocals(modObjLocals) - .withFirstClsFns(firstClsFns) - - val Lifted(lifted, extra) = d match - case f: FunDefn => - val ctxxx = ctxx.withDefnsCur(analyzer.nestedDeep(d.sym)) - liftDefnsInFn(f, ctxxx.addBmsReqdInfo(createLiftInfoFn(f, ctxxx))) - case c: ClsLikeDefn => - val ctxxx = ctxx.withDefnsCur(analyzer.nestedDeep(d.sym)) - liftDefnsInCls(c, ctxxx.addBmsReqdInfo(createLiftInfoCls(c, ctxxx))) - case _ => return super.applyBlock(b) - val newDefns = lifted :: extra - val newBms = extra.map(_.sym) - val newBlk = newDefns.foldLeft(applyBlock(rest))((acc, defn) => Define(defn, acc)) - Scoped(newBms.toSet, newBlk) - case _ => super.applyBlock(b) - walker1.applyBlock(blk) + val ctorWithCap = addExtraSyms(rewrittenCtor, captureSym, Nil, false) + + // Assign passed locals and captures + val ctorWithPassed = passedSymsOrdered.foldRight(ctorWithCap): + case (sym, acc) => + val (vs, ts) = passedSymsMap_(sym) + Assign(ts, vs.asPath, acc) + val ctorWithCaps = capturesOrdered.foldRight(ctorWithPassed): + case (sym, acc) => + val (vs, ts) = capSymsMap_(sym) + Assign(ts, vs.asPath, acc) + val ctorWithDefns = passedDefnsOrdered.foldRight(ctorWithCaps): + case (sym, acc) => + val (vs, ts) = defnSymsMap_(sym) + Assign(ts, vs.asPath, acc) + + val newAuxList = + if isTrivial then cls.auxParams + else PlainParamList(auxParams) :: cls.auxParams + + val LifterResult(newMtds, extras) = rewriteMethods(node, obj.cls.methods) + val newCls = obj.cls.copy( + owner = N, + k = syntax.Cls, // turn objects into classes + ctor = ctorWithDefns, + preCtor = rewrittenPrector, + privateFields = captureSym :: liftedObjsSyms.values.toList ::: obj.cls.privateFields, + methods = newMtds, + auxParams = newAuxList + ) + val extrasDefns = rewriterCtor.extraDefns.toList ::: rewriterPreCtor.extraDefns.toList ::: extras + mkFlattenedDefn match + case Some(value) => LifterResult(newCls, value :: extrasDefns) + case None => LifterResult(newCls, extrasDefns) + + private def createRewritten[T](s: TScopeNode[T])(using ctx: LifterCtxNew): RewrittenScope[T] = s.obj match + case _: ScopedObject.Top => lastWords("tried to rewrite the top-level scope") + case o: ScopedObject.Class => + if s.isLifted && !s.isTopLevel then LiftedClass(o) + else RewrittenClass(o) + case o: ScopedObject.Companion => RewrittenCompanion(o) + case o: ScopedObject.ClassCtor => RewrittenClassCtor(o) + case o: ScopedObject.Func => + if s.isLifted && !s.isTopLevel then LiftedFunc(o) + else RewrittenFunc(o) + case o: ScopedObject.Loop => RewrittenLoop(o) + case o: ScopedObject.ScopedBlock => + RewrittenScopedBlock(o) + + // Note: we must write this as a definition here to have tighter types + private def rewriteScope[T <: Defn](l: LiftedScope[T])(using ctx: LifterCtxNew) = + val LifterResult[T](d1, d2) = liftNestedScopes[T](l) + (d1, d2) + + /** + * Lifts scopes nested within `s`, and then rewrites `s`. + * + * @param s The scope to be rewritten. + * @param r The rewritten scope associated with `s`. + * @param ctx The lifter context. + * @return The rewritten scope with the additional definitions. + */ + private def liftNestedScopesImpl[T](scope: RewrittenScope[T])(using ctx: LifterCtxNew): LifterResult[T] = + val node = scope.node + + // Add the symbols map of the current scope + // Note: this will be reset to the original value in liftNestedScopes + ctx.symbolsMap ++= scope.symbolsMap + ctx.capturesMap ++= scope.capturePaths + ctx.defnsMap ++= scope.defnPaths + + val rewrittenScopes = node.children.map(createRewritten) + // The scopes in `lifted` will be rewritten right now + // The scopes in `ignored` will be rewritten in-place when traversing the block + val (lifted, ignored) = rewrittenScopes.partitionMap: + case s: LiftedScope[?] => L(s) + case s => R(s) + for r <- rewrittenScopes do + ctx.rewrittenScopes.put(r.obj.toInfo, r) + for l <- lifted do + ctx.liftedScopes.put(l.obj.sym, l) + + val LifterResult(rewrittenObj, extraDefns) = scope.rewrite + val (res1, res2) = lifted.map(rewriteScope).unzip + val defns = res1 ++ res2.flatten ++ extraDefns + LifterResult(rewrittenObj, defns) + + + def liftNestedScopes[T](r: RewrittenScope[T])(using ctx: LifterCtxNew): LifterResult[T] = + val curSyms = ctx.symbolsMap + val curCaptures = ctx.capturesMap + val curDefns = ctx.defnsMap + val ret = liftNestedScopesImpl(r) + ctx.symbolsMap = curSyms + ctx.capturesMap = curCaptures + ctx.defnsMap = curDefns + ret + + def transform = + given ctx: LifterCtxNew = new LifterCtxNew + val root = data.root + + val children = root.children + children.foreach: c => + ctx.rewrittenScopes.put(c.obj.toInfo, createRewritten(c)) + + val topLevelRewriter = new ScopeRewriter + + val (syms, top) = root.obj.contents match + case Scoped(syms, body) => + (syms.toSet, body) + case b => (Set.empty, b) + + val transformed = topLevelRewriter.applyBlock(top) + val newSyms = syms ++ topLevelRewriter.extraDefns.map(_.sym) + val withDefns = topLevelRewriter.extraDefns.foldLeft(transformed): + case (acc, d) => Define(d, acc) + Scoped(newSyms, withDefns) + + \ No newline at end of file diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 800e696a27..16da7399b8 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -1046,15 +1046,14 @@ class Lowering()(using Config, TL, Raise, State, Ctx): val withHandlers1 = config.effectHandlers.fold(desug): opt => HandlerLowering(handlerPaths, opt).translateHandleBlocks(desug) - // TODO: Refactor the lifter so it does not require flattened scopes - val shouldFlattenScopes = config.effectHandlers.isDefined || config.liftDefns.isDefined + val shouldFlattenScopes = config.effectHandlers.isDefined val scopeFlattened = if shouldFlattenScopes then ScopeFlattener().applyBlock(withHandlers1) else withHandlers1 val lifted = - if lift then Lifter().transform(scopeFlattened) + if lift then Lifter(scopeFlattened, handlerPaths).transform else scopeFlattened val (withHandlers2, stackSafetyInfo) = config.effectHandlers.fold((lifted, Map.empty)): opt => diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/ScopeData.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/ScopeData.scala new file mode 100644 index 0000000000..398ae8004f --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/ScopeData.scala @@ -0,0 +1,399 @@ +package hkmc2 + +import mlscript.utils.*, shorthands.* +import utils.* + +import hkmc2.codegen.* +import hkmc2.semantics.* +import hkmc2.ScopeData.* +import hkmc2.semantics.Elaborator.State + +import hkmc2.syntax.Tree +import hkmc2.codegen.llir.FreshInt +import java.util.IdentityHashMap +import scala.collection.mutable.Map as MutMap +import scala.collection.mutable.Set as MutSet + +object ScopeData: + opaque type ScopeUID = BigInt + val dummyUID: ScopeUID = 0 + class FreshUID: + private val underlying = FreshInt() + def make: ScopeUID = underlying.make + + type ScopedInfo = DefinitionSymbol[?] | LabelSymbol | ScopeUID | Unit + + // ScopeData requires the set of ignored scopes to compute certain things, but + // the lifter requires the scope tree to generate the metadata. To solve this, + // we generate the scope tree then populate the metadata later. + case class IgnoredScopes(var ignored: Opt[Set[ScopedInfo]]) + + type ScopedObject = ScopedObject.ScopedObject[?] + type TScopedObject[T] = ScopedObject.ScopedObject[T] + + type LiftedSym = DefinitionSymbol[?] + + extension (d: DefinitionSymbol[?]) + def asBmsRef = Value.Ref(d.asBlkMember.get, S(d)) + + // These cannot be hashed + object ScopedObject: + // T: The actual contents of the scoped object + sealed abstract class ScopedObject[T]: + var node: Opt[TScopeNode[T]] = N + lazy val toInfo: ScopedInfo = this match + case Top(_) => () + case Class(cls, _) => cls.isym + case Companion(comp, par) => comp.isym + case ClassCtor(cls) => cls.ctorSym.get + case Func(fun, _) => fun.dSym + case ScopedBlock(uid, block) => uid + case Loop(sym, _) => sym + + // note: not unique + lazy val nme = this match + case Top(b) => "top" + case Class(cls, _) => cls.isym.nme + case Companion(comp, par) => comp.isym.nme + "_mod" + case ClassCtor(cls) => cls.isym.nme // should be unused + case Func(fun, isMethod) => fun.dSym.nme + case Loop(sym, block) => "loop$" + sym.uid.toString() + case ScopedBlock(uid, block) => "scope" + uid + + // Locals defined by a scoped object. + lazy val definedLocals: Set[Local] = this match + // we want definedLocals for the top level scope to be empty, because otherwise, + // the lifter may try to capture those locals. + case Top(b) => Set.empty + case Class(cls, _) => + // Public fields are not included, as they are accessed using + // a field selection rather than directly using the BlockMemberSymbol. + val paramsSet: Set[Local] = cls.paramsOpt match + case Some(value) => value.params.map(_.sym).toSet + case None => Set.empty + val auxSet: Set[Local] = cls.auxParams.flatMap: p => + p.params.map(_.sym) + .toSet + paramsSet ++ auxSet ++ cls.privateFields + cls.isym + case Companion(comp, par) => + comp.privateFields.toSet + comp.isym + case _: ClassCtor => Set.empty + case Func(fun, _) => fun.params.flatMap: p => + p.params.map(_.sym) + .toSet + case ScopedBlock(_, block) => block.syms.toSet + case _: Loop => Set.empty + + def contents: T = this match + case Top(b) => b + case Class(cls, _) => cls + case Companion(comp, par) => comp + case ClassCtor(cls) => () + case Func(fun, _) => fun + case ScopedBlock(_, block) => block + case Loop(_, blk) => blk + + // Scoped nodes which may be referenced using a symbol. + sealed abstract class Referencable[T] extends TScopedObject[T]: + def sym: LiftedSym = this match + case Class(cls, _) => cls.isym + case Companion(comp, par) => comp.isym + case Func(fun, isMethod) => fun.dSym + case ClassCtor(cls) => cls.ctorSym.get + def bsym: BlockMemberSymbol = this match + case Class(cls, _) => cls.sym + case Companion(comp, par) => par.sym + case Func(fun, isMethod) => fun.sym + case ClassCtor(cls) => cls.sym + def owner: Opt[InnerSymbol] = this match + case Class(cls, _) => cls.owner + case Companion(comp, par) => par.owner + case ClassCtor(cls) => cls.owner + case Func(fun, isMethod) => fun.owner + + + // Scoped nodes which could possibly be lifted to the top level. + sealed abstract class Liftable[T <: Defn] extends Referencable[T]: + val defn: T + + // The top-level scope. + case class Top(b: Block) extends ScopedObject[Block] // b may be a scoped block, in which case, its variables represent the top-level variables. + case class Class(cls: ClsLikeDefn, isObj: Bool) extends Liftable[ClsLikeDefn]: + val defn = cls + case class Companion(comp: ClsLikeBody, par: ClsLikeDefn) extends Referencable[ClsLikeBody] + // We model it like this: the ctor is just another function in the same scope as the class and initializes the corresponding class + case class ClassCtor(cls: ClsLikeDefn) extends Referencable[Unit] + // isMethod: + // N = not a method + // S(false) = module method + // S(true) = class or object method + case class Func(fun: FunDefn, isMethod: Opt[Bool]) extends Liftable[FunDefn]: + val defn = fun + // The purpose of `Loop` is to enforce the rule that the control flow remains linear when we enter + // a scoped block. + case class Loop(sym: LabelSymbol, body: Block) extends ScopedObject[Block] + case class ScopedBlock(uid: ScopeUID, block: Scoped) extends ScopedObject[Block] + + extension (traverser: BlockTraverser) + def applyScopedObject(obj: ScopedObject) = + extension (s: Symbol) def traverse = + traverser.applySymbol(s) + obj match + case ScopedObject.Top(b) => traverser.applyBlock(b) + case ScopedObject.Class(ClsLikeDefn(own, isym, sym, ctorSym, k, paramsOpt, auxParams, parentPath, methods, + privateFields, publicFields, preCtor, ctor, mod, bufferable), _) + => + // do not traverse the companion -- it is a separate kind of scoped object + // and will therefore be traversed separately + own.foreach(_.traverse) + isym.traverse + sym.traverse + ctorSym.foreach(_.traverse) + paramsOpt.foreach(traverser.applyParamList) + auxParams.foreach(traverser.applyParamList) + parentPath.foreach(traverser.applyPath) + methods.foreach(traverser.applyFunDefn) + privateFields.foreach(_.traverse) + publicFields.foreach: f => + f._1.traverse; f._2.traverse + traverser.applySubBlock(preCtor) + traverser.applySubBlock(ctor) + case ScopedObject.Companion(comp, par) => traverser.applyClsLikeBody(comp) + case ScopedObject.Func(fun, isMethod) => traverser.applyFunDefn(fun) + case ScopedObject.ScopedBlock(uid, block) => traverser.applyBlock(block) + case ScopedObject.ClassCtor(c) => () + case ScopedObject.Loop(_, b) => traverser.applyBlock(b) + + // A simple tree data structure representing the nesting relation of definitions and scopes. + class NestedScopeTree(val root: TScopeNode[Block]): + val nodesMap: Map[ScopedInfo, ScopeNode] = root.allChildNodes.map(n => n.obj.toInfo -> n).toMap + + type ScopeNode = ScopeNode.ScopeNode[?] + type TScopeNode[T] = ScopeNode.ScopeNode[T] + object ScopeNode: + case class ScopeNode[T](obj: TScopedObject[T], var parent: Opt[ScopeNode[?]], children: List[ScopeNode[?]])(using ignoredScopes: IgnoredScopes): + + lazy val allParents: List[ScopeNode[?]] = parent match + case Some(value) => this :: value.allParents + case None => this :: Nil + + lazy val parentsSet = allParents.map(_.obj.toInfo).toSet + + def inSubtree(root: ScopedInfo) = parentsSet.contains(root) + + // note: includes itself + lazy val allChildNodes: List[ScopeNode[?]] = this :: children.flatMap(_.allChildNodes) + lazy val allChildren: List[ScopedObject] = allChildNodes.map(_.obj) + + // does not include variables introduced by itself + lazy val existingVars: Set[Local] = parent match + case Some(value) => value.existingVars ++ value.obj.definedLocals ++ value.nestedObjSyms + case None => Set.empty + + lazy val isTopLevel: Bool = parent match + case Some(ScopeNode(obj = _: ScopedObject.Top)) => true + case _ => false + + lazy val isModOrTopLevel: Bool = parent match + case Some(par) => par.obj match + case _: ScopedObject.Companion => true + case _ => false + case None => true + + // Scoped blocks include the BlockMemberSymbols of their nested definitions. This removes them. + lazy val localsWithoutBms: Set[Local] = obj match + case s: ScopedObject.ScopedBlock => + val rmv = children.collect: + case c @ ScopeNode(obj = s: ScopedObject.Referencable[?]) => s.bsym + obj.definedLocals -- rmv + case _ => obj.definedLocals + + lazy val nestedObjSyms: Set[InnerSymbol] = children.collect: + case ScopeNode(obj = c: ScopedObject.Class) if c.isObj => c.cls.isym + .toSet + + lazy val liftedObjSyms: Set[InnerSymbol] = children.collect: + case n @ ScopeNode(obj = c: ScopedObject.Class) if c.isObj && n.isLifted => c.cls.isym + .toSet + + lazy val inScopeISyms: Set[Local] = + val parVals = parent match + case Some(value) => value.inScopeISyms + case None => Set.empty + + obj match + case c: ScopedObject.Class => parVals + c.cls.isym + case c: ScopedObject.Companion => parVals + c.comp.isym + case _ => parVals + + // All of the following must not be called until ignoredScopes is populated with the relevant data. + + lazy val isLifted: Bool = + def impl: Bool = + val ignored = ignoredScopes.ignored match + case Some(value) => value + case None => lastWords("isLifted accessed before the set of ignored scopes was set") + + // check if the owner is a module + obj match + case r: ScopedObject.Referencable[?] => r.owner match + case Some(value) => if value.asMod.isDefined then return false + case _ => () + case _ => () + + parent.map(_.obj) match + case Some(_: ScopedObject.Companion) => false // there is no need to lift objects nested inside a module + case _ => + obj match + // case _: ScopedObject.Companion => false + // case c: ScopedObject.Class if c.cls.companion.isDefined => false + case _ if ignored.contains(obj.toInfo) => false + case _ if isModOrTopLevel => false + case ScopedObject.Func(isMethod = S(true)) => false + case _: ScopedObject.Loop | _: ScopedObject.ClassCtor | _: ScopedObject.ScopedBlock | _: ScopedObject.Companion => false + case _ => true + impl + + lazy val liftedChildNodes: List[ScopeNode[?]] = + if isLifted then this :: Nil + else children.flatMap(_.liftedChildNodes) + + // Finds the first parent that is a lifted object, i.e. a non-ignored definition, or the top level + lazy val firstLiftedParent: ScopedObject = + if !isLifted then + parent match + case Some(value) => value.firstLiftedParent + case None => obj // unreachable + else obj + + // When a node is lifted, some neighbouring ignored definitions may become out of scope. This computes + // the list of these definitions, and they could be passed to this node as a parameter once lifted. + private lazy val reqCaptureObjsImpl: List[ScopedObject.Referencable[?]] = obj match + case _: ScopedObject.Top => List.empty + case _ => + // All unlifted neighbour nodes ::: parent's reqCaptureObjsImpl + val initial = parent.get.children.collect: + case c @ ScopeNode(obj = t: ScopedObject.Referencable[?]) if !c.isLifted => t + initial ::: parent.get.reqCaptureObjsImpl + + lazy val reqCaptureObjs: List[ScopedObject.Referencable[?]] = obj match + case _: ScopedObject.Top => List.empty + case _ => + if isLifted then reqCaptureObjsImpl + else parent.get.reqCaptureObjsImpl + + + def dSymUnapply(data: ScopeData, v: DefinitionSymbol[?] | Option[DefinitionSymbol[?]]) = v match + case Some(d) if data.contains(d) => S(d) + case d: DefinitionSymbol[?] if data.contains(d) => S(d) + case _ => None + +class ScopeData(b: Block)(using State, IgnoredScopes): + import ScopeData.* + + private val fresh = FreshUID() + + val scopeTree = NestedScopeTree(makeScopeTreeRec(ScopedObject.Top(b))) + val root = scopeTree.root + val allBms = root.allChildren.collect: + case s: ScopedObject.Referencable[?] => s.bsym + + def contains(s: ScopedInfo) = scopeTree.nodesMap.contains(s) + + private val scopedMap: IdentityHashMap[Scoped, ScopeUID] = new IdentityHashMap + for + case ScopeNode(obj = ScopedObject.ScopedBlock(uid, blk)) <- scopeTree.root.allChildNodes + do + scopedMap.put(blk, uid) + def getNode(x: ScopedInfo): ScopeNode = scopeTree.nodesMap(x) + def getNode(defn: ClsLikeDefn): ScopeNode = getNode(defn.isym) + def getNode(companion: ClsLikeBody): ScopeNode = getNode(companion.isym) + def getNode(defn: FunDefn): ScopeNode = getNode(defn.dSym) + def getUID(blk: Scoped): ScopeUID = + if scopedMap.containsKey(blk) then scopedMap.get(blk) + else lastWords("getUID: key not found") + def getNode(blk: Scoped): ScopeNode = getNode(getUID(blk)) + // From the input block or definition, traverses until a function, class or new scoped block is found and appends them. + class ScopeFinder extends BlockTraverserShallow: + var objs: List[ScopedObject] = Nil + override def applyBlock(b: Block): Unit = b match + case s: Scoped => + val id = fresh.make + objs ::= ScopedObject.ScopedBlock(id, s) + case l: Label if l.loop => + objs ::= ScopedObject.Loop(l.label, l.body) + case _ => super.applyBlock(b) + override def applyFunDefn(fun: FunDefn): Unit = + objs ::= ScopedObject.Func(fun, N) + override def applyDefn(defn: Defn): Unit = defn match + case f: FunDefn => applyFunDefn(f) + case c: ClsLikeDefn => + objs ::= ScopedObject.Class(c, c.k === syntax.Obj) + c.ctorSym match + case Some(value) => objs ::= ScopedObject.ClassCtor(c) + case None => () + c.companion.map: comp => + objs ::= ScopedObject.Companion(comp, c) + + case _ => super.applyDefn(defn) + + + def scopeFinder = new ScopeFinder() + + def makeScopeTreeRec[T](obj: TScopedObject[T]): TScopeNode[T] = + // An annoying thing with class ctors: + // Sometimes, nested classes/objects appear inside the top-level scoped block of class/module ctor, + // despite being a child of the class, but we want these to be a direct child of the class/module. + val finder = scopeFinder + val ctorFinder = scopeFinder + val ctorScoped = obj match + case ScopedObject.Class(cls, _) => cls.ctor match + case s: Scoped => S((s, cls.isym)) + case _ => N + case ScopedObject.Companion(comp, _) => comp.ctor match + case s: Scoped => S((s, comp.isym)) + case _ => N + case _ => N + obj match + case ScopedObject.Top(s: Scoped) => finder.applyBlock(s.body) + case ScopedObject.Top(b) => finder.applyBlock(b) + case ScopedObject.Class(cls, _) => + finder.applyBlock(cls.preCtor) + ctorScoped match + case Some((value, _)) => ctorFinder.applyBlock(value.body) + case None => finder.applyBlock(cls.ctor) + case ScopedObject.Companion(comp, par) => + ctorScoped match + case Some((value, _)) => ctorFinder.applyBlock(value.body) + case None => finder.applyBlock(comp.ctor) + case ScopedObject.Func(fun, _) => + finder.applyBlock(fun.body) + case ScopedObject.ScopedBlock(_, block) => + finder.applyBlock(block.body) + case ScopedObject.ClassCtor(c) => () + case ScopedObject.Loop(_, b) => finder.applyBlock(b) + val mtdObjs = obj match + case ScopedObject.Class(cls, _) => cls.methods.map(ScopedObject.Func(_, S(true))) + case ScopedObject.Companion(comp, par) => comp.methods.map(ScopedObject.Func(_, S(false))) + case _ => Nil + + // This extracts owned definitions from the ctor and makes them a descendant of the class scope node, + // while making the other scoped objects in the ctor a descendant of the ctor's scoped block node. + val (ctorNode, ctorObjs) = ctorScoped match + case Some((ctor, isym)) => + val ctorBlkObj = ScopedObject.ScopedBlock(fresh.make, ctor) + val (ctorObjs, ctorBlkChildren) = ctorFinder.objs.partitionMap: + case a: ScopedObject.Referencable[?] if a.owner.isDefined && a.owner.get === isym => L(a) + case a => R(a) + val children = ctorBlkChildren.map(makeScopeTreeRec) + val ctorNde = ScopeNode.ScopeNode(ctorBlkObj, N, children) + ctorBlkObj.node = S(ctorNde) + for c <- children do c.parent = S(ctorNde) + (S(ctorNde), ctorObjs) + case None => (N, Nil) + + val children = (ctorObjs ::: mtdObjs ::: finder.objs).map(makeScopeTreeRec).prependedAll(ctorNode) + val retNode = ScopeNode.ScopeNode(obj, N, children) + obj.node = S(retNode) + for c <- children do c.parent = S(retNode) + retNode diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala index 4b5c976354..ded3a92436 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala @@ -9,9 +9,29 @@ import hkmc2.codegen.* import hkmc2.semantics.* import hkmc2.Message.* import hkmc2.semantics.Elaborator.State +import hkmc2.ScopeData.* +import hkmc2.Lifter.* import scala.collection.mutable.Map as MutMap - +import scala.collection.mutable.Set as MutSet +import scala.jdk.CollectionConverters.* +import java.util.IdentityHashMap +import java.util.Collections +import scala.collection.mutable.Buffer + +object UsedVarAnalyzer: + case class MutAccessInfo( + accessed: MutSet[Local], + mutated: MutSet[Local], + refdDefns: MutSet[ScopedInfo] + ): + def toIMut = AccessInfo(accessed.toSet, mutated.toSet, refdDefns.toSet) + object MutAccessInfo: + def empty = MutAccessInfo( + MutSet.empty, + MutSet.empty, + MutSet.empty + ) /** * Analyzes which variables have been used and mutated by which functions. * Also finds which variables can be passed to a capture class without a heap @@ -19,280 +39,325 @@ import scala.collection.mutable.Map as MutMap * * Assumes the input trees have no lambdas. */ -class UsedVarAnalyzer(b: Block)(using State): - import Lifter.* - - private case class DefnMetadata( - definedLocals: Map[BlockMemberSymbol, Set[Local]], // locals defined explicitly by that function - defnsMap: Map[BlockMemberSymbol, Defn], // map bms to defn - existingVars: Map[BlockMemberSymbol, Set[Local]], // variables already existing when that defn is defined - inScopeDefns: Map[BlockMemberSymbol, Set[BlockMemberSymbol]], // definitions that are in scope and not nested within this defn, and not including itself - nestedDefns: Map[BlockMemberSymbol, List[Defn]], // definitions that are a successor of the current defn - nestedDeep: Map[BlockMemberSymbol, Set[BlockMemberSymbol]], // definitions nested within another defn, including that defn (deep) - nestedIn: Map[BlockMemberSymbol, BlockMemberSymbol], // the definition that a definition is directly nested in - companionMap: Map[InnerSymbol, InnerSymbol], // a (bijective) map between companion object symbols and class symbols - ) - private def createMetadata: DefnMetadata = - var defnsMap: Map[BlockMemberSymbol, Defn] = Map.empty - var definedLocals: Map[BlockMemberSymbol, Set[Local]] = Map.empty - var existingVars: Map[BlockMemberSymbol, Set[Local]] = Map.empty - var inScopeDefns: Map[BlockMemberSymbol, Set[BlockMemberSymbol]] = Map.empty - var nestedDefns: Map[BlockMemberSymbol, List[Defn]] = Map.empty - var nestedDeep: Map[BlockMemberSymbol, Set[BlockMemberSymbol]] = Map.empty - var nestedIn: Map[BlockMemberSymbol, BlockMemberSymbol] = Map.empty - var companionMap: Map[InnerSymbol, InnerSymbol] = Map.empty - - def createMetadataFn(f: FunDefn, existing: Set[Local], inScope: Set[BlockMemberSymbol]): Unit = - var nested: Set[BlockMemberSymbol] = Set.empty - - existingVars += (f.sym -> existing) - val thisVars = Lifter.getVars(f) -- existing - val newExisting = existing ++ thisVars - - val thisScopeDefns: List[Defn] = f.body.gatherDefns() - - nestedDefns += f.sym -> thisScopeDefns - - val newInScope = inScope ++ thisScopeDefns.map(_.sym) - for s <- thisScopeDefns do - inScopeDefns += s.sym -> (newInScope - s.sym) - nested += s.sym - - defnsMap += (f.sym -> f) - definedLocals += (f.sym -> thisVars) - - for d <- thisScopeDefns do - nestedIn += (d.sym -> f.sym) - createMetadataDefn(d, newExisting, newInScope) - nested ++= nestedDeep(d.sym) - - nestedDeep += f.sym -> nested - - def createMetadataDefn(d: Defn, existing: Set[Local], inScope: Set[BlockMemberSymbol]): Unit = - d match - case f: FunDefn => - createMetadataFn(f, existing, inScope) - case c: ClsLikeDefn => - createMetadataCls(c, existing, inScope) - case d => Map.empty - - def createMetadataCls(c: ClsLikeDefn, existing: Set[Local], inScope: Set[BlockMemberSymbol]): Unit = - var nested: Set[BlockMemberSymbol] = Set.empty - - existingVars += (c.sym -> existing) - val thisVars = Lifter.getVars(c) -- existing - val newExisting = existing ++ thisVars - - val thisScopeDefns: List[Defn] = c.methods ++ c.preCtor.gatherDefns() - ++ c.ctor.gatherDefns() ++ c.companion.fold(Nil)(comp => comp.ctor.gatherDefns() ++ comp.methods) - - nestedDefns += c.sym -> thisScopeDefns - - val newInScope = inScope ++ thisScopeDefns.map(_.sym) - for s <- thisScopeDefns do - inScopeDefns += s.sym -> (newInScope - s.sym) - nested += s.sym - - defnsMap += (c.sym -> c) - definedLocals += (c.sym -> thisVars) - - for d <- thisScopeDefns do - nestedIn += (d.sym -> c.sym) - createMetadataDefn(d, newExisting, newInScope) - nested ++= nestedDeep(d.sym) - - nestedDeep += c.sym -> nested - - c.companion match - case None => - case Some(value) => companionMap += (value.isym -> c.isym) - +class UsedVarAnalyzer(b: Block, scopeData: ScopeData)(using State, IgnoredScopes): + import UsedVarAnalyzer.* + + object SDSym: + def unapply(v: DefinitionSymbol[?] | Option[DefinitionSymbol[?]]) = dSymUnapply(scopeData, v) + + // Finds the locals that this block accesses/mutates, and the definitions which it could use. + private def blkAccessesShallow(b: Block): AccessInfo = + var accessed: MutAccessInfo = MutAccessInfo.empty new BlockTraverserShallow: - // If there's any variables available at the top-level we need to explicitly ignore, - // then we add them here - val ignoredVars = b.definedVars applyBlock(b) - override def applyDefn(defn: Defn): Unit = - inScopeDefns += defn.sym -> Set.empty - createMetadataDefn(defn, ignoredVars, Set.empty) - DefnMetadata(definedLocals, defnsMap, existingVars, inScopeDefns, nestedDefns, nestedDeep, nestedIn, companionMap) - - val DefnMetadata(definedLocals, defnsMap, existingVars, - inScopeDefns, nestedDefns, nestedDeep, nestedIn, companionMap) = createMetadata - - private val blkMutCache: MutMap[Local, AccessInfo] = MutMap.empty - private def blkAccessesShallow(b: Block, cacheId: Opt[Local] = N): AccessInfo = - cacheId.flatMap(blkMutCache.get) match - case Some(value) => value - case None => - var accessed: AccessInfo = AccessInfo.empty - new BlockTraverserShallow: - applyBlock(b) - - override def applyBlock(b: Block): Unit = b match - case Assign(lhs, rhs, rest) => - accessed = accessed.addMutated(lhs) - applyResult(rhs) - applyBlock(rest) - case Label(label, loop, body, rest) => - accessed ++= blkAccessesShallow(body, S(label)) - applyBlock(rest) - case _ => super.applyBlock(b) + + override def applyBlock(b: Block): Unit = b match + case s: Scoped => + accessed.refdDefns.add(scopeData.getUID(s)) + case Assign(lhs, rhs, rest) => + accessed.accessed.add(lhs) + accessed.mutated.add(lhs) + applyResult(rhs) + applyBlock(rest) + case l: Label if l.loop => + accessed.refdDefns.add(l.label) + case d: Define => d.defn match + case v: ValDefn => + applyDefn(v) + applySubBlock(d.rest) + case c @ ClsLikeDefn(k = syntax.Obj) => + accessed.refdDefns.add(c.isym) + case _ => applySubBlock(d.rest) - override def applyValue(v: Value): Unit = v match - case Value.Ref(_: BuiltinSymbol, _) => super.applyValue(v) - case RefOfBms(l, _) => - accessed = accessed.addRefdDefn(l) - case Value.Ref(l, _) => - accessed = accessed.addAccess(l) - case _ => super.applyValue(v) - - cacheId match - case None => () - case Some(value) => blkMutCache.addOne(value -> accessed) + case _ => super.applyBlock(b) - accessed - - private val accessedCache: MutMap[BlockMemberSymbol, AccessInfo] = MutMap.empty - + override def applyPath(p: Path): Unit = p match + case Value.Ref(_: BuiltinSymbol, _) => super.applyPath(p) + case RefOfBms(_, SDSym(dSym)) => + // Check if it's referencing a class method. + // If so, then it requires reading the class symbol + val node = scopeData.getNode(dSym) + node.obj match + // for definitions nested inside a class: they need the InnerSymbol of the class instance + case f @ ScopedObject.Func(isMethod = S(true)) => accessed.accessed.add(f.fun.owner.get) + // definitions that access a module's method directly need an edge to that method + case ScopedObject.Func(isMethod = N | S(false)) => + accessed.refdDefns.add(node.obj.toInfo) + case c: ScopedObject.Class if c.isObj => accessed.accessed.add(c.cls.isym) + case _: ScopedObject.Class | _: ScopedObject.ClassCtor | _: ScopedObject.Companion => accessed.refdDefns.add(node.obj.toInfo) + case _ => () + + case Value.Ref(l, _) => + accessed.accessed.add(l) + case _ => super.applyPath(p) + accessed.toIMut + /** - * Finds the variables which this definition could possibly mutate, excluding mutations through - * calls to other functions and, in the case of functions, mutations of its own variables. - * - * @param defn The definition to search through. + * Finds the variables belonging to a parent scope which this scoped object could possibly + * access or mutate, excluding mutations through calls to other functions and mutations + * of their own variables. Also finds the other scoped objects that this definition may enter. + * + * @param obj The scoped object to search through. * @return The variables which this definition could possibly mutate. */ - private def findAccessesShallow(defn: Defn): AccessInfo = - def create = defn match - case f: FunDefn => - val fVars = definedLocals(f.sym) - blkAccessesShallow(f.body).withoutLocals(fVars) - case c: ClsLikeDefn => - val methodSyms = c.methods.map(_.sym).toSet - c.methods.foldLeft(blkAccessesShallow(c.preCtor) ++ blkAccessesShallow(c.ctor)): - case (acc, fDefn) => - // class methods do not need to be lifted, so we don't count calls to their methods. - // a previous reference to this class's block member symbol is enough to assume any - // of the class's methods could be called. - // - // however, we must keep references to the class itself! - val defnAccess = findAccessesShallow(fDefn) - acc ++ defnAccess.withoutBms(methodSyms) - case _: ValDefn => AccessInfo.empty - - accessedCache.getOrElseUpdate(defn.sym, create) - - // MUST be called from a top-level defn - private def findAccesses(d: Defn): Map[BlockMemberSymbol, AccessInfo] = - var defns: mutable.Buffer[Defn] = mutable.Buffer.empty - var definedVarsDeep: Set[Local] = Set.empty + private def findAccessesShallow(obj: ScopedObject): AccessInfo = + val accessed = obj match + case ScopedObject.Top(b) => b match + case s: Scoped => blkAccessesShallow(s.body) + case _ => blkAccessesShallow(b) + case ScopedObject.Func(f, _) => + blkAccessesShallow(f.body) + case ScopedObject.Class(c, _) => + // We must assume that classes may access all their methods. + // When the class symbol is referenced once, that symbol may be used in + // arbitrary ways, which includes calling any of this class's methods. + val res = blkAccessesShallow(c.preCtor) ++ blkAccessesShallow(c.ctor) + res.copy(refdDefns = res.refdDefns ++ c.methods.map(_.dSym)) + case ScopedObject.ClassCtor(cls) => + // Recall that we interpret the ctor as just another function in the same scope + // as the corresponding class, and initializes the class. + AccessInfo.empty.addRefdScopedObj(scopeData.getNode(cls).obj.toInfo) + case ScopedObject.ScopedBlock(uid, b) => blkAccessesShallow(b.body) + case ScopedObject.Companion(c, _) => + // There likely won't be nested companion classes in the future, but for now, + // just assume they may access all their methods + val res = blkAccessesShallow(c.ctor) + res.copy(refdDefns = res.refdDefns ++ c.methods.map(_.dSym)) + case ScopedObject.Loop(_, b) => blkAccessesShallow(b) + // Variables introduced by this scoped object do not belong to a parent scope, so + // we remove them + accessed.withoutLocals(obj.definedLocals) + + private def combineInfos(m1: Map[ScopedInfo, AccessInfo], m2: Map[ScopedInfo, AccessInfo]): Map[ScopedInfo, AccessInfo] = + if m2.size < m1.size then combineInfos(m2, m1) + else m1.foldLeft(m2): + case (acc, info -> accesses) => m2.get(info) match + case Some(value) => acc + (info -> (accesses ++ value)) + case None => acc + (info -> accesses) + + val shallowAccesses: Map[ScopedInfo, AccessInfo] = + scopeData.scopeTree.root.allChildren.map(obj => obj.toInfo -> findAccessesShallow(obj)).toMap + + // Optimization: Find all nodes which are accessed by their children + // See the comment for findAccesses + private val allEdges = + for + (src, accesses) <- shallowAccesses + refd <- accesses.refdDefns + if src =/= refd + yield + (src, refd) + private val accessedByChild = allEdges + .groupBy(_._2) // group by edge destination + .map: + case (_: Unit) -> _ => () -> false + case d -> edges => + val par = scopeData.getNode(d).parent.get.obj.toInfo + d -> edges.exists: + case a -> b => a =/= par + .collect: + case d -> true => d + .toSet - new BlockTraverser: - applyDefn(d) - - override def applyFunDefn(f: FunDefn): Unit = - defns += f - definedVarsDeep ++= definedLocals(f.sym) - super.applyFunDefn(f) - - override def applyDefn(defn: Defn): Unit = - defn match - case c: ClsLikeDefn => - defns += c - definedVarsDeep ++= definedLocals(c.sym) - case _ => - super.applyDefn(defn) + // Find: + // - Map 1: + // - Variables that each scoped object has accessed, either through itself or a nested scoped object. + // - Variables that each scoped object has mutated, either through itself or a nested scoped object. + // - Scoped objects that each object accesses, either through itself or a nested scoped object. + // - Map 2: + // - Variables that each scoped object has accessed, either through itself or a *lifted* scoped object. + // - Variables that each scoped object has mutated, either through itself or a lifted nested scoped object. + // - Scoped objects that each object accesses, either through itself or a lifted nested scoped object. + // + // The former includes ignored objects, and is used to do the readers/writers analysis. The latter is used to determine + // whether we actually need to allocate a capture for the object. In particular, we never need to allocate a capture + // for a variable if only nested scopes mutate it. + // + // Note that it is possible for a lifted scoped object to be reached by traversing through an ignored object. + // + // Also observe that if a node is not accesed from any of its children, then we can re-use the result of its parent's analysis. + private def findAccesses(s: ScopeNode): (Map[ScopedInfo, AccessInfo], Map[ScopedInfo, AccessInfo]) = + // Note: these include `s` + val children = s.allChildren + val childInfo = children.map(_.toInfo).toSet + + // Traverses the node's children, and stops when a child that is accessed by one of its children is found. + // The analysis will be performed on *all* of the traversed nodes simultaneously. + // We will later recurse on the children of all these nodes. + val nexts: Buffer[ScopeNode] = Buffer.empty + def findNodes(s: ScopeNode): List[ScopeNode] = s :: s.children.flatMap: child => + if accessedByChild(child.obj.toInfo) then + nexts.addOne(child) + List.empty + else findNodes(child) + val nodes = findNodes(s) + + val allLocals = nodes.flatMap(node => node.obj.definedLocals).toSet - val defnSyms = defns.iterator.map(_.sym).toSet - val accessInfo = defns.map: d => - val AccessInfo(accessed, mutated, refdDefns) = findAccessesShallow(d) - d.sym -> AccessInfo( - accessed.intersect(definedVarsDeep), - mutated.intersect(definedVarsDeep), - refdDefns.intersect(defnSyms) // only care about definitions nested in this top-level definition + val accessInfo = children.map: obj => + val a @ AccessInfo(accessed, mutated, refdDefns) = shallowAccesses(obj.toInfo) + obj.toInfo -> AccessInfo( + accessed = accessed.intersect(allLocals), + mutated = mutated.intersect(allLocals), + refdDefns = refdDefns.intersect(childInfo) ) val accessInfoMap = accessInfo.toMap - - val edges = + val edges: Set[(ScopedInfo, ScopedInfo)] = for - (sym, AccessInfo(_, _, refd)) <- accessInfo + (src, AccessInfo(_, _, refd)) <- accessInfo r <- refd - if defnSyms.contains(r) - yield sym -> r + // remove self-edges: they do not affect this analysis + if src =/= r + // very important: we only care about edges that flow into the subtree rooted at `s` + if childInfo.contains(r) && r =/= s.obj.toInfo + yield src -> r .toSet // (sccs, sccEdges) forms a directed acyclic graph (DAG) - val algorithms.SccsInfo(sccs, sccEdges, inDegs, outDegs) = algorithms.sccsWithInfo(edges, defnSyms) + val algorithms.SccsInfo(sccs, sccEdges, inDegs, outDegs) = algorithms.sccsWithInfo(edges, childInfo) + + val rootInfo = s.obj.toInfo + val (rootId, rootElems) = sccs.find: + case (id, elems) => elems.contains(rootInfo) + .get + if rootElems.size != 1 then lastWords("SCC containing root had a degree other than 1.") - // all defns in the same scc must have at least the same accesses as each other - val base = for (id, scc) <- sccs yield id -> - scc.foldLeft(AccessInfo.empty): - case (acc, sym) => acc ++ accessInfoMap(sym) + // With respect to the current scoped object `s`, we may "ignore" one of its children `c` if and only if + // it is ignored (not lifted), and `s` is in the subtree rooted at the first lifted parent of `c`. We + // "ignore" `c` in the sense that it does not need to capture `s`'s scoped object's variables, nor does + // it require the current scoped object to create a capture class for its accessed variables. + def isIgnored(c: ScopedInfo) = + s.inSubtree(scopeData.getNode(c).firstLiftedParent.toInfo) + + // All objects in the same scc must have at least the same accesses as each other + def go(includeIgnored: Bool) = + val base = for (id, scc) <- sccs yield + // If all objects in this SCC are ignored, then we treat it as if it does not access anything, + // unless we explicitly want to count ignored items (for the readers-mutators analysis) + if !includeIgnored && scc.forall(isIgnored) then id -> AccessInfo.empty + else id -> scc.foldLeft(AccessInfo.empty): + case (acc, sym) => acc ++ accessInfoMap(sym) + + // dp on DAG + val dp: MutMap[Int, AccessInfo] = MutMap.empty + def sccAccessInfo(scc: Int): AccessInfo = dp.get(scc) match + case Some(value) => value + case None => + val ret = sccEdges(scc).foldLeft(base(scc)): + case (acc, nextScc) => acc ++ sccAccessInfo(nextScc) + dp.addOne(scc -> ret) + ret + + for + (id, scc) <- sccs + sym <- scc + yield + sym -> sccAccessInfo(id).withoutLocals(scopeData.getNode(sym).obj.definedLocals) - // dp on DAG - val dp: MutMap[Int, AccessInfo] = MutMap.empty - def sccAccessInfo(scc: Int): AccessInfo = dp.get(scc) match - case Some(value) => value - case None => - val ret = sccEdges(scc).foldLeft(base(scc)): - case (acc, nextScc) => acc ++ sccAccessInfo(nextScc) - dp.addOne(scc -> ret) - ret + // Remove locals that are not yet defined + def removeUnused(m: Map[ScopedInfo, AccessInfo]) = m.map: + case k -> v => + val node = scopeData.getNode(k) + k -> v.intersectLocals(node.existingVars) - for - (id, scc) <- sccs - sym <- scc - yield sym -> (sccAccessInfo(id).intersectLocals(existingVars(sym))) - - private def findAccessesTop = - var accessMap: Map[BlockMemberSymbol, AccessInfo] = Map.empty - new BlockTraverserShallow: - applyBlock(b) - override def applyDefn(defn: Defn): Unit = defn match - case _: FunDefn | _: ClsLikeDefn => - accessMap ++= findAccesses(defn) - case _ => super.applyDefn(defn) + val (m1, m2) = (removeUnused(go(true)), removeUnused(go(false))) - accessMap + val subCases = nexts.map(findAccesses) + subCases.foldLeft((m1, m2)): + case ((acc1, acc2), (new1, new2)) => (combineInfos(acc1, new1), combineInfos(acc2, new2)) - val accessMap = findAccessesTop + // Searching from the root makes no sense. We instead start searching from each scope nested in the top-level + private val (m1, m2) = scopeData.scopeTree.root.children.map(findAccesses).unzip + val accessMapWithIgnored = m1.foldLeft[Map[ScopedInfo, AccessInfo]](Map.empty)(_ ++ _) + val accessMap = m2.foldLeft[Map[ScopedInfo, AccessInfo]](Map.empty)(_ ++ _) + + private def reqdCaptureLocals(s: ScopeNode): Map[ScopedInfo, Set[Local]] = + val blk = s.obj match + case ScopedObject.Top(b) => lastWords("reqdCaptureLocals called on top block") + case ScopedObject.ClassCtor(cls) => return Map.empty + (s.obj.toInfo -> Set.empty) + case ScopedObject.Class(cls, _) => Begin(cls.preCtor, cls.ctor) + case ScopedObject.Companion(comp, _) => comp.ctor + case ScopedObject.Func(fun, _) => fun.body + case ScopedObject.ScopedBlock(uid, block) => block + case ScopedObject.Loop(sym, block) => block + + // traverse all scoped blocks and loops + val nexts: Buffer[ScopeNode] = Buffer.empty + def findNodes(s: ScopeNode): List[ScopeNode] = s :: s.children.flatMap: + case c @ ScopeNode(obj = obj: (ScopedObject.ScopedBlock | ScopedObject.Loop)) => findNodes(c) + case c => + nexts.addOne(c) + List.empty + val nodes = findNodes(s) + + val locals = nodes.flatMap(_.obj.definedLocals).toSet + + val cap = reqdCaptureLocalsBlk(blk, nexts.toList, s.obj.definedLocals, locals) + + // In a class, all variables that are mutated by a child scope and accessed by a lifted class must be captured + val additional = s.obj match + case _: ScopedObject.Companion | _: ScopedObject.Class => + val (a, b) = s.children.map: c => + val acc = accessMap(c.obj.toInfo) + val accAll = accessMapWithIgnored(c.obj.toInfo) + (accAll.mutated, acc.accessed) + .unzip + a.flatten.toSet.intersect(b.flatten.toSet) + case _ => Set.empty - // TODO: let declarations inside loops (also broken without class lifting) - // I'll fix it once it's fixed in the IR since we will have more tools to determine - // what locals belong to what block. - private def reqdCaptureLocals(f: FunDefn) = - var defns = f.body.gatherDefns() - val defnSyms = defns.collect: - case f: FunDefn => f.sym -> f - case c: ClsLikeDefn => c.sym -> c + val newCap = cap ++ additional + + val cur: Map[ScopedInfo, Set[Local]] = nodes.map: n => + n.obj.toInfo -> newCap.intersect(n.obj.definedLocals) .toMap + + nexts.foldLeft(cur): + case (mp, acc) => mp ++ reqdCaptureLocals(acc) - val thisVars = definedLocals(f.sym) - - case class CaptureInfo(reqCapture: Set[Local], hasReader: Set[Local], hasMutator: Set[Local]) + // readers-mutators analysis + private def reqdCaptureLocalsBlk(b: Block, nextNodes: List[ScopeNode], startingVars: Set[Local], thisVars: Set[Local]): Set[Local] = + val scopeInfos: Map[ScopedInfo, ScopeNode] = nextNodes.map(node => node.obj.toInfo -> node).toMap - def go(b: Block, reqCapture_ : Set[Local], hasReader_ : Set[Local], hasMutator_ : Set[Local]): CaptureInfo = + case class CaptureInfo(reqCapture: Set[Local], hasReader: Set[Local], hasMutator: Set[Local], mutated: Set[Local]) + + // linearVars denotes the variables defined inside the scopes up to the nearest loop or the top level block. + // If a loop modifies a non-linear variable and then one of its nested definitions accesses it, we must put put + // that variable in a capture. + def go(b: Block, reqCapture_ : Set[Local], hasReader_ : Set[Local], hasMutator_ : Set[Local], mutated_ : Set[Local])(using linearVars: Set[Local]): CaptureInfo = var reqCapture = reqCapture_ var hasReader = hasReader_ var hasMutator = hasMutator_ + // note: the meaning of `mutated` is a bit strange: it basically means variables which are currently not linear that have been mutated + // if a variable is in this set but is linear, then it's ignored + var mutated = mutated_ inline def merge(c: CaptureInfo) = reqCapture ++= c.reqCapture hasReader ++= c.hasReader hasMutator ++= c.hasMutator + mutated ++= c.mutated - def rec(blk: Block) = - go(blk, reqCapture, hasReader, hasMutator) + def rec(blk: Block)(using linearVars: Set[Local]) = + go(blk, reqCapture, hasReader, hasMutator, mutated_) new BlockTraverserShallow: applyBlock(b) override def applyBlock(b: Block): Unit = b match + // Note that we traverse directly into scoped blocks without using handleCalledScope + case s: Scoped => + rec(s.body)(using linearVars = linearVars ++ s.syms) |> merge + case l: Label if l.loop => + rec(l.body)(using linearVars = Set.empty) |> merge + applySubBlock(l.rest) case Assign(lhs, rhs, rest) => applyResult(rhs) if hasReader.contains(lhs) || hasMutator.contains(lhs) then reqCapture += lhs - applyBlock(rest) - + if !linearVars.contains(lhs) then mutated += lhs + applySubBlock(rest) + case Define(c @ ClsLikeDefn(k = syntax.Obj), rest) => + handleCalledScope(c.isym) + applySubBlock(rest) case Match(scrut, arms, dflt, rest) => applyPath(scrut) val infos = arms.map: @@ -300,81 +365,89 @@ class UsedVarAnalyzer(b: Block)(using State): val dfltInfo = dflt.map: case arm => rec(arm) - infos.map(merge) // IMPORTANT: rec all first, then merge, since each branch is mutually exclusive - dfltInfo.map(merge) - applyBlock(rest) - case Label(label, loop, body, rest) => - // for now, if the loop body mutates a variable and that variable is accessed or mutated by a defn, - // or if it reads a variable that is later mutated by an instance inside the loop, - // we put it in a capture. this preserves the current semantics of the IR (even though it's incorrect). - // See the above TODO - val c @ CaptureInfo(req, read, mut) = rec(body) - merge(c) - reqCapture ++= read.intersect(blkAccessesShallow(body, S(label)).mutated) - reqCapture ++= mut.intersect(body.freeVars) - applyBlock(rest) + infos.foreach(merge) // IMPORTANT: rec all first, then merge, since each branch is mutually exclusive + dfltInfo.foreach(merge) + applySubBlock(rest) case Begin(sub, rest) => rec(sub) |> merge - applyBlock(rest) + applySubBlock(rest) case TryBlock(sub, finallyDo, rest) => // sub and finallyDo could be executed sequentially, so we must merge rec(sub) |> merge rec(finallyDo) |> merge - applyBlock(rest) + applySubBlock(rest) case Return(res, false) => applyResult(res) hasReader = Set.empty hasMutator = Set.empty case _ => super.applyBlock(b) - def handleCalledBms(called: BlockMemberSymbol): Unit = defnSyms.get(called) match + def handleCalledScope(called: ScopedInfo): Unit = scopeInfos.get(called) match case None => () - case Some(defn) => - val AccessInfo(accessed, muted, refd) = accessMap(defn.sym) + case Some(node) => + node.obj match + // ignore method calls to class or object methods + case ScopedObject.Func(_, S(true)) => return + case _ => () + + val AccessInfo(accessed, muted, refd) = accessMapWithIgnored(called) val muts = muted.intersect(thisVars) - val reads = defn.freeVars.intersect(thisVars) -- muts - // this not a naked reference. if it's a ref to a class, this can only ever create once instance - // so the "one writer" rule applies + val reads = accessed.intersect(thisVars) -- muts + val refdExcl = refd.filter: sym => + scopeData.getNode(sym).obj match + case s: ScopedObject.ScopedBlock => false + case ScopedObject.Func(_, S(true)) => false + case _ => true + + // This not a naked reference. If it's a ref to a class, this can only ever create once instance + // so the "one writer" rule applies. + // However, if the control flow is not linear, we are forced to add all the mutated variables for l <- muts do - if hasReader.contains(l) || hasMutator.contains(l) || defn.isInstanceOf[FunDefn] then + if hasReader.contains(l) || hasMutator.contains(l) || !linearVars.contains(l) then reqCapture += l + hasReader += l hasMutator += l + mutated += l for l <- reads do if hasMutator.contains(l) then reqCapture += l + if mutated.contains(l) && !linearVars.contains(l) then + reqCapture += l hasReader += l // if this defn calls another defn that creates a class or has a naked reference to a // function, we must capture the latter's mutated variables in a capture, as arbitrarily // many mutators could be created from it for - sym <- refd - l <- accessMap(sym).mutated + sym <- refdExcl + l <- accessMapWithIgnored(sym).mutated do reqCapture += l hasMutator += l - override def applyResult(r: Result): Unit = r match - case Call(RefOfBms(l, _), args) => - args.map(super.applyArg(_)) - handleCalledBms(l) - case Instantiate(mut, InstSel(l), args) => - args.map(super.applyArg) - handleCalledBms(l._1) + override def applyResult(r: Result): Unit = + r match + case Call(RefOfBms(_, SDSym(d)), args) => + args.foreach(super.applyArg(_)) + handleCalledScope(d) + case Instantiate(mut, InstSel(_, S(d)), args) => + args.foreach(super.applyArg) + handleCalledScope(d) case _ => super.applyResult(r) override def applyPath(p: Path): Unit = p match - case RefOfBms(l, _) => - defnSyms.get(l) match + case RefOfBms(_, SDSym(d)) => + scopeInfos.get(d) match case None => super.applyPath(p) case Some(defn) => - val isMod = defn match - case c: ClsLikeDefn => modOrObj(c) + val isModOrObj = defn.obj match + case c: ScopedObject.Companion => true + case c: ScopedObject.Class => c.isObj case _ => false - if isMod then super.applyPath(p) + if isModOrObj then super.applyPath(p) else - val AccessInfo(accessed, muted, refd) = accessMap(defn.sym) + val AccessInfo(accessed, muted, refd) = accessMapWithIgnored(d) val muts = muted.intersect(thisVars) - val reads = defn.freeVars.intersect(thisVars) -- muts + val reads = accessed.intersect(thisVars) -- muts // this is a naked reference, we assume things it mutates always needs a capture for l <- muts do reqCapture += l @@ -382,13 +455,15 @@ class UsedVarAnalyzer(b: Block)(using State): for l <- reads do if hasMutator.contains(l) then reqCapture += l + if mutated.contains(l) && !linearVars.contains(l) then + reqCapture += l hasReader += l // if this defn calls another defn that creates a class or has a naked reference to a // function, we must capture the latter's mutated variables in a capture, as arbitrarily // many mutators could be created from it for sym <- refd - l <- accessMap(sym).mutated + l <- accessMapWithIgnored(sym).mutated do reqCapture += l hasMutator += l @@ -399,52 +474,21 @@ class UsedVarAnalyzer(b: Block)(using State): override def applyDefn(defn: Defn): Unit = defn match case c: ClsLikeDefn if modOrObj(c) => - handleCalledBms(c.sym) + handleCalledScope(c.isym) // TODO: use new system super.applyDefn(defn) case _ => super.applyDefn(defn) - CaptureInfo(reqCapture, hasReader, hasMutator) - - val reqCapture = go(f.body, Set.empty, Set.empty, Set.empty).reqCapture - val usedVars = defns.flatMap(_.freeVars.intersect(thisVars)).toSet - (usedVars, reqCapture) - - // the current problem is that we need extra code to find which variables were really defined by a function - // this may be resolved in the future when the IR gets explicit variable declarations - private def findUsedLocalsFn(f: FunDefn): Map[BlockMemberSymbol, FreeVars] = - val thisVars = definedLocals(f.sym) - - val (vars, cap) = reqdCaptureLocals(f) - - var usedMap: Map[BlockMemberSymbol, FreeVars] = Map.empty - usedMap += (f.sym -> Lifter.FreeVars(vars.intersect(thisVars), cap.intersect(thisVars))) - for d <- nestedDefns(f.sym) do - usedMap ++= findUsedLocalsDefn(d) - usedMap - - private def findUsedLocalsDefn(d: Defn) = - d match - case f: FunDefn => - findUsedLocalsFn(f) - case c: ClsLikeDefn => - findUsedLocalsCls(c) - case d => Map.empty + CaptureInfo(reqCapture, hasReader, hasMutator, mutated) + + val reqCapture = go(b, Set.empty, Set.empty, Set.empty, Set.empty)(using linearVars = startingVars).reqCapture + reqCapture.intersect(thisVars) - private def findUsedLocalsCls(c: ClsLikeDefn): Map[BlockMemberSymbol, FreeVars] = - nestedDefns(c.sym).foldLeft(Map.empty): - case (acc, d) => acc ++ findUsedLocalsDefn(d) + val reqdCaptures: Map[ScopedInfo, Set[Local]] = scopeData.root.children.foldLeft(Map.empty): + case (acc, node) => acc ++ reqdCaptureLocals(node) - /** - * Finds the used locals of functions which have been used by their nested definitions. - * - * @param b - * @return - */ - def findUsedLocals: Lifter.UsedLocalsMap = - var usedMap: Map[BlockMemberSymbol, FreeVars] = Map.empty - new BlockTraverserShallow: - applyBlock(b) - override def applyDefn(defn: Defn): Unit = - usedMap ++= findUsedLocalsDefn(defn) - - Lifter.UsedLocalsMap(usedMap) + // For local inside a capture, finds the node to which this local belongs. + val capturesMap = + for + case (info -> reqCap) <- reqdCaptures + s <- reqCap + yield s -> info diff --git a/hkmc2/shared/src/test/mlscript/HkScratch.mls b/hkmc2/shared/src/test/mlscript/HkScratch.mls index ef38e8b363..b09a983866 100644 --- a/hkmc2/shared/src/test/mlscript/HkScratch.mls +++ b/hkmc2/shared/src/test/mlscript/HkScratch.mls @@ -5,8 +5,9 @@ // :elt :global -// :d -// :todo - +:d +:todo +class Eff +//│ Elab: { Cls Eff { }; } diff --git a/hkmc2/shared/src/test/mlscript/HkScratch2.mls b/hkmc2/shared/src/test/mlscript/HkScratch2.mls new file mode 100644 index 0000000000..16a486ebaf --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/HkScratch2.mls @@ -0,0 +1,96 @@ +:js +// :de +// :sjs +// :pt +// :elt + +:global +:d +:todo +:noSanityCheck + + +:lift +:ssjs +module M with + data class A(x) with + class B() with + fun newA_B(y) = + newA_B(2) + A(y) + fun newB = B() +M.A(2).newB +//│ Elab: { Mod M { Cls AParamList(‹›,List(Param(‹val›,x‹695›,None,Modulefulness(None))),None) { val member:x‹698› = x‹695›#0; Cls BParamList(‹›,List(),None) { method fun member:newA_B‹686›(y‹705›) = { (class:B‹701›#666.)newA_B‹member:newA_B‹686››(2); (module:M‹691›#666.)A‹member:A‹689››(y‹705›#666) }; }; method fun member:newB‹687› = (class:A‹693›#666.)B‹member:B‹688››(); }; }; member:M‹690›#666.A‹member:A‹689››(2).newB } +//│ JS: +//│ let B1, M1, tmp, B$; +//│ B$ = function B$(isMut, A$, A$1, B$1, M2) { +//│ let tmp1; +//│ if (isMut === true) { +//│ tmp1 = new B1.class(); +//│ } else { +//│ tmp1 = globalThis.Object.freeze(new B1.class()); +//│ } +//│ return tmp1(A$, A$1, B$1, M2) +//│ }; +//│ B1 = function B() { +//│ return (A$, A$1, B$1, M2) => { +//│ return globalThis.Object.freeze(new B.class()(A$, A$1, B$1, M2)); +//│ } +//│ }; +//│ (class B { +//│ static { +//│ B1.class = this +//│ } +//│ constructor() { +//│ return (A$, A$1, B$1, M2) => { +//│ this.A$ = A$; +//│ this.A$ = A$1; +//│ this.B$ = B$1; +//│ this.M = M2; +//│ return this; +//│ } +//│ } +//│ #B$cap; +//│ newA_B(y) { +//│ let tmp1; +//│ tmp1 = this.newA_B(2); +//│ return this.A$(y) +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "B", []]; +//│ }); +//│ (class M { +//│ static { +//│ M1 = this +//│ } +//│ constructor() { +//│ runtime.Unit; +//│ } +//│ static #M_mod$cap; +//│ static { +//│ this.A = function A(x) { +//│ return globalThis.Object.freeze(new A.class(x)); +//│ }; +//│ (class A { +//│ static { +//│ M.A.class = this +//│ } +//│ constructor(x) { +//│ this.x = x; +//│ } +//│ #A$cap; +//│ get newB() { +//│ return B$(false, M.A.class, M.A, this.B, M); +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "A", ["x"]]; +//│ }); +//│ } +//│ #M$cap; +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "M"]; +//│ }); +//│ tmp = M1.A(2); +//│ block$res1 = tmp.newB; +//│ undefined +//│ = B() diff --git a/hkmc2/shared/src/test/mlscript/backlog/Lifter.mls b/hkmc2/shared/src/test/mlscript/backlog/Lifter.mls index 6572c17e73..1b2772d1eb 100644 --- a/hkmc2/shared/src/test/mlscript/backlog/Lifter.mls +++ b/hkmc2/shared/src/test/mlscript/backlog/Lifter.mls @@ -2,23 +2,6 @@ :lift :todo - -// The following are problems with lifting functions inside other functions. - - -// Lifting functions with spread arguments is broken. -:expect [1] -fun f(x) = - fun g(...rest) = - print(x) - rest - let a = g - a(1) -f(2) -//│ > 2 -//│ ═══[RUNTIME ERROR] Expected: '[1]', got: '[]' -//│ = [] - // The following are problems with lifting classes inside other definitions. // We use the optional `symbol` parameter of `Select` to detect references to the @@ -68,6 +51,25 @@ test(2) //│ ═══[RUNTIME ERROR] Error: Access to required field 'get' yielded 'undefined' //│ ═══[RUNTIME ERROR] Expected: '2', got: 'undefined' +// This is due to `super` not being called with multiple parameter lists. +:expect "B()" +fun test(x) = + class A() with + fun get = x + class B() extends A() with + fun bar = x + B() +test(0).toString() +//│ ═══[RUNTIME ERROR] Expected: '"B()"', got: '"(x) => { this.x = x; return this; }"' +//│ = "(x) => { this.x = x; return this; }" + +// In fact, the correct way (using the current "constructor returns a lambda" hack) is this: +// return (params) => { +// super(parentMainParams); +// // ctor +// return this(parentAuxParams); +// } + /// The following are related to first-class classes, instance checks, and private fields. /// :w @@ -127,27 +129,6 @@ test(2) //│ ═══[WARNING] Cannot yet lift definition `B` as it extends an unliftable class. //│ = 2 -:expect 2 -class A(a) with - let x = 2 - fun f() = - fun g() = x - g() -A(2).f() -//│ ═══[COMPILATION ERROR] Uses of private fields cannot yet be lifted. -//│ ═══[RUNTIME ERROR] Error: MLscript call unexpectedly returned `undefined`, the forbidden value. -//│ ═══[RUNTIME ERROR] Expected: '2', got: 'undefined' - -:expect 2 -module M with - let x = 2 - fun f() = - fun g() = x - g -M.f()() -//│ ═══[COMPILATION ERROR] Uses of private fields cannot yet be lifted. -//│ ═══[RUNTIME ERROR] Expected: '2', got: '()' - /// The following are related to modules and objects. /// :todo @@ -158,7 +139,9 @@ fun foo(x, y) = set y = 2 x + y + test M.foo() -//│ ═══[WARNING] Modules are not yet lifted. +//│ ╔══[WARNING] Modules are not yet lifted. +//│ ║ l.136: module M with +//│ ╙── ^ :expect 14 foo(10, 0) @@ -171,24 +154,33 @@ fun foo(x, y) = x + y fun foo = M.foo() foo -//│ ═══[WARNING] Modules are not yet lifted. -//│ ═══[COMPILATION ERROR] No definition found in scope for member 'M$' +//│ ╔══[WARNING] Modules are not yet lifted. +//│ ║ l.151: module M with +//│ ╙── ^ :expect 12 foo(10, 0) -//│ ═══[RUNTIME ERROR] ReferenceError: M$ is not defined -//│ ═══[RUNTIME ERROR] Expected: '12', got: 'undefined' +//│ = 12 data class A(x) with module M with fun getB() = x fun getA() = M.getB() -//│ ═══[WARNING] Modules are not yet lifted. +//│ ╔══[WARNING] Modules are not yet lifted. +//│ ║ l.167: module M with +//│ ╙── ^ +//│ ╔══[COMPILATION ERROR] No definition found in scope for member 'M' +//│ ╟── which references the symbol introduced here +//│ ║ l.167: module M with +//│ ║ ^^^^^^ +//│ ║ l.168: fun getB() = x +//│ ╙── ^^^^^^^^^^^^^^^^^^ :expect 2 A(2).getA() -//│ = 2 +//│ ═══[RUNTIME ERROR] ReferenceError: M is not defined +//│ ═══[RUNTIME ERROR] Expected: '2', got: 'undefined' // TODO: Foo needs to be put in a mutable capture. Also, we need to pass the Foo instance itself into Foo fun foo(x) = @@ -199,12 +191,11 @@ fun foo(x) = (new Bar).self2 foo(2) //│ ╔══[ERROR] Expected a statically known class; found reference of type Foo. -//│ ║ l.198: class Bar extends Foo +//│ ║ l.190: class Bar extends Foo //│ ║ ^^^ //│ ╙── The 'new' keyword requires a statically known class; use the 'new!' operator for dynamic instantiation. -//│ ═══[COMPILATION ERROR] No definition found in scope for member 'Foo$' -//│ ═══[COMPILATION ERROR] No definition found in scope for member 'Foo$' -//│ ═══[RUNTIME ERROR] TypeError: Class extends value [object Object] is not a constructor or null +//│ ═══[WARNING] Cannot yet lift definition `Bar` as it extends an expression. +//│ = Bar { Foo: undefined } // `h` is lifted out, but then cannot access the BlockMemberSymbol M. :fixme @@ -215,7 +206,8 @@ fun f = fun h = M.g h M.g -//│ ═══[WARNING] Modules are not yet lifted. -//│ /!!!\ Uncaught error: hkmc2.InternalError: `this` not in scope: class:M +//│ ╔══[WARNING] Modules are not yet lifted. +//│ ║ l.204: module M with +//│ ╙── ^ diff --git a/hkmc2/shared/src/test/mlscript/codegen/ScopedBlocksAndHandlers.mls b/hkmc2/shared/src/test/mlscript/codegen/ScopedBlocksAndHandlers.mls index f753b520d8..928b0b2560 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/ScopedBlocksAndHandlers.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/ScopedBlocksAndHandlers.mls @@ -19,8 +19,13 @@ fun f(x) = let z = 3 fun g() = z //│ JS (unsanitized): -//│ let f, g$; +//│ let g, f, g$; //│ g$ = function g$(z) { +//│ return () => { +//│ return g(z) +//│ } +//│ }; +//│ g = function g(z) { //│ return z //│ }; //│ f = function f(x) { @@ -110,8 +115,13 @@ fun f(x) = let z = 3 fun g() = z ) //│ JS (unsanitized): -//│ let f4, g$3; +//│ let g5, f4, g$3; //│ g$3 = function g$(z) { +//│ return () => { +//│ return g5(z) +//│ } +//│ }; +//│ g5 = function g(z) { //│ return z //│ }; //│ f4 = function f(x) { diff --git a/hkmc2/shared/src/test/mlscript/handlers/Debugging.mls b/hkmc2/shared/src/test/mlscript/handlers/Debugging.mls index 297681f6c2..2fb2bcf659 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/Debugging.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/Debugging.mls @@ -57,7 +57,7 @@ fun f() = //│ k = 2000; //│ runtime.resumeValue = Predef.equals(i, 0); //│ if (runtime.curEffect !== null) { -//│ return runtime.unwind(f, 5, "Debugging.mls:17:6", f$debugInfo, null, 1, 0, 6, i, j, k, scrut, tmp1, tmp2) +//│ return runtime.unwind(f, 5, "Debugging.mls:17:8", f$debugInfo, null, 1, 0, 6, i, j, k, scrut, tmp1, tmp2) //│ } //│ pc = 5; //│ continue main; @@ -118,7 +118,7 @@ lambda_test(() => 100) //│ ═══[RUNTIME ERROR] Error: Unhandled effect FatalEffect //│ at lambda (Debugging.mls:117:3) -//│ at lambda_test (Debugging.mls:114:3) +//│ at lambda_test (pc=1) import "../../mlscript-compile/Runtime.mls" @@ -158,38 +158,30 @@ module Test with fun main()(using Debugger) = test(12) + test(34) -// Currently this test fails due to lifter issue. Commenting out :lift at the top of this file will make this work. -// The lifter currently does not correctly consider the member variable as a local that could be captured. -:fixme let res = handle h = Debugger with fun break(payload)(resume) = resume using Debugger = h Runtime.try(() => Test.main()) -//│ ╔══[COMPILATION ERROR] No definition found in scope for member 'instance$Ident(Debugger)' -//│ ╟── which references the symbol introduced here -//│ ║ l.167: using Debugger = h -//│ ╙── ^^^^^^^^^^^^^^^^^^ -//│ ═══[RUNTIME ERROR] ReferenceError: instance$Ident is not defined -//│ res = undefined +//│ res = EffectHandle(_) :re res.raise() -//│ ═══[RUNTIME ERROR] TypeError: Cannot read properties of undefined (reading 'raise') +//│ ═══[RUNTIME ERROR] Error: Unhandled effect Handler$h$ +//│ at Test.test (Debugging.mls:156:18) +//│ at Test.main (pc=3) -:fixme set res = res.resumeWith(42) -//│ ═══[RUNTIME ERROR] TypeError: Cannot read properties of undefined (reading 'resumeWith') :re res.raise() -//│ ═══[RUNTIME ERROR] TypeError: Cannot read properties of undefined (reading 'raise') +//│ ═══[RUNTIME ERROR] Error: Unhandled effect Handler$h$ +//│ at Test.test (Debugging.mls:156:18) +//│ at Test.main (pc=1) -:fixme :expect 0.33676533676533676 res.resumeWith(666) -//│ ═══[RUNTIME ERROR] TypeError: Cannot read properties of undefined (reading 'resumeWith') -//│ ═══[RUNTIME ERROR] Expected: '0.33676533676533676', got: 'undefined' +//│ = 0.33676533676533676 let i = 100 fun f() = @@ -205,8 +197,8 @@ fun f() = f() //│ > Stack Trace: -//│ > at f (Debugging.mls:199:3) with locals: j=200 +//│ > at f (Debugging.mls:191:3) with locals: j=200 //│ > Stack Trace: -//│ > at f (Debugging.mls:201:3) +//│ > at f (Debugging.mls:193:3) //│ > Stack Trace: //│ > at tail position diff --git a/hkmc2/shared/src/test/mlscript/handlers/EffectInHandler.mls b/hkmc2/shared/src/test/mlscript/handlers/EffectInHandler.mls index 00853d2e4e..61fd0825ad 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/EffectInHandler.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/EffectInHandler.mls @@ -24,7 +24,7 @@ h.aux(3) //│ ╔══[ERROR] Name not found: h //│ ║ l.19: h.p(x) //│ ╙── ^ -//│ ═══[RUNTIME ERROR] Error: MLscript call unexpectedly returned `undefined`, the forbidden value. +//│ FAILURE: Unexpected lack of runtime error handle h1 = PrintEffect with fun p(x)(k) = diff --git a/hkmc2/shared/src/test/mlscript/handlers/Effects.mls b/hkmc2/shared/src/test/mlscript/handlers/Effects.mls index a18551b989..4eb93ce2b4 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/Effects.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/Effects.mls @@ -71,9 +71,13 @@ in h.perform("5") result //│ JS (unsanitized): -//│ let result, h7, tmp7, handleBlock$7, Handler$h$perform7, Handler$h$15, Handler$h$perform$7; -//│ result = ""; -//│ Handler$h$perform$7 = function Handler$h$perform$(Handler$h$$instance, arg, k) { +//│ let result, h7, tmp7, handleBlock$7, Handler$h$perform7, Handler$h$15, Handler$h$perform$4; +//│ Handler$h$perform$4 = function Handler$h$perform$(arg) { +//│ return (k) => { +//│ return Handler$h$perform7(arg, k) +//│ } +//│ }; +//│ Handler$h$perform7 = function Handler$h$perform(arg, k) { //│ let v, pc; //│ if (runtime.resumePc === -1) { //│ pc = 0; @@ -87,7 +91,7 @@ result //│ case 0: //│ runtime.resumeValue = runtime.safeCall(k(arg)); //│ if (runtime.curEffect !== null) { -//│ return runtime.unwind(Handler$h$perform$7, 1, "Effects.mls:63:13", null, null, 1, 3, Handler$h$$instance, arg, k, 0) +//│ return runtime.unwind(Handler$h$perform7, 1, null, null, null, 1, 2, arg, k, 0) //│ } //│ pc = 1; //│ continue main; @@ -101,11 +105,7 @@ result //│ break; //│ } //│ }; -//│ Handler$h$perform7 = function Handler$h$perform(Handler$h$$instance, arg) { -//│ return (k) => { -//│ return Handler$h$perform$7(Handler$h$$instance, arg, k) -//│ } -//│ }; +//│ result = ""; //│ (class Handler$h$14 extends Effect1 { //│ static { //│ Handler$h$15 = this @@ -118,9 +118,10 @@ result //│ } //│ tmp8; //│ } +//│ #Handler$h$$cap; //│ perform(arg) { //│ let Handler$h$perform$here; -//│ Handler$h$perform$here = runtime.safeCall(Handler$h$perform7(this, arg)); +//│ Handler$h$perform$here = Handler$h$perform$4(arg); //│ return runtime.mkEffect(this, Handler$h$perform$here) //│ } //│ toString() { return runtime.render(this); } @@ -441,10 +442,18 @@ fun foo(h) = handle h = Eff with fun perform()(k) = k(()) foo(h) -//│ ═══[WARNING] Modules are not yet lifted. -//│ ═══[WARNING] Modules are not yet lifted. -//│ ═══[WARNING] Modules are not yet lifted. -//│ ═══[WARNING] Modules are not yet lifted. +//│ ╔══[WARNING] Modules are not yet lifted. +//│ ║ l.438: module A with +//│ ╙── ^ +//│ ╔══[WARNING] Modules are not yet lifted. +//│ ║ l.434: module A with +//│ ╙── ^ +//│ ╔══[WARNING] Modules are not yet lifted. +//│ ║ l.430: module A with +//│ ╙── ^ +//│ ╔══[WARNING] Modules are not yet lifted. +//│ ║ l.426: module A with +//│ ╙── ^ //│ /!!!\ Uncaught error: hkmc2.InternalError: Unexpected nested class: lambdas may not function correctly. // Access superclass fields diff --git a/hkmc2/shared/src/test/mlscript/handlers/EffectsHygiene.mls b/hkmc2/shared/src/test/mlscript/handlers/EffectsHygiene.mls index 8723c6aec2..ea423aff62 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/EffectsHygiene.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/EffectsHygiene.mls @@ -14,8 +14,12 @@ fun foo(h): module M = else module A A -//│ ═══[WARNING] Modules are not yet lifted. -//│ ═══[WARNING] Modules are not yet lifted. +//│ ╔══[WARNING] Modules are not yet lifted. +//│ ║ l.15: module A +//│ ╙── ^ +//│ ╔══[WARNING] Modules are not yet lifted. +//│ ║ l.12: module A +//│ ╙── ^ //│ /!!!\ Uncaught error: hkmc2.InternalError: Unexpected nested class: lambdas may not function correctly. diff --git a/hkmc2/shared/src/test/mlscript/handlers/EffectsInClasses.mls b/hkmc2/shared/src/test/mlscript/handlers/EffectsInClasses.mls index 6c515939be..ae9508dcc6 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/EffectsInClasses.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/EffectsInClasses.mls @@ -32,6 +32,7 @@ data class Lol(h) with //│ } //│ tmp1; //│ } +//│ #Lol$cap; //│ toString() { return runtime.render(this); } //│ static [definitionMetadata] = ["class", "Lol", ["h"]]; //│ }); diff --git a/hkmc2/shared/src/test/mlscript/handlers/HandlersScratch.mls b/hkmc2/shared/src/test/mlscript/handlers/HandlersScratch.mls index 67b143ed20..bd9ed791c1 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/HandlersScratch.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/HandlersScratch.mls @@ -2,7 +2,5 @@ :effectHandlers :lift -abstract class Effect - diff --git a/hkmc2/shared/src/test/mlscript/handlers/NonLocalReturns.mls b/hkmc2/shared/src/test/mlscript/handlers/NonLocalReturns.mls index c7afa3ca81..61e9af72b6 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/NonLocalReturns.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/NonLocalReturns.mls @@ -2,7 +2,6 @@ :effectHandlers :lift - fun f() = (() => return 100)() print("Bad") @@ -67,7 +66,7 @@ f()() :e return 100 //│ ╔══[ERROR] Return statements are not allowed outside of functions. -//│ ║ l.68: return 100 +//│ ║ l.67: return 100 //│ ╙── ^^^^^^^^^^ :e @@ -75,14 +74,14 @@ if true do while false do let f = () => return 100 //│ ╔══[ERROR] Return statements are not allowed outside of functions. -//│ ║ l.76: let f = () => return 100 +//│ ║ l.75: let f = () => return 100 //│ ╙── ^^^^^^^^^^ :e fun f() = type A = return "lol" //│ ╔══[ERROR] Return statements are not allowed in this context. -//│ ║ l.83: type A = return "lol" +//│ ║ l.82: type A = return "lol" //│ ╙── ^^^^^^^^^^^^ @@ -108,8 +107,7 @@ fun f() = return 100 4 f() -//│ ═══[COMPILATION ERROR] No definition found in scope for member 'nonLocalRetHandler$f' -//│ ═══[RUNTIME ERROR] ReferenceError: nonLocalRetHandler$f is not defined +//│ ═══[RUNTIME ERROR] Error: Effect $_non$_local$_return$_effect$_18 is raised in a constructor //│ ═══[RUNTIME ERROR] Expected: '100', got: 'undefined' // ctor cannot raise effect, so error is expected. @@ -130,8 +128,7 @@ fun f() = (() => return 100)() 4 f() -//│ ═══[COMPILATION ERROR] No definition found in scope for member 'nonLocalRetHandler$f' -//│ ═══[RUNTIME ERROR] ReferenceError: nonLocalRetHandler$f is not defined +//│ ═══[RUNTIME ERROR] Error: Effect $_non$_local$_return$_effect$_22 is raised in a constructor //│ ═══[RUNTIME ERROR] Expected: '100', got: 'undefined' :fixme diff --git a/hkmc2/shared/src/test/mlscript/handlers/RecursiveHandlers.mls b/hkmc2/shared/src/test/mlscript/handlers/RecursiveHandlers.mls index 4fdeb1a1af..d37fafb8a8 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/RecursiveHandlers.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/RecursiveHandlers.mls @@ -155,187 +155,188 @@ if true do h1.perform(()) str //│ JS (unsanitized): -//│ let str, scrut, h11, tmp11, tmp12; -//│ str = ""; -//│ scrut = true; -//│ if (scrut === true) { -//│ let handleBlock$9, Handler$h2$perform1, Handler$h2$3, handleBlock$10, Handler$h1$perform1, Handler$h1$3, Handler$h1$perform$1, handleBlock$$2, Handler$h2$perform$1; -//│ Handler$h1$perform$1 = function Handler$h1$perform$(Handler$h1$$instance, arg, k) { -//│ let tmp13, tmp14, tmp15, pc; -//│ if (runtime.resumePc === -1) { -//│ pc = 0; -//│ } else { -//│ let saveOffset; -//│ pc = runtime.resumePc; -//│ runtime.resumePc = -1; +//│ let str, scrut, h11, tmp11, tmp12, handleBlock$9, Handler$h2$perform1, Handler$h2$3, handleBlock$10, Handler$h1$perform1, Handler$h1$3, handleBlock$$2, Handler$h2$perform$1, Handler$h1$perform$1; +//│ Handler$h1$perform$1 = function Handler$h1$perform$(arg) { +//│ return (k) => { +//│ return Handler$h1$perform1(arg, k) +//│ } +//│ }; +//│ Handler$h1$perform1 = function Handler$h1$perform(arg, k) { +//│ let tmp13, tmp14, tmp15, pc; +//│ if (runtime.resumePc === -1) { +//│ pc = 0; +//│ } else { +//│ let saveOffset; +//│ pc = runtime.resumePc; +//│ runtime.resumePc = -1; +//│ } +//│ main: while (true) { +//│ switch (pc) { +//│ case 0: +//│ tmp13 = str + "A"; +//│ str = tmp13; +//│ runtime.resumeValue = runtime.safeCall(k(arg)); +//│ if (runtime.curEffect !== null) { +//│ return runtime.unwind(Handler$h1$perform1, 1, null, null, null, 1, 2, arg, k, 0) +//│ } +//│ pc = 1; +//│ continue main; +//│ break; +//│ case 1: +//│ tmp14 = runtime.resumeValue; +//│ tmp15 = str + "A"; +//│ str = tmp15; +//│ return runtime.Unit; +//│ break; //│ } -//│ main: while (true) { -//│ switch (pc) { -//│ case 0: -//│ tmp13 = str + "A"; -//│ str = tmp13; -//│ runtime.resumeValue = runtime.safeCall(k(arg)); -//│ if (runtime.curEffect !== null) { -//│ return runtime.unwind(Handler$h1$perform$1, 1, "RecursiveHandlers.mls:147:7", null, null, 1, 3, Handler$h1$$instance, arg, k, 0) -//│ } -//│ pc = 1; -//│ continue main; -//│ break; -//│ case 1: -//│ tmp14 = runtime.resumeValue; -//│ tmp15 = str + "A"; -//│ str = tmp15; -//│ return runtime.Unit; -//│ break; -//│ } -//│ break; +//│ break; +//│ } +//│ }; +//│ Handler$h2$perform$1 = function Handler$h2$perform$(arg) { +//│ return (k) => { +//│ return Handler$h2$perform1(arg, k) +//│ } +//│ }; +//│ Handler$h2$perform1 = function Handler$h2$perform(arg, k) { +//│ let tmp13, tmp14, tmp15, tmp16, tmp17, pc; +//│ if (runtime.resumePc === -1) { +//│ pc = 0; +//│ } else { +//│ let saveOffset; +//│ pc = runtime.resumePc; +//│ runtime.resumePc = -1; +//│ } +//│ main: while (true) { +//│ switch (pc) { +//│ case 0: +//│ tmp13 = str + "B"; +//│ tmp14 = str + tmp13; +//│ str = tmp14; +//│ runtime.resumeValue = runtime.safeCall(k(arg)); +//│ if (runtime.curEffect !== null) { +//│ return runtime.unwind(Handler$h2$perform1, 1, null, null, null, 1, 2, arg, k, 0) +//│ } +//│ pc = 1; +//│ continue main; +//│ break; +//│ case 1: +//│ tmp15 = runtime.resumeValue; +//│ tmp16 = str + "B"; +//│ tmp17 = str + tmp16; +//│ str = tmp17; +//│ return runtime.Unit; +//│ break; //│ } -//│ }; -//│ Handler$h1$perform1 = function Handler$h1$perform(Handler$h1$$instance, arg) { -//│ return (k) => { -//│ return Handler$h1$perform$1(Handler$h1$$instance, arg, k) +//│ break; +//│ } +//│ }; +//│ handleBlock$$2 = (undefined, function (h22) { +//│ return () => { +//│ return handleBlock$9(h22) +//│ } +//│ }); +//│ (class Handler$h2$2 extends Effect1 { +//│ static { +//│ Handler$h2$3 = this +//│ } +//│ constructor() { +//│ let tmp13; +//│ tmp13 = super(); +//│ if (runtime.curEffect !== null) { +//│ tmp13 = runtime.illegalEffect("in a constructor"); //│ } -//│ }; -//│ (class Handler$h1$2 extends Effect1 { -//│ static { -//│ Handler$h1$3 = this +//│ tmp13; +//│ } +//│ #Handler$h2$$cap; +//│ perform(arg) { +//│ let Handler$h2$perform$here; +//│ Handler$h2$perform$here = Handler$h2$perform$1(arg); +//│ return runtime.mkEffect(this, Handler$h2$perform$here) +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "Handler$h2$"]; +//│ }); +//│ handleBlock$9 = (undefined, function (h22) { +//│ let tmp13, pc; +//│ if (runtime.resumePc === -1) { +//│ pc = 0; +//│ } else { +//│ let saveOffset; +//│ pc = runtime.resumePc; +//│ runtime.resumePc = -1; +//│ } +//│ main: while (true) { +//│ switch (pc) { +//│ case 0: +//│ runtime.resumeValue = runtime.safeCall(h22.perform(runtime.Unit)); +//│ if (runtime.curEffect !== null) { +//│ return runtime.unwind(handleBlock$9, 1, "RecursiveHandlers.mls:154:5", null, null, 1, 1, h22, 0) +//│ } +//│ pc = 1; +//│ continue main; +//│ break; +//│ case 1: +//│ tmp13 = runtime.resumeValue; +//│ return runtime.safeCall(h11.perform(runtime.Unit)); +//│ break; //│ } -//│ constructor() { -//│ let tmp13; -//│ tmp13 = super(); -//│ if (runtime.curEffect !== null) { -//│ tmp13 = runtime.illegalEffect("in a constructor"); -//│ } -//│ tmp13; +//│ break; +//│ } +//│ }); +//│ (class Handler$h1$2 extends Effect1 { +//│ static { +//│ Handler$h1$3 = this +//│ } +//│ constructor() { +//│ let tmp13; +//│ tmp13 = super(); +//│ if (runtime.curEffect !== null) { +//│ tmp13 = runtime.illegalEffect("in a constructor"); //│ } -//│ perform(arg) { -//│ let Handler$h1$perform$here; -//│ Handler$h1$perform$here = runtime.safeCall(Handler$h1$perform1(this, arg)); -//│ return runtime.mkEffect(this, Handler$h1$perform$here) +//│ tmp13; +//│ } +//│ #Handler$h1$$cap; +//│ perform(arg) { +//│ let Handler$h1$perform$here; +//│ Handler$h1$perform$here = Handler$h1$perform$1(arg); +//│ return runtime.mkEffect(this, Handler$h1$perform$here) +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "Handler$h1$"]; +//│ }); +//│ handleBlock$10 = (undefined, function () { +//│ let h22, tmp13, handleBlock$$here, pc; +//│ if (runtime.resumePc === -1) { +//│ pc = 0; +//│ } else { +//│ let saveOffset; +//│ pc = runtime.resumePc; +//│ runtime.resumePc = -1; +//│ } +//│ main: while (true) { +//│ switch (pc) { +//│ case 0: +//│ h22 = new Handler$h2$3(); +//│ handleBlock$$here = handleBlock$$2(h22); +//│ runtime.resumeValue = runtime.enterHandleBlock(h22, handleBlock$$here); +//│ if (runtime.curEffect !== null) { +//│ return runtime.unwind(handleBlock$10, 1, null, null, null, 1, 0, 0) +//│ } +//│ pc = 1; +//│ continue main; +//│ break; +//│ case 1: +//│ tmp13 = runtime.resumeValue; +//│ return tmp13; +//│ break; //│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "Handler$h1$"]; -//│ }); +//│ break; +//│ } +//│ }); +//│ str = ""; +//│ scrut = true; +//│ if (scrut === true) { //│ h11 = new Handler$h1$3(); -//│ Handler$h2$perform$1 = function Handler$h2$perform$(Handler$h2$$instance, arg, k) { -//│ let tmp13, tmp14, tmp15, tmp16, tmp17, pc; -//│ if (runtime.resumePc === -1) { -//│ pc = 0; -//│ } else { -//│ let saveOffset; -//│ pc = runtime.resumePc; -//│ runtime.resumePc = -1; -//│ } -//│ main: while (true) { -//│ switch (pc) { -//│ case 0: -//│ tmp13 = str + "B"; -//│ tmp14 = str + tmp13; -//│ str = tmp14; -//│ runtime.resumeValue = runtime.safeCall(k(arg)); -//│ if (runtime.curEffect !== null) { -//│ return runtime.unwind(Handler$h2$perform$1, 1, "RecursiveHandlers.mls:152:7", null, null, 1, 3, Handler$h2$$instance, arg, k, 0) -//│ } -//│ pc = 1; -//│ continue main; -//│ break; -//│ case 1: -//│ tmp15 = runtime.resumeValue; -//│ tmp16 = str + "B"; -//│ tmp17 = str + tmp16; -//│ str = tmp17; -//│ return runtime.Unit; -//│ break; -//│ } -//│ break; -//│ } -//│ }; -//│ Handler$h2$perform1 = function Handler$h2$perform(Handler$h2$$instance, arg) { -//│ return (k) => { -//│ return Handler$h2$perform$1(Handler$h2$$instance, arg, k) -//│ } -//│ }; -//│ (class Handler$h2$2 extends Effect1 { -//│ static { -//│ Handler$h2$3 = this -//│ } -//│ constructor() { -//│ let tmp13; -//│ tmp13 = super(); -//│ if (runtime.curEffect !== null) { -//│ tmp13 = runtime.illegalEffect("in a constructor"); -//│ } -//│ tmp13; -//│ } -//│ perform(arg) { -//│ let Handler$h2$perform$here; -//│ Handler$h2$perform$here = runtime.safeCall(Handler$h2$perform1(this, arg)); -//│ return runtime.mkEffect(this, Handler$h2$perform$here) -//│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "Handler$h2$"]; -//│ }); -//│ handleBlock$$2 = function handleBlock$$(h22) { -//│ let tmp13, pc; -//│ if (runtime.resumePc === -1) { -//│ pc = 0; -//│ } else { -//│ let saveOffset; -//│ pc = runtime.resumePc; -//│ runtime.resumePc = -1; -//│ } -//│ main: while (true) { -//│ switch (pc) { -//│ case 0: -//│ runtime.resumeValue = runtime.safeCall(h22.perform(runtime.Unit)); -//│ if (runtime.curEffect !== null) { -//│ return runtime.unwind(handleBlock$$2, 1, "RecursiveHandlers.mls:154:5", null, null, 1, 1, h22, 0) -//│ } -//│ pc = 1; -//│ continue main; -//│ break; -//│ case 1: -//│ tmp13 = runtime.resumeValue; -//│ return runtime.safeCall(h11.perform(runtime.Unit)); -//│ break; -//│ } -//│ break; -//│ } -//│ }; -//│ handleBlock$9 = (undefined, function (h22) { -//│ return () => { -//│ return handleBlock$$2(h22) -//│ } -//│ }); -//│ handleBlock$10 = (undefined, function () { -//│ let h22, tmp13, handleBlock$$here, pc; -//│ if (runtime.resumePc === -1) { -//│ pc = 0; -//│ } else { -//│ let saveOffset; -//│ pc = runtime.resumePc; -//│ runtime.resumePc = -1; -//│ } -//│ main: while (true) { -//│ switch (pc) { -//│ case 0: -//│ h22 = new Handler$h2$3(); -//│ handleBlock$$here = runtime.safeCall(handleBlock$9(h22)); -//│ runtime.resumeValue = runtime.enterHandleBlock(h22, handleBlock$$here); -//│ if (runtime.curEffect !== null) { -//│ return runtime.unwind(handleBlock$10, 1, null, null, null, 1, 0, 0) -//│ } -//│ pc = 1; -//│ continue main; -//│ break; -//│ case 1: -//│ tmp13 = runtime.resumeValue; -//│ return tmp13; -//│ break; -//│ } -//│ break; -//│ } -//│ }); //│ tmp11 = runtime.enterHandleBlock(h11, handleBlock$10); //│ if (runtime.curEffect !== null) { //│ tmp11 = runtime.topLevelEffect(false); diff --git a/hkmc2/shared/src/test/mlscript/handlers/UserThreadsSafe.mls b/hkmc2/shared/src/test/mlscript/handlers/UserThreadsSafe.mls index c0290957d2..8681ec7f28 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/UserThreadsSafe.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/UserThreadsSafe.mls @@ -44,11 +44,11 @@ in //│ > main end //│ > f 2 //│ ═══[RUNTIME ERROR] Error: Unhandled effect Handler$h$ -//│ at f (UserThreadsSafe.mls:18:3) -//│ at while$ (UserThreadsSafe.mls:11:7) +//│ at f (UserThreadsSafe.mls:18:4) +//│ at while (pc=1) //│ at ThreadEffect#drain (pc=1) -//│ at Handler$h$fork$ (UserThreadsSafe.mls:36:5) -//│ at Handler$h$fork$ (UserThreadsSafe.mls:36:5) +//│ at Handler$h$fork (pc=1) +//│ at Handler$h$fork (pc=1) // FIFO @@ -67,10 +67,10 @@ in //│ > main end //│ > f 0 //│ ═══[RUNTIME ERROR] Error: Unhandled effect Handler$h$2 -//│ at f (UserThreadsSafe.mls:18:3) -//│ at while$ (UserThreadsSafe.mls:11:7) +//│ at f (UserThreadsSafe.mls:18:4) +//│ at while (pc=1) //│ at ThreadEffect#drain (pc=1) -//│ at Handler$h$fork$ (UserThreadsSafe.mls:59:5) -//│ at Handler$h$fork$ (UserThreadsSafe.mls:59:5) +//│ at Handler$h$fork (pc=1) +//│ at Handler$h$fork (pc=1) diff --git a/hkmc2/shared/src/test/mlscript/handlers/UserThreadsUnsafe.mls b/hkmc2/shared/src/test/mlscript/handlers/UserThreadsUnsafe.mls index 2b3e28bb4c..4abf8b990a 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/UserThreadsUnsafe.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/UserThreadsUnsafe.mls @@ -47,11 +47,11 @@ in //│ > main end //│ > f 2 //│ ═══[RUNTIME ERROR] Error: Unhandled effect Handler$h$ -//│ at f (UserThreadsUnsafe.mls:13:3) -//│ at while$ (UserThreadsUnsafe.mls:30:5) +//│ at f (UserThreadsUnsafe.mls:13:4) +//│ at while (pc=1) //│ at drain (pc=1) -//│ at Handler$h$fork$ (UserThreadsUnsafe.mls:39:5) -//│ at Handler$h$fork$ (UserThreadsUnsafe.mls:39:5) +//│ at Handler$h$fork (pc=1) +//│ at Handler$h$fork (pc=1) // FIFO @@ -70,10 +70,10 @@ in //│ > main end //│ > f 1 //│ ═══[RUNTIME ERROR] Error: Unhandled effect Handler$h$ -//│ at f (UserThreadsUnsafe.mls:13:3) -//│ at while$ (UserThreadsUnsafe.mls:30:5) +//│ at f (UserThreadsUnsafe.mls:13:4) +//│ at while (pc=1) //│ at drain (pc=1) -//│ at Handler$h$fork$ (UserThreadsUnsafe.mls:62:5) -//│ at Handler$h$fork$ (UserThreadsUnsafe.mls:62:5) +//│ at Handler$h$fork (pc=1) +//│ at Handler$h$fork (pc=1) diff --git a/hkmc2/shared/src/test/mlscript/lifter/ClassInFun.mls b/hkmc2/shared/src/test/mlscript/lifter/ClassInFun.mls index d0f005f9ce..514631c452 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/ClassInFun.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/ClassInFun.mls @@ -90,20 +90,39 @@ fun f() = Good() f().foo() //│ JS (unsanitized): -//│ let Bad1, Good1, f6, tmp5, Bad$, Good$, f$capture3; -//│ Good$ = function Good$(isMut, x, y, z, f$capture4) { -//│ let tmp6, tmp7; +//│ let Bad1, Good1, f6, tmp5, Capture$scope01, Bad$, Good$; +//│ (class Capture$scope0 { +//│ static { +//│ Capture$scope01 = this +//│ } +//│ constructor(z$0, w$1) { +//│ this.w$1 = w$1; +//│ this.z$0 = z$0; +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "Capture$scope0"]; +//│ }); +//│ Good$ = function Good$(isMut, scope0$cap, x, y) { +//│ let tmp6; //│ if (isMut === true) { //│ tmp6 = new Good1.class(); //│ } else { //│ tmp6 = globalThis.Object.freeze(new Good1.class()); //│ } -//│ tmp7 = tmp6(x, y, z, f$capture4); -//│ return tmp7 +//│ return tmp6(scope0$cap, x, y) +//│ }; +//│ Bad$ = function Bad$(isMut, scope0$cap) { +//│ let tmp6; +//│ if (isMut === true) { +//│ tmp6 = new Bad1.class(); +//│ } else { +//│ tmp6 = globalThis.Object.freeze(new Bad1.class()); +//│ } +//│ return tmp6(scope0$cap) //│ }; //│ Good1 = function Good() { -//│ return (x, y, z, f$capture4) => { -//│ return globalThis.Object.freeze(new Good.class()(x, y, z, f$capture4)); +//│ return (scope0$cap, x, y) => { +//│ return globalThis.Object.freeze(new Good.class()(scope0$cap, x, y)); //│ } //│ }; //│ (class Good { @@ -111,49 +130,27 @@ f().foo() //│ Good1.class = this //│ } //│ constructor() { -//│ return (x, y, z, f$capture4) => { -//│ this.f$capture = f$capture4; +//│ return (scope0$cap, x, y) => { +//│ this.scope0$cap = scope0$cap; //│ this.x = x; //│ this.y = y; -//│ this.z = z; //│ return this; //│ } //│ } -//│ #f$capture; -//│ #x; -//│ #y; -//│ #z; -//│ get f$capture() { return this.#f$capture; } -//│ set f$capture(value) { this.#f$capture = value; } -//│ get x() { return this.#x; } -//│ set x(value) { this.#x = value; } -//│ get y() { return this.#y; } -//│ set y(value) { this.#y = value; } -//│ get z() { return this.#z; } -//│ set z(value) { this.#z = value; } +//│ #Good$cap; //│ foo() { //│ let tmp6, tmp7; -//│ this.z = 100; +//│ this.scope0$cap.z$0 = 100; //│ tmp6 = this.x + this.y; -//│ tmp7 = tmp6 + this.z; -//│ return tmp7 + this.f$capture.w$capture$0 +//│ tmp7 = tmp6 + this.scope0$cap.z$0; +//│ return tmp7 + this.scope0$cap.w$1 //│ } //│ toString() { return runtime.render(this); } //│ static [definitionMetadata] = ["class", "Good", []]; //│ }); -//│ Bad$ = function Bad$(isMut, f$capture4) { -//│ let tmp6, tmp7; -//│ if (isMut === true) { -//│ tmp6 = new Bad1.class(); -//│ } else { -//│ tmp6 = globalThis.Object.freeze(new Bad1.class()); -//│ } -//│ tmp7 = tmp6(f$capture4); -//│ return tmp7 -//│ }; //│ Bad1 = function Bad() { -//│ return (f$capture4) => { -//│ return globalThis.Object.freeze(new Bad.class()(f$capture4)); +//│ return (scope0$cap) => { +//│ return globalThis.Object.freeze(new Bad.class()(scope0$cap)); //│ } //│ }; //│ (class Bad { @@ -161,41 +158,29 @@ f().foo() //│ Bad1.class = this //│ } //│ constructor() { -//│ return (f$capture4) => { -//│ this.f$capture = f$capture4; +//│ return (scope0$cap) => { +//│ this.scope0$cap = scope0$cap; //│ return this; //│ } //│ } -//│ #f$capture; -//│ get f$capture() { return this.#f$capture; } -//│ set f$capture(value) { this.#f$capture = value; } +//│ #Bad$cap; //│ foo() { -//│ this.f$capture.w$capture$0 = 10000; +//│ this.scope0$cap.w$1 = 10000; //│ return runtime.Unit //│ } //│ toString() { return runtime.render(this); } //│ static [definitionMetadata] = ["class", "Bad", []]; //│ }); -//│ (class f$capture2 { -//│ static { -//│ f$capture3 = this -//│ } -//│ constructor(w$capture$0) { -//│ this.w$capture$0 = w$capture$0; -//│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "f$capture"]; -//│ }); //│ f6 = function f() { -//│ let x, y, z, w, tmp6, tmp7, capture; -//│ capture = new f$capture3(null); +//│ let x, y, z, w, tmp6, tmp7, scope0$cap; +//│ scope0$cap = new Capture$scope01(z, w); //│ x = 1; //│ y = 10; -//│ z = 10; -//│ capture.w$capture$0 = 1000; -//│ tmp6 = Bad$(false, capture); +//│ scope0$cap.z$0 = 10; +//│ scope0$cap.w$1 = 1000; +//│ tmp6 = Bad$(false, scope0$cap); //│ tmp7 = tmp6.foo(); -//│ return Good$(false, x, y, z, capture) +//│ return Good$(false, scope0$cap, x, y) //│ }; //│ tmp5 = f6(); //│ runtime.safeCall(tmp5.foo()) @@ -244,6 +229,7 @@ f().get() :lift :w +:expect 2 fun f(x) = class Test() with fun get() = @@ -271,3 +257,11 @@ fun test(f) = f(x) A(1) +// Check that the extra params are not printed +:expect "A(2)" +fun f(x) = + data class A(y) with + fun hi = x + y + A(2).toString() +f(0) +//│ = "A(2)" diff --git a/hkmc2/shared/src/test/mlscript/lifter/ClassWithCompanion.mls b/hkmc2/shared/src/test/mlscript/lifter/ClassWithCompanion.mls index 4a38ce4a73..5e3a2396d1 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/ClassWithCompanion.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/ClassWithCompanion.mls @@ -9,7 +9,9 @@ fun foo(x) = module C with val empty = C(123) C.empty -//│ ═══[WARNING] Modules are not yet lifted. +//│ ╔══[WARNING] Modules are not yet lifted. +//│ ║ l.9: module C with +//│ ╙── ^ foo(10).get //│ = [10, 123] @@ -21,7 +23,9 @@ fun foo(x) = module C with val empty = new C C.empty -//│ ═══[WARNING] Modules are not yet lifted. +//│ ╔══[WARNING] Modules are not yet lifted. +//│ ║ l.23: module C with +//│ ╙── ^ foo(10).get //│ = 10 diff --git a/hkmc2/shared/src/test/mlscript/lifter/CompanionsInFun.mls b/hkmc2/shared/src/test/mlscript/lifter/CompanionsInFun.mls index cf57dc78d0..290bfbf624 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/CompanionsInFun.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/CompanionsInFun.mls @@ -9,11 +9,18 @@ fun f(x) = fun get = x module A fun g = new A -//│ ═══[WARNING] Modules are not yet lifted. +//│ ╔══[WARNING] Modules are not yet lifted. +//│ ║ l.10: module A +//│ ╙── ^ //│ JS (unsanitized): -//│ let f, g$; -//│ g$ = function g$(A$member, A1, x) { -//│ return globalThis.Object.freeze(new A$member()) +//│ let g, f, g$; +//│ g$ = function g$(A$) { +//│ return () => { +//│ return g(A$) +//│ } +//│ }; +//│ g = function g(A$) { +//│ return globalThis.Object.freeze(new A$()) //│ }; //│ f = function f(x) { //│ let A1; @@ -22,6 +29,8 @@ fun f(x) = //│ A1 = this //│ } //│ constructor() {} +//│ static #A_mod$cap; +//│ #A$cap; //│ get get() { //│ return x; //│ } @@ -30,6 +39,8 @@ fun f(x) = //│ }); //│ return runtime.Unit //│ }; -//│ ═══[WARNING] Modules are not yet lifted. +//│ ╔══[WARNING] Modules are not yet lifted. +//│ ║ l.10: module A +//│ ╙── ^ diff --git a/hkmc2/shared/src/test/mlscript/lifter/DefnsInClass.mls b/hkmc2/shared/src/test/mlscript/lifter/DefnsInClass.mls index 8adf8005d3..ce0fe5e508 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/DefnsInClass.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/DefnsInClass.mls @@ -1,71 +1,24 @@ :js :lift -:sjs data class A(x) with data class B(y) with fun getB() = x + y fun getA() = B(2).getB() -//│ JS (unsanitized): -//│ let B1, A1, B$; -//│ B$ = function B$(isMut, A$instance, y) { -//│ let tmp, tmp1; -//│ if (isMut === true) { -//│ tmp = new B1.class(y); -//│ } else { -//│ tmp = globalThis.Object.freeze(new B1.class(y)); -//│ } -//│ tmp1 = tmp(A$instance); -//│ return tmp1 -//│ }; -//│ B1 = function B(y) { -//│ return (A$instance) => { -//│ return globalThis.Object.freeze(new B.class(y)(A$instance)); -//│ } -//│ }; -//│ (class B { -//│ static { -//│ B1.class = this -//│ } -//│ constructor(y) { -//│ return (A$instance) => { -//│ this.A$instance = A$instance; -//│ this.y = y; -//│ return this; -//│ } -//│ } -//│ #A$instance; -//│ get A$instance() { return this.#A$instance; } -//│ set A$instance(value) { this.#A$instance = value; } -//│ getB() { -//│ return this.A$instance.x + this.y -//│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "B", ["y"]]; -//│ }); -//│ A1 = function A(x) { -//│ return globalThis.Object.freeze(new A.class(x)); -//│ }; -//│ (class A { -//│ static { -//│ A1.class = this -//│ } -//│ constructor(x) { -//│ this.x = x; -//│ } -//│ getA() { -//│ let tmp; -//│ tmp = B$(false, this, 2); -//│ return tmp.getB() -//│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "A", ["x"]]; -//│ }); :expect 3 A(1).getA() //│ = 3 +data class A1(x) with + data class B(y) with + fun getB() = x + y + fun getA() = (new B(2)).getB() + +:expect 3 +(new A1(1)).getA() +//│ = 3 + :sjs class A with val x = @@ -73,27 +26,45 @@ class A with g (new A).x() //│ JS (unsanitized): -//│ let g, A3, tmp1, g$; -//│ g$ = function g$(A$instance) { +//│ let g, A3, tmp2; +//│ g = function g() { //│ return 2 //│ }; -//│ g = function g(A$instance) { -//│ return () => { -//│ return g$(A$instance) -//│ } -//│ }; //│ (class A2 { //│ static { //│ A3 = this //│ } //│ constructor() { -//│ let g$here; -//│ g$here = runtime.safeCall(g(this)); -//│ this.x = g$here; +//│ this.x = g; //│ } +//│ #A$cap; //│ toString() { return runtime.render(this); } //│ static [definitionMetadata] = ["class", "A"]; //│ }); -//│ tmp1 = globalThis.Object.freeze(new A3()); -//│ runtime.safeCall(tmp1.x()) +//│ tmp2 = globalThis.Object.freeze(new A3()); +//│ runtime.safeCall(tmp2.x()) +//│ = 2 + +:expect 2 +class A(a) with + let x = 2 + fun f() = + fun g() = x + g() +A(2).f() //│ = 2 + +class A with + let x = 2 + val bruh = () => x + fun f = + fun g() = + set x += 2 + this + g +let a = new A +a.f() +a.f() +a.bruh() +//│ = 6 +//│ a = A { A$cap: Capture$A { x$0: 6 }, bruh: fun } diff --git a/hkmc2/shared/src/test/mlscript/lifter/FunInFun.mls b/hkmc2/shared/src/test/mlscript/lifter/FunInFun.mls index 1cb2e47424..381a2dcc93 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/FunInFun.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/FunInFun.mls @@ -62,22 +62,22 @@ fun f(used1, unused1) = foo(used2) + unused2 f(1, 2) //│ JS (unsanitized): -//│ let g1, f3, g$3; -//│ g$3 = function g$(used1, used2, g_arg) { +//│ let g3, f3, g$3; +//│ g$3 = function g$(used1, used2) { +//│ return (g_arg) => { +//│ return g3(used1, used2, g_arg) +//│ } +//│ }; +//│ g3 = function g(used1, used2, g_arg) { //│ let used3; //│ used3 = 2; //│ return used1 + used2 //│ }; -//│ g1 = function g(used1, used2) { -//│ return (g_arg) => { -//│ return g$3(used1, used2, g_arg) -//│ } -//│ }; //│ f3 = function f(used1, unused1) { //│ let used2, unused2, foo2, tmp1, g$here; //│ used2 = unused1; //│ unused2 = 2; -//│ g$here = runtime.safeCall(g1(used1, used2)); +//│ g$here = g$3(used1, used2); //│ foo2 = g$here; //│ tmp1 = runtime.safeCall(foo2(used2)); //│ return tmp1 + unused2 @@ -92,8 +92,13 @@ fun f(a1, a2, a3, a4, a5, a6) = g f(1,2,3,4,5,6) //│ JS (unsanitized): -//│ let f4, g$4; +//│ let g4, f4, g$4; //│ g$4 = function g$(a1, a2, a3, a4, a5, a6) { +//│ return () => { +//│ return g4(a1, a2, a3, a4, a5, a6) +//│ } +//│ }; +//│ g4 = function g(a1, a2, a3, a4, a5, a6) { //│ let tmp1, tmp2, tmp3, tmp4; //│ tmp1 = a1 + a2; //│ tmp2 = tmp1 + a3; @@ -103,7 +108,7 @@ f(1,2,3,4,5,6) //│ }; //│ f4 = function f(a1, a2, a3, a4, a5, a6) { //│ let tmp1; -//│ tmp1 = g$4(a1, a2, a3, a4, a5, a6); +//│ tmp1 = g4(a1, a2, a3, a4, a5, a6); //│ return tmp1 //│ }; //│ f4(1, 2, 3, 4, 5, 6) @@ -177,29 +182,39 @@ fun f(unused, immutable, mutated) = a + h() + unused f(1, 2, 1000) //│ JS (unsanitized): -//│ let f7, h$2, g$6, f$capture5; -//│ g$6 = function g$(immutable, f$capture6) { -//│ f$capture6.mutated$capture$0 = 2; -//│ return immutable + f$capture6.mutated$capture$0 +//│ let g6, h2, f7, Capture$f3, h$2, g$6; +//│ g$6 = function g$(f$cap, immutable) { +//│ return () => { +//│ return g6(f$cap, immutable) +//│ } //│ }; -//│ h$2 = function h$(f$capture6) { -//│ return f$capture6.mutated$capture$0 +//│ h$2 = function h$(f$cap) { +//│ return () => { +//│ return h2(f$cap) +//│ } +//│ }; +//│ g6 = function g(f$cap, immutable) { +//│ f$cap.mutated$0 = 2; +//│ return immutable + f$cap.mutated$0 +//│ }; +//│ h2 = function h(f$cap) { +//│ return f$cap.mutated$0 //│ }; -//│ (class f$capture4 { +//│ (class Capture$f2 { //│ static { -//│ f$capture5 = this +//│ Capture$f3 = this //│ } -//│ constructor(mutated$capture$0) { -//│ this.mutated$capture$0 = mutated$capture$0; +//│ constructor(mutated$0) { +//│ this.mutated$0 = mutated$0; //│ } //│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "f$capture"]; +//│ static [definitionMetadata] = ["class", "Capture$f"]; //│ }); //│ f7 = function f(unused, immutable, mutated) { -//│ let a1, tmp8, tmp9, capture; -//│ capture = new f$capture5(mutated); -//│ a1 = g$6(immutable, capture); -//│ tmp8 = h$2(capture); +//│ let a1, tmp8, tmp9, f$cap; +//│ f$cap = new Capture$f3(mutated); +//│ a1 = g6(f$cap, immutable); +//│ tmp8 = h2(f$cap); //│ tmp9 = a1 + tmp8; //│ return tmp9 + unused //│ }; @@ -266,41 +281,46 @@ fun g() = f g()(1) //│ JS (unsanitized): -//│ let f14, g6, tmp8, f$1, h$4, g$capture1; -//│ h$4 = function h$(x1, k, g$capture2) { +//│ let h5, f15, g13, tmp8, Capture$scope07, f$1, h$4; +//│ (class Capture$scope06 { +//│ static { +//│ Capture$scope07 = this +//│ } +//│ constructor(y$0) { +//│ this.y$0 = y$0; +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "Capture$scope0"]; +//│ }); +//│ h$4 = function h$(scope0$cap, x1, k) { +//│ return () => { +//│ return h5(scope0$cap, x1, k) +//│ } +//│ }; +//│ h5 = function h(scope0$cap, x1, k) { //│ k = 5; //│ x1 = 4; -//│ return x1 + g$capture2.y$capture$0 +//│ return x1 + scope0$cap.y$0 //│ }; -//│ f$1 = function f$(g$capture2, x1) { +//│ f$1 = function f$(scope0$cap) { +//│ return (x1) => { +//│ return f15(scope0$cap, x1) +//│ } +//│ }; +//│ f15 = function f(scope0$cap, x1) { //│ let k; //│ k = 4; -//│ g$capture2.y$capture$0 = 2; +//│ scope0$cap.y$0 = 2; //│ return x1 //│ }; -//│ f14 = function f(g$capture2) { -//│ return (x1) => { -//│ return f$1(g$capture2, x1) -//│ } -//│ }; -//│ (class g$capture { -//│ static { -//│ g$capture1 = this -//│ } -//│ constructor(y$capture$0) { -//│ this.y$capture$0 = y$capture$0; -//│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "g$capture"]; -//│ }); -//│ g6 = function g() { -//│ let y1, capture, f$here; -//│ capture = new g$capture1(null); -//│ capture.y$capture$0 = 0; -//│ f$here = runtime.safeCall(f14(capture)); +//│ g13 = function g() { +//│ let y1, scope0$cap, f$here; +//│ scope0$cap = new Capture$scope07(y1); +//│ scope0$cap.y$0 = 0; +//│ f$here = f$1(scope0$cap); //│ return f$here //│ }; -//│ tmp8 = g6(); +//│ tmp8 = g13(); //│ runtime.safeCall(tmp8(1)) //│ = 1 @@ -385,30 +405,26 @@ fun f(x) = set y = 2 [g, g] //│ JS (unsanitized): -//│ let g12, f23, g$14; +//│ let g19, f25, g$14; //│ g$14 = function g$(y1) { -//│ return y1 -//│ }; -//│ g12 = function g(y1) { //│ return () => { -//│ return g$14(y1) +//│ return g19(y1) //│ } //│ }; -//│ f23 = function f(x1) { +//│ g19 = function g(y1) { +//│ return y1 +//│ }; +//│ f25 = function f(x1) { //│ let y1, scrut, g$here; //│ scrut = x1 < 0; //│ if (scrut === true) { //│ y1 = 1; -//│ g$here = runtime.safeCall(g12(y1)); +//│ g$here = g$14(y1); //│ return globalThis.Object.freeze([ //│ g$here, //│ g$here //│ ]) -//│ } else { -//│ y1 = 2; -//│ g$here = runtime.safeCall(g12(y1)); -//│ return globalThis.Object.freeze([ g$here, g$here ]) -//│ } +//│ } else { y1 = 2; g$here = g$14(y1); return globalThis.Object.freeze([ g$here, g$here ]) } //│ }; :sjs @@ -418,33 +434,33 @@ fun f(x) = set x += 1 [a, g] //│ JS (unsanitized): -//│ let g13, f24, g$15, f$capture17; -//│ g$15 = function g$(f$capture18) { -//│ return f$capture18.x$capture$0 -//│ }; -//│ g13 = function g(f$capture18) { +//│ let g20, f26, Capture$f7, g$15; +//│ g$15 = function g$(f$cap) { //│ return () => { -//│ return g$15(f$capture18) +//│ return g20(f$cap) //│ } //│ }; -//│ (class f$capture16 { +//│ g20 = function g(f$cap) { +//│ return f$cap.x$0 +//│ }; +//│ (class Capture$f6 { //│ static { -//│ f$capture17 = this +//│ Capture$f7 = this //│ } -//│ constructor(x$capture$0) { -//│ this.x$capture$0 = x$capture$0; +//│ constructor(x$0) { +//│ this.x$0 = x$0; //│ } //│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "f$capture"]; +//│ static [definitionMetadata] = ["class", "Capture$f"]; //│ }); -//│ f24 = function f(x1) { -//│ let a1, tmp11, capture, g$here; -//│ capture = new f$capture17(x1); -//│ g$here = runtime.safeCall(g13(capture)); +//│ f26 = function f(x1) { +//│ let a1, tmp11, f$cap, g$here; +//│ f$cap = new Capture$f7(x1); +//│ g$here = g$15(f$cap); //│ a1 = g$here; -//│ tmp11 = capture.x$capture$0 + 1; -//│ capture.x$capture$0 = tmp11; -//│ g$here = runtime.safeCall(g13(capture)); +//│ tmp11 = f$cap.x$0 + 1; +//│ f$cap.x$0 = tmp11; +//│ g$here = g$15(f$cap); //│ return globalThis.Object.freeze([ a1, g$here ]) //│ }; @@ -465,3 +481,16 @@ fun f(x) = else g g + +// Spread arguments + +:expect [1] +fun f(x) = + fun g(...rest) = + print(x) + rest + let a = g + a(1) +f(2) +//│ > 2 +//│ = [1] diff --git a/hkmc2/shared/src/test/mlscript/lifter/Imports.mls b/hkmc2/shared/src/test/mlscript/lifter/Imports.mls index 9e11d7207d..76d46be07a 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/Imports.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/Imports.mls @@ -15,6 +15,7 @@ module A with //│ constructor() { //│ runtime.Unit; //│ } +//│ static #A_mod$cap; //│ static f(x) { //│ if (x instanceof Option.Some.class) { //│ return true @@ -22,6 +23,7 @@ module A with //│ return false //│ } //│ } +//│ #A$cap; //│ toString() { return runtime.render(this); } //│ static [definitionMetadata] = ["class", "A"]; //│ }); diff --git a/hkmc2/shared/src/test/mlscript/lifter/Loops.mls b/hkmc2/shared/src/test/mlscript/lifter/Loops.mls index cbffcbd5df..2514664105 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/Loops.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/Loops.mls @@ -11,27 +11,11 @@ fun foo() = while x < 5 do set x += 1 fs.push of () => x -//│ /!!!\ Uncaught error: java.lang.AssertionError: assertion failed: loops should be rewritten to functions before scope flattening :expect 5 foo() fs.0() -//│ ╔══[COMPILATION ERROR] No definition found in scope for member 'foo' -//│ ║ l.17: foo() -//│ ║ ^^^ -//│ ╟── which references the symbol introduced here -//│ ║ l.9: fun foo() = -//│ ║ ^^^^^^^^^^^ -//│ ║ l.10: let x = 1 -//│ ║ ^^^^^^^^^^^ -//│ ║ l.11: while x < 5 do -//│ ║ ^^^^^^^^^^^^^^^^ -//│ ║ l.12: set x += 1 -//│ ║ ^^^^^^^^^^^^^^ -//│ ║ l.13: ... (more lines omitted) ... -//│ ╙── ^^^^^^^^^^^^^^^^^ -//│ ═══[RUNTIME ERROR] ReferenceError: foo is not defined -//│ ═══[RUNTIME ERROR] Expected: '5', got: 'undefined' +//│ = 5 let fs = mut [] @@ -43,7 +27,6 @@ fun foo() = let x = i set i += 1 fs.push of () => x -//│ /!!!\ Uncaught error: java.lang.AssertionError: assertion failed: loops should be rewritten to functions before scope flattening // * Note that this works with while loop rewriting // * See [fixme:0] for cause of the issue @@ -51,22 +34,7 @@ fun foo() = :expect 1 foo() fs.0() -//│ ╔══[COMPILATION ERROR] No definition found in scope for member 'foo' -//│ ║ l.52: foo() -//│ ║ ^^^ -//│ ╟── which references the symbol introduced here -//│ ║ l.40: fun foo() = -//│ ║ ^^^^^^^^^^^ -//│ ║ l.41: let i = 1 -//│ ║ ^^^^^^^^^^^ -//│ ║ l.42: while i < 5 do -//│ ║ ^^^^^^^^^^^^^^^^ -//│ ║ l.43: let x = i -//│ ║ ^^^^^^^^^^^^^ -//│ ║ l.44: ... (more lines omitted) ... -//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -//│ ═══[RUNTIME ERROR] ReferenceError: foo is not defined -//│ ═══[RUNTIME ERROR] Expected: '1', got: 'undefined' +//│ = 1 :sjs @@ -75,25 +43,47 @@ fun foo() = while true do set x += 1 return () => x -//│ /!!!\ Uncaught error: java.lang.AssertionError: assertion failed: loops should be rewritten to functions before scope flattening +//│ JS (unsanitized): +//│ let foo2, lambda2, Capture$scope03, lambda$2; +//│ lambda$2 = (undefined, function (scope0$cap) { +//│ return () => { +//│ return lambda2(scope0$cap) +//│ } +//│ }); +//│ lambda2 = (undefined, function (scope0$cap) { +//│ return scope0$cap.x$0 +//│ }); +//│ (class Capture$scope02 { +//│ static { +//│ Capture$scope03 = this +//│ } +//│ constructor(x$0) { +//│ this.x$0 = x$0; +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "Capture$scope0"]; +//│ }); +//│ foo2 = function foo() { +//│ let x, tmp2, scope0$cap; +//│ scope0$cap = new Capture$scope03(x); +//│ scope0$cap.x$0 = 1; +//│ lbl: while (true) { +//│ let scrut, tmp3; +//│ scrut = true; +//│ if (scrut === true) { +//│ let lambda$here; +//│ tmp3 = scope0$cap.x$0 + 1; +//│ scope0$cap.x$0 = tmp3; +//│ lambda$here = lambda$2(scope0$cap); +//│ return lambda$here +//│ } else { tmp2 = runtime.Unit; } +//│ break; +//│ } +//│ return tmp2 +//│ }; :expect 2 foo()() -//│ ╔══[COMPILATION ERROR] No definition found in scope for member 'foo' -//│ ║ l.81: foo()() -//│ ║ ^^^ -//│ ╟── which references the symbol introduced here -//│ ║ l.73: fun foo() = -//│ ║ ^^^^^^^^^^^ -//│ ║ l.74: let x = 1 -//│ ║ ^^^^^^^^^^^ -//│ ║ l.75: while true do -//│ ║ ^^^^^^^^^^^^^^^ -//│ ║ l.76: set x += 1 -//│ ║ ^^^^^^^^^^^^^^ -//│ ║ l.77: ... (more lines omitted) ... -//│ ╙── ^^^^^^^^^^^^^ -//│ ═══[RUNTIME ERROR] ReferenceError: foo is not defined -//│ ═══[RUNTIME ERROR] Expected: '2', got: 'undefined' +//│ = 2 diff --git a/hkmc2/shared/src/test/mlscript/lifter/ModulesObjects.mls b/hkmc2/shared/src/test/mlscript/lifter/ModulesObjects.mls index e0db79b6c1..1498bb6da0 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/ModulesObjects.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/ModulesObjects.mls @@ -44,9 +44,7 @@ foo(10) //│ constructor(y) { //│ this.y = y; //│ } -//│ #y; -//│ get y() { return this.#y; } -//│ set y(value) { this.#y = value; } +//│ #M$cap; //│ foo() { //│ this.y = 2; //│ return runtime.Unit @@ -56,6 +54,19 @@ foo(10) //│ }); //│ foo1 = function foo(y) { let tmp; tmp = M$(false, y); return tmp.foo() }; //│ foo1(10) +//│ FAILURE: Unexpected runtime error +//│ FAILURE LOCATION: mkQuery (JSBackendDiffMaker.scala:161) +//│ ═══[RUNTIME ERROR] TypeError: Cannot assign to read only property 'y' of object '[object Object]' +//│ at M2.foo (REPL16:1:540) +//│ at foo (REPL16:1:855) +//│ at REPL16:1:896 +//│ at ContextifyScript.runInThisContext (node:vm:137:12) +//│ at REPLServer.defaultEval (node:repl:562:24) +//│ at bound (node:domain:433:15) +//│ at REPLServer.runBound [as eval] (node:domain:444:12) +//│ at REPLServer.onLine (node:repl:886:12) +//│ at REPLServer.emit (node:events:508:28) +//│ at REPLServer.emit (node:domain:489:12) // * `new mut` (or lack thereof) should be preserved by the lifter @@ -70,9 +81,12 @@ let cc = foo(10) c0 = cc.0 c1 = cc.1 -//│ c0 = C -//│ c1 = C -//│ cc = [C, C] +//│ c0 = C { foo$cap: Capture$foo { y$0: 10 } } +//│ c1 = C { foo$cap: Capture$foo { y$0: 10 } } +//│ cc = [ +//│ C { foo$cap: Capture$foo { y$0: 10 } }, +//│ C { foo$cap: Capture$foo { y$0: 10 } } +//│ ] :re set c0.x = 1 @@ -93,12 +107,13 @@ fun foo(y) = O let o = foo(10) -//│ o = O +//│ o = O { y: 10, O: undefined } :re set o.x = 1 o.x -//│ ═══[RUNTIME ERROR] Error: Access to required field 'x' yielded 'undefined' +//│ = 1 +//│ FAILURE: Unexpected lack of runtime error :sjs fun foo(x, y) = @@ -109,48 +124,33 @@ fun foo(x, y) = fun foo3 = M.foo2() foo3 //│ JS (unsanitized): -//│ let M5, foo4, foo3$, foo$capture3; -//│ foo3$ = function foo3$(M6, x, foo$capture4) { +//│ let M5, foo31, foo4, foo3$; +//│ foo3$ = function foo3$(M6) { +//│ return () => { +//│ return foo31(M6) +//│ } +//│ }; +//│ foo31 = function foo3(M6) { //│ return M6.foo2() //│ }; //│ (class M4 { //│ static { //│ M5 = this //│ } -//│ constructor(x, foo$capture4) { -//│ this.foo$capture = foo$capture4; +//│ constructor(x, y, M6) { //│ this.x = x; +//│ this.y = y; +//│ this.M = M6; //│ } -//│ #foo$capture; -//│ #x; -//│ get foo$capture() { return this.#foo$capture; } -//│ set foo$capture(value) { this.#foo$capture = value; } -//│ get x() { return this.#x; } -//│ set x(value) { this.#x = value; } +//│ #M$cap; //│ foo2() { -//│ this.foo$capture.y$capture$0 = 2; -//│ return this.x + this.foo$capture.y$capture$0 +//│ this.y = 2; +//│ return this.x + this.y //│ } //│ toString() { return runtime.render(this); } //│ static [definitionMetadata] = ["class", "M"]; //│ }); -//│ (class foo$capture2 { -//│ static { -//│ foo$capture3 = this -//│ } -//│ constructor(y$capture$0) { -//│ this.y$capture$0 = y$capture$0; -//│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "foo$capture"]; -//│ }); -//│ foo4 = function foo(x, y) { -//│ let tmp, M$1, capture; -//│ capture = new foo$capture3(y); -//│ M$1 = globalThis.Object.freeze(new M5(x, capture)); -//│ tmp = foo3$(M$1, x, capture); -//│ return tmp -//│ }; +//│ foo4 = function foo(x, y) { let tmp, M$1; M$1 = new M5(x, y, M$1); tmp = foo31(M$1); return tmp }; :expect 12 foo(10, 0) @@ -160,7 +160,9 @@ foo(10, 0) data class A(x) with object M with fun getB() = x - fun getA() = M.getB() + fun getA() = + fun bar = M.getB() + bar :expect 2 A(2).getA() @@ -204,6 +206,20 @@ object M with M.f()() //│ = 2 +// unlifted objects +:lift +:w +:expect 2 +fun f = + object A with + fun a = 2 + 0 is A + fun foo = A + foo +f.a +//│ ═══[WARNING] Cannot yet lift class/module `A` as it is used in an instance check. +//│ = 2 + :sjs module M with class A() with @@ -211,7 +227,7 @@ module M with val x = if A() is A then 2 else 3 M.A().get //│ JS (unsanitized): -//│ let M17, tmp3; +//│ let M17, tmp4; //│ (class M16 { //│ static { //│ M17 = this @@ -219,16 +235,18 @@ M.A().get //│ constructor() { //│ runtime.Unit; //│ } +//│ static #M_mod$cap; //│ static { -//│ let scrut, tmp4; +//│ let scrut, tmp5; //│ this.A = function A() { //│ return globalThis.Object.freeze(new A.class()); //│ }; -//│ (class A5 { +//│ (class A6 { //│ static { //│ M16.A.class = this //│ } //│ constructor() {} +//│ #A$cap; //│ get get() { //│ return M17.x + M16.x; //│ } @@ -237,17 +255,18 @@ M.A().get //│ }); //│ scrut = M16.A(); //│ if (scrut instanceof M16.A.class) { -//│ tmp4 = 2; +//│ tmp5 = 2; //│ } else { -//│ tmp4 = 3; +//│ tmp5 = 3; //│ } -//│ this.x = tmp4; +//│ this.x = tmp5; //│ } +//│ #M$cap; //│ toString() { return runtime.render(this); } //│ static [definitionMetadata] = ["class", "M"]; //│ }); -//│ tmp3 = M17.A(); -//│ tmp3.get +//│ tmp4 = M17.A(); +//│ tmp4.get //│ = 4 module M with @@ -315,14 +334,12 @@ A.x //│ = 2 // private fields in modules -:fixme module A with let x = 1 fun mtd = fun nested() = x nested() A.mtd -//│ ═══[COMPILATION ERROR] Uses of private fields cannot yet be lifted. -//│ ═══[RUNTIME ERROR] Error: MLscript call unexpectedly returned `undefined`, the forbidden value. +//│ = 1 diff --git a/hkmc2/shared/src/test/mlscript/lifter/Mutation.mls b/hkmc2/shared/src/test/mlscript/lifter/Mutation.mls index e7eff12d2f..2bea555fc1 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/Mutation.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/Mutation.mls @@ -12,32 +12,32 @@ fun foo() = xs.push(bar) set x = 2 //│ JS (unsanitized): -//│ let bar, foo, bar$, foo$capture1; -//│ bar$ = function bar$(foo$capture2) { -//│ return foo$capture2.x$capture$0 -//│ }; -//│ bar = function bar(foo$capture2) { -//│ return () => { -//│ return bar$(foo$capture2) -//│ } -//│ }; -//│ (class foo$capture { +//│ let bar, foo, Capture$scope01, bar$; +//│ (class Capture$scope0 { //│ static { -//│ foo$capture1 = this +//│ Capture$scope01 = this //│ } -//│ constructor(x$capture$0) { -//│ this.x$capture$0 = x$capture$0; +//│ constructor(x$0) { +//│ this.x$0 = x$0; //│ } //│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "foo$capture"]; +//│ static [definitionMetadata] = ["class", "Capture$scope0"]; //│ }); +//│ bar$ = function bar$(scope0$cap) { +//│ return () => { +//│ return bar(scope0$cap) +//│ } +//│ }; +//│ bar = function bar(scope0$cap) { +//│ return scope0$cap.x$0 +//│ }; //│ foo = function foo() { -//│ let x, tmp, capture, bar$here; -//│ capture = new foo$capture1(null); -//│ capture.x$capture$0 = 1; -//│ bar$here = runtime.safeCall(bar(capture)); +//│ let x, tmp, scope0$cap, bar$here; +//│ scope0$cap = new Capture$scope01(x); +//│ scope0$cap.x$0 = 1; +//│ bar$here = bar$(scope0$cap); //│ tmp = runtime.safeCall(xs.push(bar$here)); -//│ capture.x$capture$0 = 2; +//│ scope0$cap.x$0 = 2; //│ return runtime.Unit //│ }; diff --git a/hkmc2/shared/src/test/mlscript/staging/Syntax.mls b/hkmc2/shared/src/test/mlscript/staging/Syntax.mls index 94accabd06..aa471678e9 100644 --- a/hkmc2/shared/src/test/mlscript/staging/Syntax.mls +++ b/hkmc2/shared/src/test/mlscript/staging/Syntax.mls @@ -21,8 +21,6 @@ staged fun f() = 0 :staging :w staged module A -//│ ╔══[WARNING] `staged` keyword doesn't do anything currently. -//│ ║ l.23: staged module A -//│ ╙── ^^^^^^^^ //│ Pretty Lowered: -//│ define staged class A in set block$res = undefined in end +//│ define class A in set block$res = undefined in end +//│ FAILURE: Unexpected lack of warnings diff --git a/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls index c19e96a031..ebbdb9f985 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls @@ -165,12 +165,8 @@ module A with fun g(x) = if x < 0 then 0 else @tailcall f(x) @tailcall g(x - 1) A.f(10000) -//│ ╔══[ERROR] This tail call exits the current scope is not optimized. -//│ ║ l.165: fun g(x) = if x < 0 then 0 else @tailcall f(x) -//│ ╙── ^^^ -//│ ╔══[ERROR] This tail call exits the current scope is not optimized. -//│ ║ l.166: @tailcall g(x - 1) -//│ ╙── ^^^^^^^^ +//│ ═══[ERROR] This tail call exits the current scope is not optimized. +//│ ═══[ERROR] This tail call exits the current scope is not optimized. //│ ═══[RUNTIME ERROR] RangeError: Maximum call stack size exceeded // These calls are represented as field selections and don't yet have the explicitTailCall parameter. @@ -196,12 +192,12 @@ class A with fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) (new A).f(10) //│ ╔══[ERROR] Calls from class methods cannot yet be marked @tailcall. -//│ ║ l.195: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) +//│ ║ l.191: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) //│ ╙── ^^^^^^^^ //│ ╔══[ERROR] Class methods may not yet be marked @tailrec. -//│ ║ l.195: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) +//│ ║ l.191: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) //│ ╙── ^ //│ ╔══[ERROR] Calls from class methods cannot yet be marked @tailcall. -//│ ║ l.196: fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) +//│ ║ l.192: fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) //│ ╙── ^^^^^^^^ //│ = 0