From d19c72f3f201cb438f41b5982e1362b397b1cbf9 Mon Sep 17 00:00:00 2001 From: Matvey Aksenov Date: Mon, 12 Jan 2026 21:54:25 +0000 Subject: [PATCH 01/13] very basic typing, extremely wip --- src/T/Embed.hs | 4 +- src/T/Error.hs | 2 +- src/T/Exp.hs | 13 +-- src/T/Parse.hs | 3 +- src/T/Parse/Macro.hs | 3 +- src/T/Prelude.hs | 32 +++++- src/T/Render.hs | 20 ++-- src/T/Stdlib/Macro.hs | 7 +- src/T/Stdlib/Op.hs | 4 +- src/T/Type.hs | 236 +++++++++++++++++++++++++++++++++++++++--- src/T/Value.hs | 8 +- t.cabal | 2 +- test/T/RenderSpec.hs | 10 +- 13 files changed, 285 insertions(+), 59 deletions(-) diff --git a/src/T/Embed.hs b/src/T/Embed.hs index 738099a..ecd666b 100644 --- a/src/T/Embed.hs +++ b/src/T/Embed.hs @@ -108,11 +108,11 @@ instance (k ~ Name, v ~ Value) => Eject (HashMap k v) where Record o -> pure o value -> - Left (TypeError (varE name) Type.Record (typeOf value) (sexp value)) + Left (TypeError (varE name) (Type.Record (error "fields")) (typeOf value) (sexp value)) instance Eject a => Eject [a] where eject name = \case Array xs -> map toList (traverse (eject name) xs) value -> - Left (TypeError (varE name) Type.Array (typeOf value) (sexp value)) + Left (TypeError (varE name) (Type.Array (error "element")) (typeOf value) (sexp value)) diff --git a/src/T/Error.hs b/src/T/Error.hs index 3a887e0..1e90499 100644 --- a/src/T/Error.hs +++ b/src/T/Error.hs @@ -12,7 +12,7 @@ import Prettyprinter.Render.Terminal qualified as PP (Color(..), color) import Text.Trifecta qualified as Tri import Text.Trifecta.Delta qualified as Tri -import T.Exp (Cofree((:<)), Exp, ExpF(..), (:+)(..), Ann) +import T.Exp (Exp, ExpF(..), (:+)(..), Ann) import T.Exp.Ann (emptyAnn) import T.Name (Name) import T.Prelude diff --git a/src/T/Exp.hs b/src/T/Exp.hs index aeef2bc..2204a96 100644 --- a/src/T/Exp.hs +++ b/src/T/Exp.hs @@ -1,4 +1,6 @@ +{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralisedNewtypeDeriving #-} {-# LANGUAGE NamedFieldPuns #-} @@ -7,7 +9,6 @@ {-# LANGUAGE UndecidableInstances #-} module T.Exp ( Exp - , Cofree(..) , ExpF(..) , Literal(..) , (:+)(..) @@ -40,14 +41,6 @@ import T.SExp (sexp) import T.SExp qualified as SExp -data Cofree f a = a :< f (Cofree f a) - -deriving instance (Show a, Show (f (Cofree f a))) => Show (Cofree f a) - -instance Eq1 f => Eq (Cofree f a) where - (_ :< f) == (_ :< g) = - eq1 f g - type Exp = Cofree ExpF Ann data ExpF a @@ -63,7 +56,7 @@ data ExpF a -- ^ array index access: xs[0] | Key a (Ann :+ Name) -- ^ record property access: foo.bar - deriving (Show, Eq, Generic1) + deriving (Show, Eq, Generic1, Functor, Foldable, Traversable) instance SExp.To Exp where sexp = \case diff --git a/src/T/Parse.hs b/src/T/Parse.hs index b7fa150..59aaf7d 100644 --- a/src/T/Parse.hs +++ b/src/T/Parse.hs @@ -26,8 +26,7 @@ import Text.Parser.Token.Style (emptyOps) import Text.Regex.PCRE.Light qualified as Pcre import T.Exp - ( Cofree(..) - , Exp + ( Exp , ExpF(..) , Literal(..) , (:+)(..) diff --git a/src/T/Parse/Macro.hs b/src/T/Parse/Macro.hs index 50f007c..649d0fc 100644 --- a/src/T/Parse/Macro.hs +++ b/src/T/Parse/Macro.hs @@ -11,8 +11,7 @@ module T.Parse.Macro import Data.List qualified as List import T.Exp - ( Cofree(..) - , Exp + ( Exp , ExpF(..) , Ann , (:+)(..) diff --git a/src/T/Prelude.hs b/src/T/Prelude.hs index e57951a..b210a65 100644 --- a/src/T/Prelude.hs +++ b/src/T/Prelude.hs @@ -1,10 +1,15 @@ +{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE NoImplicitPrelude #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE UndecidableInstances #-} module T.Prelude ( Alternative(..) , Applicative(..) , Eq(..) - , Fractional , Eq1(..) + , Foldable(..) + , Fractional + , Functor , Generic1 , Hashable , Integral @@ -17,14 +22,17 @@ module T.Prelude , Ord(..) , Semigroup(..) , Show(..) + , Traversable , Bool(..) , ByteString , Char + , Cofree(..) , Double , Either(..) , FilePath , HashMap + , Identity , Int , IO , Maybe(..) @@ -38,10 +46,12 @@ module T.Prelude , ($) , (<$) , (&&) + , (||) , (+) , (-) , (*) , (/) + , any , asum , bool , concatMap @@ -67,8 +77,10 @@ module T.Prelude , notElem , otherwise , reverse + , runIdentity , second , seq + , sequence , toList , traverse , traverse_ @@ -80,18 +92,22 @@ module T.Prelude ) where import Control.Applicative (Alternative(..)) -import Control.Monad ((<=<), foldM_, when, unless) +import Control.Monad ((<=<), foldM_, sequence, when, unless) import Control.Monad.IO.Class (MonadIO(..)) import Data.Bool (bool) import Data.Bifunctor (first, second) import Data.ByteString (ByteString) import Data.Foldable - ( asum + ( Foldable + , any + , asum , for_ , toList , traverse_ ) import Data.Functor.Classes (Eq1(..), eq1) +import Data.Functor.Identity (Identity, runIdentity) +import Data.Traversable (Traversable) import Data.HashMap.Strict (HashMap) import Data.Hashable (Hashable) import Data.List (foldl') @@ -129,6 +145,7 @@ import Prelude , ($) , (<$) , (&&) + , (||) , (+) , (-) , (*) @@ -155,6 +172,15 @@ import Prelude , zipWith ) +data Cofree f a = a :< f (Cofree f a) + deriving (Functor) + +deriving instance (Show a, Show (f (Cofree f a))) => Show (Cofree f a) + +instance Eq1 f => Eq (Cofree f a) where + (_ :< f) == (_ :< g) = + eq1 f g + map :: Functor f => (a -> b) -> f a -> f b map = fmap diff --git a/src/T/Render.hs b/src/T/Render.hs index c756529..5b389db 100644 --- a/src/T/Render.hs +++ b/src/T/Render.hs @@ -26,7 +26,7 @@ import Data.Text.Lazy.Builder qualified as Builder import Data.HashMap.Strict qualified as HashMap import Data.Vector ((!?), (//)) -import T.Exp (Cofree(..), Exp, ExpF(..), Literal(..), (:+)(..), Ann) +import T.Exp (Exp, ExpF(..), Literal(..), (:+)(..), Ann) import T.Exp.Ann (emptyAnn) import T.Error (Error(..), Warning(..)) import T.Name (Name(..)) @@ -131,7 +131,7 @@ renderTmpl = \case (HashMap.toList o)) pure (bool (Just xs) Nothing (List.null xs)) _ -> - throwError (TypeError exp Type.Iterable (Value.typeOf value) (sexp value)) + throwError (TypeError exp (error "Type.Iterable") (Value.typeOf value) (sexp value)) case itemsQ of Nothing -> maybe (pure ()) renderTmpl elseTmpl @@ -169,7 +169,7 @@ renderExp exp = do Value.String str -> pure str _ -> - throwError (TypeError exp Type.Renderable (Value.typeOf value) (sexp value)) + throwError (TypeError exp (error "Type.Renderable") (Value.typeOf value) (sexp value)) evalExp :: (Ctx m, MonadError Error m) => Exp -> m Value evalExp = \case @@ -243,7 +243,7 @@ enforceArray exp = do Value.Array xs -> pure xs _ -> - throwError (TypeError exp Type.Array (Value.typeOf v) (sexp v)) + throwError (TypeError exp (error "Type.Array") (Value.typeOf v) (sexp v)) enforceRecord :: (Ctx m, MonadError Error m) => Exp -> m (HashMap Name Value) enforceRecord exp = do @@ -252,7 +252,7 @@ enforceRecord exp = do Value.Record r -> pure r _ -> - throwError (TypeError exp Type.Record (Value.typeOf v) (sexp v)) + throwError (TypeError exp (Type.Record mempty) (Value.typeOf v) (sexp v)) evalApp :: (Ctx m, MonadError Error m) @@ -273,7 +273,7 @@ evalApp name@(ann0 :+ _) = go g exps -- in every other case something went wrong :-( go v _ = - throwError (TypeError (ann0 :< Var name) Type.Fun (Value.typeOf v) (sexp v)) + throwError (TypeError (ann0 :< Var name) (error "Type.Fun") (Value.typeOf v) (sexp v)) data Path = Path { var :: Ann :+ Name @@ -320,7 +320,7 @@ evalLValue = Just v -> (Right v, path) Right v -> - throwError (TypeError exp Type.Record (Value.typeOf v) (sexp v)) + throwError (TypeError exp (error "Type.Array") (Value.typeOf v) (sexp v)) Left name -> throwError (NotInScope name) _ :< Key exp0 key@(_ :+ key0) -> do @@ -336,7 +336,7 @@ evalLValue = Just v -> (Right v, path) Right v -> - throwError (TypeError exp Type.Record (Value.typeOf v) (sexp v)) + throwError (TypeError exp (Type.Record mempty) (Value.typeOf v) (sexp v)) Left name -> throwError (NotInScope name) @@ -395,7 +395,7 @@ insertVar Path {var = (ann :+ name), lookups} v = do _ -> throwError (MissingProperty (ann0 :< Lit Null) (sexp r) (sexp key)) go v0 (K (ann0 :+ _key) : _path) = - throwError (TypeError (ann0 :< Lit Null) Type.Record (Value.typeOf v0) (sexp v0)) + throwError (TypeError (ann0 :< Lit Null) (Type.Record mempty) (Value.typeOf v0) (sexp v0)) go (Value.Array xs) (I (ann0 :+ idx) : path) = -- this is pretty similar to records except the lack of the aforementioned special -- treatment. @@ -406,7 +406,7 @@ insertVar Path {var = (ann :+ name), lookups} v = do Nothing -> throwError (OutOfBounds (ann0 :< Lit Null) (sexp xs) (sexp idx)) go v0 (I (ann0 :+ _idx) : _path) = - throwError (TypeError (ann0 :< Lit Null) Type.Record (Value.typeOf v0) (sexp v0)) + throwError (TypeError (ann0 :< Lit Null) (Type.Record mempty) (Value.typeOf v0) (sexp v0)) go _v0 [] = pure v diff --git a/src/T/Stdlib/Macro.hs b/src/T/Stdlib/Macro.hs index 190c372..0fc64c0 100644 --- a/src/T/Stdlib/Macro.hs +++ b/src/T/Stdlib/Macro.hs @@ -19,8 +19,7 @@ import Data.List qualified as List import Data.Map.Strict qualified as Map import T.Exp - ( Cofree(..) - , ExpF(..) + ( ExpF(..) , appE , appE_ , ifE @@ -93,8 +92,8 @@ or _ann args = -- function application macro: -- --- {{ x || f }} -> {{ f(x) }} --- {{ y || f(x) }} -> {{ f(x, y) }} +-- {{ x | f }} -> {{ f(x) }} +-- {{ y | f(x) }} -> {{ f(x, y) }} legacyApp :: Expansion legacyApp _ann = \case [expl, annf :< Var name] -> diff --git a/src/T/Stdlib/Op.hs b/src/T/Stdlib/Op.hs index 77ec7d7..e254d24 100644 --- a/src/T/Stdlib/Op.hs +++ b/src/T/Stdlib/Op.hs @@ -90,7 +90,7 @@ combineNumbers intOp doubleOp name = ann :+ n -> Left (TypeError (varE (ann :+ name)) Type.Double (typeOf n) (sexp n)) ann :+ n -> - Left (TypeError (varE (ann :+ name)) Type.Number (typeOf n) (sexp n)) + Left (TypeError (varE (ann :+ name)) (error "Type.Number") (typeOf n) (sexp n)) add :: Name -> Value add = @@ -124,7 +124,7 @@ predicateNumbers intOp doubleOp name = ann :+ n -> Left (TypeError (varE (ann :+ name)) Type.Double (typeOf n) (sexp n)) ann :+ n -> - Left (TypeError (varE (ann :+ name)) Type.Number (typeOf n) (sexp n)) + Left (TypeError (varE (ann :+ name)) (error "Type.Number") (typeOf n) (sexp n)) lt :: Name -> Value lt = diff --git a/src/T/Type.hs b/src/T/Type.hs index c077d0f..553b713 100644 --- a/src/T/Type.hs +++ b/src/T/Type.hs @@ -1,27 +1,237 @@ +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedRecordDot #-} module T.Type ( Type(..) + , infer + , TypeError ) where +import Control.Monad (foldM) +import Control.Monad.Reader (ReaderT, MonadReader, runReaderT, ask) +import Control.Monad.State (StateT, MonadState, runStateT, gets, modify) +import Control.Monad.Except (ExceptT, MonadError, runExceptT, throwError) +import Data.List.NonEmpty qualified as NonEmpty +import Data.HashMap.Strict qualified as HashMap + +import T.Exp (Exp, (:+)(..)) +import T.Exp qualified as Exp +import T.Name (Name) import T.Prelude data Type - -- real types first - = Null + = Unit | Bool | Int | Double | String | Regexp - | Array - | Record - | Fun - -- then pseudo-types - -- used in numeric operators - | Number - -- used in `for` and stdlib functions, polymorphic over - -- all containers - | Iterable - -- used in rendering values into text - | Renderable + | Array Type + | Record (HashMap Name Type) + | Fun (NonEmpty Type) Type + | Var Int deriving (Show, Eq) + +type TypedExp = Cofree Exp.ExpF (Exp.Ann, Type) + +newtype InferenceT m a = InferenceT (ReaderT Γ (StateT Σ (ExceptT TypeError m)) a) + deriving (Functor, Applicative, Monad, MonadReader Γ, MonadState Σ, MonadError TypeError) + +type Γ = HashMap Name Type + +data Σ = Σ + { subst :: Subst + , counter :: Int + } deriving (Show, Eq) + +type Subst = HashMap Int Type + +emptyΣ :: Σ +emptyΣ = Σ + { subst = mempty + , counter = 0 + } + +data TypeError + = MissingKey Name + | MissingVar Name + | NotARecord Type + | TypeMismatch Type Type + | OccursCheck Int Type + +runInferenceT :: Γ -> Σ -> InferenceT m a -> m (Either TypeError (a, Σ)) +runInferenceT ctx subst (InferenceT m) = + runExceptT (runStateT (runReaderT m ctx) subst) + +infer :: Exp -> Either TypeError TypedExp +infer exp = do + (te, Σ {subst}) <- runIdentity (runInferenceT mempty emptyΣ (inferExp exp)) + pure (finalize subst te) + +inferExp :: Monad m => Exp -> InferenceT m TypedExp +inferExp (ann :< e) = do + te <- traverse inferExp e + inferredType <- case te of + Exp.Lit l -> + inferLiteral l + Exp.Var (_ann :+ name) -> + lookupCtx name + Exp.If b t f -> do + _ <- unify (extractType b) Bool + unify (extractType t) (extractType f) + Exp.App (_ann :+ name) args -> + checkApp name args + Exp.Idx arr idx -> + checkIdx arr idx + Exp.Key r (_ann :+ name) -> + checkKey r name + pure ((ann, inferredType) :< te) + +lookupCtx :: Monad m => Name -> InferenceT m Type +lookupCtx name = do + ctx <- ask + maybe (throwError (MissingVar name)) pure (HashMap.lookup name ctx) + +extractType :: TypedExp -> Type +extractType ((_ann, t) :< _e) = t + +unify :: Monad m => Type -> Type -> InferenceT m Type +unify t1 t2 = do + s <- gets (.subst) + let + t1' = + applySubst s t1 + t2' = + applySubst s t2 + case (t1', t2') of + (a, b) + | a == b -> + pure a + (Var n, t) -> do + when (occurs n t s) $ + throwError (OccursCheck n t) + extendSubst n t + pure t + (t, Var n) -> do + when (occurs n t s) $ + throwError (OccursCheck n t) + extendSubst n t + pure t + (Array a, Array b) -> + map Array (unify a b) + (Record m1, Record m2) -> do + map Record (sequence (HashMap.intersectionWith unify m1 m2)) + (Fun args1 ret1, Fun args2 ret2) -> + liftA2 Fun (sequence (NonEmpty.zipWith unify args1 args2)) (unify ret1 ret2) + _ -> + throwError (TypeMismatch t1' t2') + +occurs :: Int -> Type -> Subst -> Bool +occurs n t subst = + case applySubst subst t of + Var m -> + n == m + Array arr -> + occurs n arr subst + Record r -> + any (\t' -> occurs n t' subst) r + Fun args r -> + any (\a -> occurs n a subst) args || occurs n r subst + _ -> + False + +applySubst :: Subst -> Type -> Type +applySubst subst t = + case t of + Var n -> + case HashMap.lookup n subst of + Just nt -> + applySubst subst nt + Nothing -> + Var n + Array arr -> + Array (applySubst subst arr) + Record r -> + Record (map (applySubst subst) r) + Fun args r -> + Fun (map (applySubst subst) args) (applySubst subst r) + Unit -> + Unit + Bool -> + Bool + Int -> + Int + Double -> + Double + String -> + String + Regexp -> + Regexp + +extendSubst :: Monad m => Int -> Type -> InferenceT m () +extendSubst n t = + modify (\s -> s { subst = HashMap.insert n t s.subst }) + +checkApp :: Monad m => Name -> NonEmpty TypedExp -> InferenceT m Type +checkApp name args = do + ft <- lookupCtx name + r <- freshVar + _ <- unify ft (Fun (map extractType args) r) + pure r + +checkIdx :: Monad m => TypedExp -> TypedExp -> InferenceT m Type +checkIdx arr idx = do + e <- freshVar + _ <- unify (extractType arr) (Array e) + _ <- unify (extractType idx) Int + pure e + +freshVar :: Monad m => InferenceT m Type +freshVar = do + n <- gets (.counter) + modify (\s -> s { counter = s.counter + 1 }) + pure (Var n) + +checkKey :: Monad m => TypedExp -> Name -> InferenceT m Type +checkKey r name = do + case extractType r of + Record fields -> + case HashMap.lookup name fields of + Just t -> + pure t + Nothing -> + throwError (MissingKey name) + Var n -> do + v <- freshVar + _ <- unify (Var n) (Record (HashMap.singleton name v)) + pure v + _ -> + throwError (NotARecord (extractType r)) + +inferLiteral :: Monad m => Exp.Literal -> InferenceT m Type +inferLiteral = \case + Exp.Null -> pure Unit + Exp.Bool _ -> pure Bool + Exp.Int _ -> pure Int + Exp.Double _ -> pure Double + Exp.String _ -> pure String + Exp.Regexp _ -> pure Regexp + Exp.Array xs -> do + ys <- traverse inferExp xs + y <- generalize (toList (map extractType ys)) + pure (Array y) + Exp.Record r -> do + ts <- traverse inferExp r + pure (Record (map extractType ts)) + +generalize :: Monad m => [Type] -> InferenceT m Type +generalize = \case + [] -> freshVar + t : ts -> + foldM unify t ts + +finalize :: Subst -> TypedExp -> TypedExp +finalize subst cofree = + map (\(ann, t) -> (ann, applySubst subst t)) cofree diff --git a/src/T/Value.hs b/src/T/Value.hs index ff13f5e..b9cd4ea 100644 --- a/src/T/Value.hs +++ b/src/T/Value.hs @@ -95,15 +95,15 @@ displayWith f = typeOf :: Value -> Type typeOf = \case - Null -> Type.Null + Null -> Type.Unit Bool _ -> Type.Bool Int _ -> Type.Int Double _ -> Type.Double String _ -> Type.String Regexp _ -> Type.Regexp - Array _ -> Type.Array - Record _ -> Type.Record - Lam _ -> Type.Fun + Array _ -> Type.Array (error "element") + Record _ -> Type.Record (error "fields") + Lam _ -> Type.Fun (error "args") (error "result") embedAeson :: Aeson.Value -> Value embedAeson = \case diff --git a/t.cabal b/t.cabal index 1332462..3dda618 100644 --- a/t.cabal +++ b/t.cabal @@ -1,6 +1,6 @@ cabal-version: 2.0 --- This file has been generated from package.yaml by hpack version 0.36.0. +-- This file has been generated from package.yaml by hpack version 0.38.0. -- -- see: https://github.com/sol/hpack diff --git a/test/T/RenderSpec.hs b/test/T/RenderSpec.hs index 1219d12..21a2c70 100644 --- a/test/T/RenderSpec.hs +++ b/test/T/RenderSpec.hs @@ -64,7 +64,7 @@ spec = r_ "{{ [1,2,3][0] }}" `shouldRender` "1" r_ "{{ [[1,2],[3]][0][1] }}" `shouldRender` "2" r_ "{{ 4[0] }}" `shouldRaise` - TypeError (litE_ (Int 4)) Type.Array Type.Int "4" + error "TypeError" (litE_ (Int 4)) Type.Array Type.Int "4" r_ "{{ [1,2,3][\"foo\"] }}" `shouldRaise` TypeError (litE_ (String "foo")) Type.Int Type.String "\"foo\"" r_ "{{ [1,2,3][-1] }}" `shouldRaise` OutOfBounds (int (-1)) (sexp (array [int 1, int 2, int 3])) "-1" @@ -75,7 +75,7 @@ spec = r_ "{{ {foo: [1,2,3]}.foo[0] }}" `shouldRender` "1" r_ "{{ {foo: [1,{bar: 7},3]}.foo[1].bar }}" `shouldRender` "7" r_ "{{ 4.foo }}" `shouldRaise` - TypeError (litE_ (Int 4)) Type.Record Type.Int "4" + error "TypeError" (litE_ (Int 4)) "Type.Record" Type.Int "4" r_ "{{ {}.foo }}" `shouldRaise` MissingProperty (record mempty) (sexp (record mempty)) "foo" context "line blocks" $ @@ -266,13 +266,13 @@ spec = r_ "{{ \"Foo\" =~ /foo/i }}" `shouldRender` "true" it "not-iterable" $ - r_ "{% for x in 4 %}{% endfor %}" `shouldRaise` TypeError (litE_ (Int 4)) Type.Iterable Type.Int "4" + r_ "{% for x in 4 %}{% endfor %}" `shouldRaise` error "TypeError" (litE_ (Int 4)) "Type.Iterable" Type.Int "4" it "not-renderable" $ - r_ "{{ [] }}" `shouldRaise` TypeError (array []) Type.Renderable Type.Array (sexp (array [])) + r_ "{{ [] }}" `shouldRaise` error "TypeError" (array []) error "Type.Renderable" "Type.Array" (sexp (array [])) it "not-a-function" $ - rWith [aesonQQ|{f: "foo"}|] "{{ f(4) }}" `shouldRaise` TypeError (varE "f") Type.Fun Type.String "\"foo\"" + rWith [aesonQQ|{f: "foo"}|] "{{ f(4) }}" `shouldRaise` error "TypeError" (varE "f") "Type.Fun" Type.String "\"foo\"" it "type errors" $ r_ "{{ bool01(\"foo\") }}" `shouldRaise` TypeError (varE "bool01") Type.Bool Type.String "\"foo\"" From b5be6da148455d45bdacab6fe225a5b81170f110 Mon Sep 17 00:00:00 2001 From: Matvey Aksenov Date: Tue, 13 Jan 2026 19:23:59 +0000 Subject: [PATCH 02/13] :infer-exp in repl --- src/T.hs | 7 +++++-- src/T/App/IO.hs | 4 ++++ src/T/App/Repl.hs | 21 +++++++++++++++++++++ src/T/Type.hs | 35 +++++++++++++++++++++++++++++++---- 4 files changed, 61 insertions(+), 6 deletions(-) diff --git a/src/T.hs b/src/T.hs index c573f43..e1bc949 100644 --- a/src/T.hs +++ b/src/T.hs @@ -1,5 +1,5 @@ module T - ( Tmpl + ( Tmpl(..) , Scope(..) , Exp , Value @@ -21,6 +21,8 @@ module T , stdlib , emptyScope + , TypeError + , Embed(..) , Eject(..) ) where @@ -34,7 +36,7 @@ import T.Error ) import T.Exp (Exp) import T.Name (Name(..)) -import T.Tmpl (Tmpl) +import T.Tmpl (Tmpl(..)) import T.Parse ( ParseError(..) , parseFile @@ -44,6 +46,7 @@ import T.Parse import T.Prelude import T.Render (Rendered(..), Scope(..), render) import T.Stdlib (Stdlib, def) +import T.Type (TypeError) import T.Value (Value, embedAeson) diff --git a/src/T/App/IO.hs b/src/T/App/IO.hs index 80ee120..7e32310 100644 --- a/src/T/App/IO.hs +++ b/src/T/App/IO.hs @@ -29,6 +29,10 @@ instance Warn T.ParseError where warn (T.ParseError err) = ppWarn err +instance Warn T.TypeError where + warn = + ppWarn . fromString . show + die :: Warn t => t -> IO () die err = do warn err diff --git a/src/T/App/Repl.hs b/src/T/App/Repl.hs index 40dc2b4..13aa5c8 100644 --- a/src/T/App/Repl.hs +++ b/src/T/App/Repl.hs @@ -18,6 +18,7 @@ import T qualified import T.SExp (sexp) import T.SExp qualified as SExp import T.Stdlib qualified as Stdlib +import T.Type qualified as Type run :: IO () @@ -46,6 +47,9 @@ run = do ParseTmpl str -> do parseTmpl str loop scope0 + InferExp str -> do + inferExp str + loop scope0 header :: IO () header = @@ -80,10 +84,25 @@ parseTmpl str = liftIO $ Right tmpl -> Text.Lazy.putStrLn (SExp.renderLazyText (sexp tmpl)) +inferExp :: MonadIO m => Text -> m () +inferExp str = liftIO $ + case T.parseText Stdlib.def str of + Left err -> + warn err + Right (T.Exp exp) -> + case Type.infer exp of + Left te -> + warn te + Right texp -> + Text.Lazy.putStrLn (SExp.renderLazyText (sexp (Type.extractType texp))) + Right _ -> + error "wow" + data Cmd = Quit | EvalTmpl Text | ParseTmpl Text + | InferExp Text deriving (Show, Eq) prompt :: IO Text @@ -100,5 +119,7 @@ parseInput input EvalTmpl rest | Just rest <- Text.stripPrefix ":parse-tmpl " input = ParseTmpl rest + | Just rest <- Text.stripPrefix ":infer-exp " input = + InferExp rest | otherwise = EvalTmpl input diff --git a/src/T/Type.hs b/src/T/Type.hs index 553b713..1d14d5b 100644 --- a/src/T/Type.hs +++ b/src/T/Type.hs @@ -1,11 +1,11 @@ {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedRecordDot #-} module T.Type ( Type(..) , infer , TypeError + , extractType ) where import Control.Monad (foldM) @@ -19,6 +19,8 @@ import T.Exp (Exp, (:+)(..)) import T.Exp qualified as Exp import T.Name (Name) import T.Prelude +import T.SExp (sexp) +import T.SExp qualified as SExp data Type @@ -34,6 +36,29 @@ data Type | Var Int deriving (Show, Eq) +instance SExp.To Type where + sexp = \case + Unit -> + "unit" + Bool -> + "bool" + Int -> + "int" + Double -> + "double" + String -> + "string" + Regexp -> + "regexp" + Array t -> + SExp.square [sexp t] + Record fs -> + SExp.curly (concatMap (\(k, v) -> [sexp k, sexp v]) (HashMap.toList fs)) + Fun args r -> + SExp.round ["->", SExp.square (toList (map sexp args)), sexp r] + Var n -> + fromString ('#' : show n) + type TypedExp = Cofree Exp.ExpF (Exp.Ann, Type) newtype InferenceT m a = InferenceT (ReaderT Γ (StateT Σ (ExceptT TypeError m)) a) @@ -60,6 +85,7 @@ data TypeError | NotARecord Type | TypeMismatch Type Type | OccursCheck Int Type + deriving (Show, Eq) runInferenceT :: Γ -> Σ -> InferenceT m a -> m (Either TypeError (a, Σ)) runInferenceT ctx subst (InferenceT m) = @@ -67,8 +93,8 @@ runInferenceT ctx subst (InferenceT m) = infer :: Exp -> Either TypeError TypedExp infer exp = do - (te, Σ {subst}) <- runIdentity (runInferenceT mempty emptyΣ (inferExp exp)) - pure (finalize subst te) + (te, finalΣ) <- runIdentity (runInferenceT mempty emptyΣ (inferExp exp)) + pure (finalize finalΣ.subst te) inferExp :: Monad m => Exp -> InferenceT m TypedExp inferExp (ann :< e) = do @@ -228,7 +254,8 @@ inferLiteral = \case generalize :: Monad m => [Type] -> InferenceT m Type generalize = \case - [] -> freshVar + [] -> + freshVar t : ts -> foldM unify t ts From 859d8a60d0969c96e250291eeaf202d81042babb Mon Sep 17 00:00:00 2001 From: Matvey Aksenov Date: Tue, 13 Jan 2026 21:47:37 +0000 Subject: [PATCH 03/13] prepare generalization --- src/T/Prelude.hs | 5 ++ src/T/Type.hs | 117 +++++++++++++++++++++++++++++++++++------------ 2 files changed, 93 insertions(+), 29 deletions(-) diff --git a/src/T/Prelude.hs b/src/T/Prelude.hs index b210a65..490ec40 100644 --- a/src/T/Prelude.hs +++ b/src/T/Prelude.hs @@ -67,6 +67,7 @@ module T.Prelude , foldr , foldM_ , foldr1 + , for , for_ , fromIntegral , impossible @@ -114,6 +115,9 @@ import Data.List (foldl') import Data.List.NonEmpty (NonEmpty(..)) import Data.Maybe (mapMaybe) import Data.Text (Text) +import Data.Traversable + ( for + ) import Data.String (IsString(..)) import Data.Vector (Vector) import Debug.Trace (traceShow) @@ -157,6 +161,7 @@ import Prelude , error , filter , flip + , foldMap , foldr , foldr1 , fromIntegral diff --git a/src/T/Type.hs b/src/T/Type.hs index 1d14d5b..9a3de67 100644 --- a/src/T/Type.hs +++ b/src/T/Type.hs @@ -13,6 +13,8 @@ import Control.Monad.Reader (ReaderT, MonadReader, runReaderT, ask) import Control.Monad.State (StateT, MonadState, runStateT, gets, modify) import Control.Monad.Except (ExceptT, MonadError, runExceptT, throwError) import Data.List.NonEmpty qualified as NonEmpty +import Data.Set (Set) +import Data.Set qualified as Set import Data.HashMap.Strict qualified as HashMap import T.Exp (Exp, (:+)(..)) @@ -21,6 +23,7 @@ import T.Name (Name) import T.Prelude import T.SExp (sexp) import T.SExp qualified as SExp +import T.Tmpl (Tmpl(..)) data Type @@ -59,12 +62,15 @@ instance SExp.To Type where Var n -> fromString ('#' : show n) +data Scheme = Forall (Set Int) Type + deriving (Show, Eq) + type TypedExp = Cofree Exp.ExpF (Exp.Ann, Type) newtype InferenceT m a = InferenceT (ReaderT Γ (StateT Σ (ExceptT TypeError m)) a) deriving (Functor, Applicative, Monad, MonadReader Γ, MonadState Σ, MonadError TypeError) -type Γ = HashMap Name Type +type Γ = HashMap Name Scheme data Σ = Σ { subst :: Subst @@ -96,6 +102,16 @@ infer exp = do (te, finalΣ) <- runIdentity (runInferenceT mempty emptyΣ (inferExp exp)) pure (finalize finalΣ.subst te) +inferTmpl :: Monad m => Tmpl -> InferenceT m () +inferTmpl = \case + Raw _text -> + pure () + Comment _text -> + pure () + Exp exp -> do + _texp <- inferExp exp + pure () + inferExp :: Monad m => Exp -> InferenceT m TypedExp inferExp (ann :< e) = do te <- traverse inferExp e @@ -118,7 +134,42 @@ inferExp (ann :< e) = do lookupCtx :: Monad m => Name -> InferenceT m Type lookupCtx name = do ctx <- ask - maybe (throwError (MissingVar name)) pure (HashMap.lookup name ctx) + maybe (throwError (MissingVar name)) instantiate (HashMap.lookup name ctx) + +generalize :: Γ -> Type -> Scheme +generalize ctx t = do + let + fvs = + freeVarsType t + ctxvs = + foldMap freeVarsScheme ctx + qs = + Set.difference fvs ctxvs + Forall qs t + +freeVarsType :: Type -> Set Int +freeVarsType = \case + Var n -> + Set.singleton n + Array t -> + freeVarsType t + Record r -> + foldMap freeVarsType r + Fun args r -> + foldMap freeVarsType args <> freeVarsType r + _ -> + Set.empty + +freeVarsScheme :: Scheme -> Set Int +freeVarsScheme (Forall qs t) = + Set.difference (freeVarsType t) qs + +instantiate :: Monad m => Scheme -> InferenceT m Type +instantiate (Forall vars t) = do + fvs <- for (toList vars) $ \v -> do + fv <- freshVar + pure (v, fv) + pure (substitute (HashMap.fromList fvs) t) extractType :: TypedExp -> Type extractType ((_ann, t) :< _e) = t @@ -128,9 +179,9 @@ unify t1 t2 = do s <- gets (.subst) let t1' = - applySubst s t1 + substitute s t1 t2' = - applySubst s t2 + substitute s t2 case (t1', t2') of (a, b) | a == b -> @@ -156,7 +207,7 @@ unify t1 t2 = do occurs :: Int -> Type -> Subst -> Bool occurs n t subst = - case applySubst subst t of + case substitute subst t of Var m -> n == m Array arr -> @@ -168,21 +219,9 @@ occurs n t subst = _ -> False -applySubst :: Subst -> Type -> Type -applySubst subst t = +substitute :: Subst -> Type -> Type +substitute subst t = case t of - Var n -> - case HashMap.lookup n subst of - Just nt -> - applySubst subst nt - Nothing -> - Var n - Array arr -> - Array (applySubst subst arr) - Record r -> - Record (map (applySubst subst) r) - Fun args r -> - Fun (map (applySubst subst) args) (applySubst subst r) Unit -> Unit Bool -> @@ -195,6 +234,18 @@ applySubst subst t = String Regexp -> Regexp + Array arr -> + Array (substitute subst arr) + Record r -> + Record (map (substitute subst) r) + Fun args r -> + Fun (map (substitute subst) args) (substitute subst r) + Var n -> + case HashMap.lookup n subst of + Nothing -> + Var n + Just nt -> + substitute subst nt extendSubst :: Monad m => Int -> Type -> InferenceT m () extendSubst n t = @@ -246,19 +297,27 @@ inferLiteral = \case Exp.Regexp _ -> pure Regexp Exp.Array xs -> do ys <- traverse inferExp xs - y <- generalize (toList (map extractType ys)) - pure (Array y) + t <- case toList (map extractType ys) of + [] -> + freshVar + z : zs -> + foldM unify z zs + pure (Array t) Exp.Record r -> do ts <- traverse inferExp r pure (Record (map extractType ts)) -generalize :: Monad m => [Type] -> InferenceT m Type -generalize = \case - [] -> - freshVar - t : ts -> - foldM unify t ts - finalize :: Subst -> TypedExp -> TypedExp finalize subst cofree = - map (\(ann, t) -> (ann, applySubst subst t)) cofree + map (\(ann, t) -> (ann, defaultType (substitute subst t))) cofree + +defaultType :: Type -> Type +defaultType = \case + Var _ -> + Unit + Array t -> + Array (defaultType t) + Record r -> + Record (map defaultType r) + t -> + t From d86bd628ac7484027e80907dd90601e3df8769d6 Mon Sep 17 00:00:00 2001 From: Matvey Aksenov Date: Wed, 14 Jan 2026 19:49:27 +0000 Subject: [PATCH 04/13] ascribe types to stdlib definitions --- src/T/App/Repl.hs | 2 +- src/T/Prelude.hs | 3 +- src/T/Stdlib.hs | 6 ++++ src/T/Stdlib/Fun.hs | 86 ++++++++++++++++++++++++++++++++------------ src/T/Stdlib/Op.hs | 67 +++++++++++++++++++++++++--------- src/T/Type.hs | 59 ++++++++++++++++++++++++++---- test/T/RenderSpec.hs | 13 +++++-- 7 files changed, 184 insertions(+), 52 deletions(-) diff --git a/src/T/App/Repl.hs b/src/T/App/Repl.hs index 13aa5c8..cae3fa9 100644 --- a/src/T/App/Repl.hs +++ b/src/T/App/Repl.hs @@ -90,7 +90,7 @@ inferExp str = liftIO $ Left err -> warn err Right (T.Exp exp) -> - case Type.infer exp of + case Type.infer (Stdlib.typingCtx Stdlib.def) exp of Left te -> warn te Right texp -> diff --git a/src/T/Prelude.hs b/src/T/Prelude.hs index 490ec40..1bc2d0c 100644 --- a/src/T/Prelude.hs +++ b/src/T/Prelude.hs @@ -70,6 +70,7 @@ module T.Prelude , for , for_ , fromIntegral + , fromMaybe , impossible , map , mapMaybe @@ -113,7 +114,7 @@ import Data.HashMap.Strict (HashMap) import Data.Hashable (Hashable) import Data.List (foldl') import Data.List.NonEmpty (NonEmpty(..)) -import Data.Maybe (mapMaybe) +import Data.Maybe (fromMaybe, mapMaybe) import Data.Text (Text) import Data.Traversable ( for diff --git a/src/T/Stdlib.hs b/src/T/Stdlib.hs index e730438..f5374ad 100644 --- a/src/T/Stdlib.hs +++ b/src/T/Stdlib.hs @@ -13,6 +13,7 @@ module T.Stdlib , with , def , bindings + , typingCtx , Macro.macroFun , Macro.macroOp ) where @@ -27,6 +28,7 @@ import T.Stdlib.Macro (Macro(..)) import T.Stdlib.Macro qualified as Macro import T.Stdlib.Op (Op(..)) import T.Stdlib.Op qualified as Op +import T.Type (Γ) import T.Value (Value) @@ -46,3 +48,7 @@ with ops funs macros = Stdlib {..} bindings :: Stdlib -> HashMap Name Value bindings stdlib = HashMap.fromList (Op.bindings stdlib.ops <> Fun.bindings stdlib.funs) + +typingCtx :: Stdlib -> Γ +typingCtx stdlib = + Op.typingCtx stdlib.ops <> Fun.typingCtx stdlib.funs diff --git a/src/T/Stdlib/Fun.hs b/src/T/Stdlib/Fun.hs index ebcd4f4..7635bd4 100644 --- a/src/T/Stdlib/Fun.hs +++ b/src/T/Stdlib/Fun.hs @@ -4,6 +4,7 @@ module T.Stdlib.Fun ( Fun(..) , bindings + , typingCtx , functions ) where @@ -18,41 +19,80 @@ import T.Error (Error(..)) import T.Exp.Ann ((:+)(..), unann) import T.Name (Name) import T.Prelude +import T.Type (Γ, forall, fun1, fun2) +import T.Type qualified as Type import T.Value (Value(..), display, displayWith) data Fun = Fun - { name :: Name - , binding :: Name -> Value + { name :: Name + , ascribed :: Type.Scheme + , binding :: Name -> Value } bindings :: [Fun] -> [(Name, Value)] bindings = map (\fun -> (fun.name, fun.binding fun.name)) +typingCtx :: [Fun] -> Γ +typingCtx = + HashMap.fromList . map (\fun -> (fun.name, fun.ascribed)) + functions :: [Fun] functions = - [ Fun "empty?" nullB - , Fun "length" lengthB - - , Fun "floor" (flip embed0 (floor @Double @Int)) - , Fun "ceiling" (flip embed0 (ceiling @Double @Int)) - , Fun "round" (flip embed0 (round @Double @Int)) - , Fun "int->double" (flip embed0 (fromIntegral @Int @Double)) - - , Fun "upper-case" (flip embed0 Text.toUpper) - , Fun "lower-case" (flip embed0 Text.toLower) - , Fun "title-case" (flip embed0 Text.toTitle) - - , Fun "split" (flip embed0 Text.splitOn) - , Fun "join" (flip embed0 Text.intercalate) - , Fun "concat" (flip embed0 Text.concat) - , Fun "chunks-of" (flip embed0 Text.chunksOf) - - , Fun "die" dieB - - , Fun "show" (\_ -> showB) - , Fun "pp" (\_ -> ppB) + [ Fun "empty?" + (forall [0] (Type.Array (Type.Var 0) `fun1` Type.Bool)) -- less polymorphic than we'd like + nullB + , Fun "length" + (forall [0] (fun1 (Type.Array (Type.Var 0)) Type.Int)) -- less polymorphic than we'd like + lengthB + + , Fun "floor" + (forall [] (Type.Double `fun1` Type.Int)) + (flip embed0 (floor @Double @Int)) + , Fun "ceiling" + (forall [] (Type.Double `fun1` Type.Int)) + (flip embed0 (ceiling @Double @Int)) + , Fun "round" + (forall [] (Type.Double `fun1` Type.Int)) + (flip embed0 (round @Double @Int)) + , Fun "int->double" + (forall [] (Type.Int `fun1` Type.Double)) + (flip embed0 (fromIntegral @Int @Double)) + + , Fun "upper-case" + (forall [] (Type.String `fun1` Type.String)) + (flip embed0 Text.toUpper) + , Fun "lower-case" + (forall [] (Type.String `fun1` Type.String)) + (flip embed0 Text.toLower) + , Fun "title-case" + (forall [] (Type.String `fun1` Type.String)) + (flip embed0 Text.toTitle) + + , Fun "split" + (forall [] ((Type.String, Type.String) `fun2` Type.Array Type.String)) + (flip embed0 Text.splitOn) + , Fun "join" + (forall [] ((Type.String, Type.Array Type.String) `fun2` Type.String)) + (flip embed0 Text.intercalate) + , Fun "concat" + (forall [] (Type.Array Type.String `fun1` Type.String)) + (flip embed0 Text.concat) + , Fun "chunks-of" + (forall [] ((Type.Int, Type.String) `fun2` Type.Array Type.String)) + (flip embed0 Text.chunksOf) + + , Fun "die" + (forall [0] (Type.String `fun1` Type.Var 0)) + dieB + + , Fun "show" + (forall [0] (Type.Var 0 `fun1` Type.String)) + (flip embed0 showB) + , Fun "pp" + (forall [0] (Type.Var 0 `fun1` Type.String)) + (flip embed0 ppB) ] nullB :: Name -> Value diff --git a/src/T/Stdlib/Op.hs b/src/T/Stdlib/Op.hs index e254d24..fa193c8 100644 --- a/src/T/Stdlib/Op.hs +++ b/src/T/Stdlib/Op.hs @@ -6,6 +6,7 @@ module T.Stdlib.Op , PriorityMap , Fixity(..) , bindings + , typingCtx , priorities , operators ) where @@ -25,12 +26,14 @@ import T.Exp.Ann ((:+)(..)) import T.Name (Name) import T.Prelude import T.SExp (sexp) +import T.Type (Γ, forall, fun1, fun2) import T.Type qualified as Type import T.Value (Value(..), typeOf) data Op = Op { name :: Name + , ascribed :: Type.Scheme , binding :: Name -> Value , fixity :: Fixity , priority :: Int @@ -49,29 +52,59 @@ bindings :: [Op] -> [(Name, Value)] bindings = map (\op -> (op.name, op.binding op.name)) +typingCtx :: [Op] -> Γ +typingCtx = + HashMap.fromList . map (\op -> (op.name, op.ascribed)) + priorities :: [Op] -> PriorityMap priorities = Map.fromListWith (<>) . map (\op -> (op.priority, [(op.name, op.fixity)])) operators :: [Op] operators = - [ Op "!" (flip embed0 not) Prefix 8 - - , Op "==" (flip embed0 eq) Infix 4 - , Op "!=" (flip embed0 neq) Infix 4 - , Op "=~" (flip embed0 match) Infix 4 - - , Op "+" add Infixl 6 - , Op "-" subtract Infixl 6 - , Op "*" multiply Infixl 7 - , Op "/" divide Infixl 7 - - , Op "<" lt Infix 4 - , Op "<=" le Infix 4 - , Op ">" gt Infix 4 - , Op ">=" ge Infix 4 - - , Op "<>" (flip embed0 ((<>) @Text)) Infixr 6 + [ Op "!" + (forall [] (Type.Bool `fun1` Type.Bool)) + (flip embed0 not) Prefix 8 + + , Op "==" + (forall [0] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) -- more polymorphic than we'd like + (flip embed0 eq) Infix 4 + , Op "!=" + (forall [0] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) -- more polymorphic than we'd like + (flip embed0 neq) Infix 4 + , Op "=~" + (forall [0] ((Type.String, Type.Regexp) `fun2` Type.Bool)) + (flip embed0 match) Infix 4 + + , Op "+" + (forall [] ((Type.Int, Type.Int) `fun2` Type.Int)) -- less polymorphic than we'd like + add Infixl 6 + , Op "-" + (forall [] ((Type.Int, Type.Int) `fun2` Type.Int)) -- less polymorphic than we'd like + subtract Infixl 6 + , Op "*" + (forall [] ((Type.Int, Type.Int) `fun2` Type.Int)) -- less polymorphic than we'd like + multiply Infixl 7 + , Op "/" + (forall [] ((Type.Int, Type.Int) `fun2` Type.Int)) -- less polymorphic than we'd like + divide Infixl 7 + + , Op "<" + (forall [] ((Type.Int, Type.Int) `fun2` Type.Bool)) -- less polymorphic than we'd like + lt Infix 4 + , Op "<=" + (forall [] ((Type.Int, Type.Int) `fun2` Type.Bool)) -- less polymorphic than we'd like + le Infix 4 + , Op ">" + (forall [] ((Type.Int, Type.Int) `fun2` Type.Bool)) -- less polymorphic than we'd like + gt Infix 4 + , Op ">=" + (forall [] ((Type.Int, Type.Int) `fun2` Type.Bool)) -- less polymorphic than we'd like + ge Infix 4 + + , Op "<>" + (forall [] ((Type.String, Type.String) `fun2` Type.String)) + (flip embed0 ((<>) @Text)) Infixr 6 ] combineNumbers :: (Int -> Int -> Int) -> (Double -> Double -> Double) -> Name -> Value diff --git a/src/T/Type.hs b/src/T/Type.hs index 9a3de67..8b72183 100644 --- a/src/T/Type.hs +++ b/src/T/Type.hs @@ -2,7 +2,12 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE OverloadedRecordDot #-} module T.Type - ( Type(..) + ( Γ + , Type(..) + , Scheme(..) + , forall + , fun1 + , fun2 , infer , TypeError , extractType @@ -65,6 +70,18 @@ instance SExp.To Type where data Scheme = Forall (Set Int) Type deriving (Show, Eq) +forall :: [Int] -> Type -> Scheme +forall qs t = + Forall (Set.fromList qs) t + +fun1 :: Type -> Type -> Type +fun1 a1 r = + Fun (a1 :| []) r + +fun2 :: (Type, Type) -> Type -> Type +fun2 (a1, a2) r = + Fun (a1 :| a2 : []) r + type TypedExp = Cofree Exp.ExpF (Exp.Ann, Type) newtype InferenceT m a = InferenceT (ReaderT Γ (StateT Σ (ExceptT TypeError m)) a) @@ -97,9 +114,9 @@ runInferenceT :: Γ -> Σ -> InferenceT m a -> m (Either TypeError (a, Σ)) runInferenceT ctx subst (InferenceT m) = runExceptT (runStateT (runReaderT m ctx) subst) -infer :: Exp -> Either TypeError TypedExp -infer exp = do - (te, finalΣ) <- runIdentity (runInferenceT mempty emptyΣ (inferExp exp)) +infer :: Γ -> Exp -> Either TypeError TypedExp +infer ctx exp = do + (te, finalΣ) <- runIdentity (runInferenceT ctx emptyΣ (inferExp exp)) pure (finalize finalΣ.subst te) inferTmpl :: Monad m => Tmpl -> InferenceT m () @@ -169,7 +186,35 @@ instantiate (Forall vars t) = do fvs <- for (toList vars) $ \v -> do fv <- freshVar pure (v, fv) - pure (substitute (HashMap.fromList fvs) t) + pure (minisubstitute (HashMap.fromList fvs) t) + +-- | 'minisubstitute' is a variant of 'substitute' that doesn't +-- do deep substitution; this is necessary separate the namespaces +-- of quantified variables and unitification variables which is +-- useful for e.g. stdlib definitions +minisubstitute :: Subst -> Type -> Type +minisubstitute subst t = + case t of + Unit -> + Unit + Bool -> + Bool + Int -> + Int + Double -> + Double + String -> + String + Regexp -> + Regexp + Array arr -> + Array (minisubstitute subst arr) + Record r -> + Record (map (minisubstitute subst) r) + Fun args r -> + Fun (map (minisubstitute subst) args) (minisubstitute subst r) + Var n -> + fromMaybe (Var n) (HashMap.lookup n subst) extractType :: TypedExp -> Type extractType ((_ann, t) :< _e) = t @@ -308,8 +353,8 @@ inferLiteral = \case pure (Record (map extractType ts)) finalize :: Subst -> TypedExp -> TypedExp -finalize subst cofree = - map (\(ann, t) -> (ann, defaultType (substitute subst t))) cofree +finalize subst = + map (\(ann, t) -> (ann, defaultType (substitute subst t))) defaultType :: Type -> Type defaultType = \case diff --git a/test/T/RenderSpec.hs b/test/T/RenderSpec.hs index 21a2c70..dbd966b 100644 --- a/test/T/RenderSpec.hs +++ b/test/T/RenderSpec.hs @@ -24,6 +24,7 @@ import T.Stdlib (def) import T.Stdlib qualified as Stdlib import T.Stdlib.Op qualified as Op import T.SExp (sexp) +import T.Type (forall, fun1, fun2) import T.Type qualified as Type import T.Value (embedAeson) @@ -406,13 +407,19 @@ rWith json tmplStr = do opExt :: [Stdlib.Op] opExt = - [ Stdlib.Op "<+>" (flip embed0 (\str0 str1 -> str0 <> "+" <> str1 :: Text)) Stdlib.Infixr 6 + [ Stdlib.Op "<+>" + (forall [] ((Type.String, Type.String) `fun2` Type.String)) + (flip embed0 (\str0 str1 -> str0 <> "+" <> str1 :: Text)) Stdlib.Infixr 6 ] funExt :: [Stdlib.Fun] funExt = - [ Stdlib.Fun "bool01" (flip embed0 (bool @Int 0 1)) - , Stdlib.Fun "const" (flip embed0 (const @Bool @Text)) + [ Stdlib.Fun "bool01" + (forall [] (Type.Bool `fun1` Type.Int)) + (flip embed0 (bool @Int 0 1)) + , Stdlib.Fun "const" + (forall [] (Type.Bool `fun1` Type.String)) + (flip embed0 (const @Bool @Text)) ] macroExt :: [Stdlib.Macro] From 60d2a412463ed307bd2001315c8b177f61c7da9d Mon Sep 17 00:00:00 2001 From: Matvey Aksenov Date: Wed, 14 Jan 2026 20:00:49 +0000 Subject: [PATCH 05/13] embed0 is only used flipped --- src/T/Embed.hs | 4 ++-- src/T/Stdlib.hs | 4 +--- src/T/Stdlib/Fun.hs | 30 +++++++++++++++--------------- src/T/Stdlib/Op.hs | 14 +++++++------- test/T/RenderSpec.hs | 6 +++--- 5 files changed, 28 insertions(+), 30 deletions(-) diff --git a/src/T/Embed.hs b/src/T/Embed.hs index ecd666b..3ceeab8 100644 --- a/src/T/Embed.hs +++ b/src/T/Embed.hs @@ -58,8 +58,8 @@ instance (Eject a, Embed b) => Embed (a -> b) where -- Some embeddings do not have a useful annotation to attach to, such as -- stdlib definitions. This is a helper for them. -embed0 :: Embed t => Name -> t -> Value -embed0 name t = +embed0 :: Embed t => t -> Name -> Value +embed0 t name = embed (emptyAnn :+ name) t class Eject t where diff --git a/src/T/Stdlib.hs b/src/T/Stdlib.hs index f5374ad..070c56b 100644 --- a/src/T/Stdlib.hs +++ b/src/T/Stdlib.hs @@ -18,8 +18,6 @@ module T.Stdlib , Macro.macroOp ) where -import Data.HashMap.Strict qualified as HashMap - import T.Name (Name) import T.Prelude import T.Stdlib.Fun (Fun(..)) @@ -47,7 +45,7 @@ with ops funs macros = Stdlib {..} bindings :: Stdlib -> HashMap Name Value bindings stdlib = - HashMap.fromList (Op.bindings stdlib.ops <> Fun.bindings stdlib.funs) + Op.bindings stdlib.ops <> Fun.bindings stdlib.funs typingCtx :: Stdlib -> Γ typingCtx stdlib = diff --git a/src/T/Stdlib/Fun.hs b/src/T/Stdlib/Fun.hs index 7635bd4..5192df5 100644 --- a/src/T/Stdlib/Fun.hs +++ b/src/T/Stdlib/Fun.hs @@ -30,9 +30,9 @@ data Fun = Fun , binding :: Name -> Value } -bindings :: [Fun] -> [(Name, Value)] +bindings :: [Fun] -> HashMap Name Value bindings = - map (\fun -> (fun.name, fun.binding fun.name)) + HashMap.fromList . map (\fun -> (fun.name, fun.binding fun.name)) typingCtx :: [Fun] -> Γ typingCtx = @@ -49,39 +49,39 @@ functions = , Fun "floor" (forall [] (Type.Double `fun1` Type.Int)) - (flip embed0 (floor @Double @Int)) + (embed0 (floor @Double @Int)) , Fun "ceiling" (forall [] (Type.Double `fun1` Type.Int)) - (flip embed0 (ceiling @Double @Int)) + (embed0 (ceiling @Double @Int)) , Fun "round" (forall [] (Type.Double `fun1` Type.Int)) - (flip embed0 (round @Double @Int)) + (embed0 (round @Double @Int)) , Fun "int->double" (forall [] (Type.Int `fun1` Type.Double)) - (flip embed0 (fromIntegral @Int @Double)) + (embed0 (fromIntegral @Int @Double)) , Fun "upper-case" (forall [] (Type.String `fun1` Type.String)) - (flip embed0 Text.toUpper) + (embed0 Text.toUpper) , Fun "lower-case" (forall [] (Type.String `fun1` Type.String)) - (flip embed0 Text.toLower) + (embed0 Text.toLower) , Fun "title-case" (forall [] (Type.String `fun1` Type.String)) - (flip embed0 Text.toTitle) + (embed0 Text.toTitle) , Fun "split" (forall [] ((Type.String, Type.String) `fun2` Type.Array Type.String)) - (flip embed0 Text.splitOn) + (embed0 Text.splitOn) , Fun "join" (forall [] ((Type.String, Type.Array Type.String) `fun2` Type.String)) - (flip embed0 Text.intercalate) + (embed0 Text.intercalate) , Fun "concat" (forall [] (Type.Array Type.String `fun1` Type.String)) - (flip embed0 Text.concat) + (embed0 Text.concat) , Fun "chunks-of" (forall [] ((Type.Int, Type.String) `fun2` Type.Array Type.String)) - (flip embed0 Text.chunksOf) + (embed0 Text.chunksOf) , Fun "die" (forall [0] (Type.String `fun1` Type.Var 0)) @@ -89,10 +89,10 @@ functions = , Fun "show" (forall [0] (Type.Var 0 `fun1` Type.String)) - (flip embed0 showB) + (embed0 showB) , Fun "pp" (forall [0] (Type.Var 0 `fun1` Type.String)) - (flip embed0 ppB) + (embed0 ppB) ] nullB :: Name -> Value diff --git a/src/T/Stdlib/Op.hs b/src/T/Stdlib/Op.hs index fa193c8..5cb7d4f 100644 --- a/src/T/Stdlib/Op.hs +++ b/src/T/Stdlib/Op.hs @@ -48,9 +48,9 @@ data Fixity | Infixr deriving (Show, Eq) -bindings :: [Op] -> [(Name, Value)] +bindings :: [Op] -> HashMap Name Value bindings = - map (\op -> (op.name, op.binding op.name)) + HashMap.fromList . map (\op -> (op.name, op.binding op.name)) typingCtx :: [Op] -> Γ typingCtx = @@ -64,17 +64,17 @@ operators :: [Op] operators = [ Op "!" (forall [] (Type.Bool `fun1` Type.Bool)) - (flip embed0 not) Prefix 8 + (embed0 not) Prefix 8 , Op "==" (forall [0] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) -- more polymorphic than we'd like - (flip embed0 eq) Infix 4 + (embed0 eq) Infix 4 , Op "!=" (forall [0] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) -- more polymorphic than we'd like - (flip embed0 neq) Infix 4 + (embed0 neq) Infix 4 , Op "=~" (forall [0] ((Type.String, Type.Regexp) `fun2` Type.Bool)) - (flip embed0 match) Infix 4 + (embed0 match) Infix 4 , Op "+" (forall [] ((Type.Int, Type.Int) `fun2` Type.Int)) -- less polymorphic than we'd like @@ -104,7 +104,7 @@ operators = , Op "<>" (forall [] ((Type.String, Type.String) `fun2` Type.String)) - (flip embed0 ((<>) @Text)) Infixr 6 + (embed0 ((<>) @Text)) Infixr 6 ] combineNumbers :: (Int -> Int -> Int) -> (Double -> Double -> Double) -> Name -> Value diff --git a/test/T/RenderSpec.hs b/test/T/RenderSpec.hs index dbd966b..5de5822 100644 --- a/test/T/RenderSpec.hs +++ b/test/T/RenderSpec.hs @@ -409,17 +409,17 @@ opExt :: [Stdlib.Op] opExt = [ Stdlib.Op "<+>" (forall [] ((Type.String, Type.String) `fun2` Type.String)) - (flip embed0 (\str0 str1 -> str0 <> "+" <> str1 :: Text)) Stdlib.Infixr 6 + (embed0 (\str0 str1 -> str0 <> "+" <> str1 :: Text)) Stdlib.Infixr 6 ] funExt :: [Stdlib.Fun] funExt = [ Stdlib.Fun "bool01" (forall [] (Type.Bool `fun1` Type.Int)) - (flip embed0 (bool @Int 0 1)) + (embed0 (bool @Int 0 1)) , Stdlib.Fun "const" (forall [] (Type.Bool `fun1` Type.String)) - (flip embed0 (const @Bool @Text)) + (embed0 (const @Bool @Text)) ] macroExt :: [Stdlib.Macro] From fc64200e421e862eeb1cf9d503ab9aca765e92a4 Mon Sep 17 00:00:00 2001 From: Matvey Aksenov Date: Wed, 14 Jan 2026 20:54:59 +0000 Subject: [PATCH 06/13] check function arity in unify --- src/T/Type.hs | 74 +++++++++++++++++---------------------------------- 1 file changed, 25 insertions(+), 49 deletions(-) diff --git a/src/T/Type.hs b/src/T/Type.hs index 8b72183..8aba3c0 100644 --- a/src/T/Type.hs +++ b/src/T/Type.hs @@ -186,35 +186,25 @@ instantiate (Forall vars t) = do fvs <- for (toList vars) $ \v -> do fv <- freshVar pure (v, fv) - pure (minisubstitute (HashMap.fromList fvs) t) + pure (replaceOnce (HashMap.fromList fvs) t) --- | 'minisubstitute' is a variant of 'substitute' that doesn't +-- | 'replaceOnce' is a variant of 'replace' that doesn't -- do deep substitution; this is necessary separate the namespaces -- of quantified variables and unitification variables which is -- useful for e.g. stdlib definitions -minisubstitute :: Subst -> Type -> Type -minisubstitute subst t = +replaceOnce :: Subst -> Type -> Type +replaceOnce subst t = case t of - Unit -> - Unit - Bool -> - Bool - Int -> - Int - Double -> - Double - String -> - String - Regexp -> - Regexp Array arr -> - Array (minisubstitute subst arr) + Array (replaceOnce subst arr) Record r -> - Record (map (minisubstitute subst) r) + Record (map (replaceOnce subst) r) Fun args r -> - Fun (map (minisubstitute subst) args) (minisubstitute subst r) + Fun (map (replaceOnce subst) args) (replaceOnce subst r) Var n -> fromMaybe (Var n) (HashMap.lookup n subst) + _ -> + t extractType :: TypedExp -> Type extractType ((_ann, t) :< _e) = t @@ -222,12 +212,7 @@ extractType ((_ann, t) :< _e) = t unify :: Monad m => Type -> Type -> InferenceT m Type unify t1 t2 = do s <- gets (.subst) - let - t1' = - substitute s t1 - t2' = - substitute s t2 - case (t1', t2') of + case (replace s t1, replace s t2) of (a, b) | a == b -> pure a @@ -245,14 +230,15 @@ unify t1 t2 = do map Array (unify a b) (Record m1, Record m2) -> do map Record (sequence (HashMap.intersectionWith unify m1 m2)) - (Fun args1 ret1, Fun args2 ret2) -> - liftA2 Fun (sequence (NonEmpty.zipWith unify args1 args2)) (unify ret1 ret2) - _ -> - throwError (TypeMismatch t1' t2') + (Fun args1 ret1, Fun args2 ret2) + | NonEmpty.length args1 == NonEmpty.length args2 -> + liftA2 Fun (sequence (NonEmpty.zipWith unify args1 args2)) (unify ret1 ret2) + (a, b) -> + throwError (TypeMismatch a b) occurs :: Int -> Type -> Subst -> Bool occurs n t subst = - case substitute subst t of + case replace subst t of Var m -> n == m Array arr -> @@ -264,33 +250,23 @@ occurs n t subst = _ -> False -substitute :: Subst -> Type -> Type -substitute subst t = +replace :: Subst -> Type -> Type +replace subst t = case t of - Unit -> - Unit - Bool -> - Bool - Int -> - Int - Double -> - Double - String -> - String - Regexp -> - Regexp Array arr -> - Array (substitute subst arr) + Array (replace subst arr) Record r -> - Record (map (substitute subst) r) + Record (map (replace subst) r) Fun args r -> - Fun (map (substitute subst) args) (substitute subst r) + Fun (map (replace subst) args) (replace subst r) Var n -> case HashMap.lookup n subst of Nothing -> Var n Just nt -> - substitute subst nt + replace subst nt + _ -> + t extendSubst :: Monad m => Int -> Type -> InferenceT m () extendSubst n t = @@ -354,7 +330,7 @@ inferLiteral = \case finalize :: Subst -> TypedExp -> TypedExp finalize subst = - map (\(ann, t) -> (ann, defaultType (substitute subst t))) + map (\(ann, t) -> (ann, defaultType (replace subst t))) defaultType :: Type -> Type defaultType = \case From c65aec9f540bf27aff00539b72afaa68735cf5dc Mon Sep 17 00:00:00 2001 From: Matvey Aksenov Date: Wed, 14 Jan 2026 21:41:01 +0000 Subject: [PATCH 07/13] prepare constraints --- src/T/Prelude.hs | 5 +-- src/T/Type.hs | 86 +++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 78 insertions(+), 13 deletions(-) diff --git a/src/T/Prelude.hs b/src/T/Prelude.hs index 1bc2d0c..98a8bb3 100644 --- a/src/T/Prelude.hs +++ b/src/T/Prelude.hs @@ -51,6 +51,7 @@ module T.Prelude , (-) , (*) , (/) + , all , any , asum , bool @@ -70,7 +71,6 @@ module T.Prelude , for , for_ , fromIntegral - , fromMaybe , impossible , map , mapMaybe @@ -101,6 +101,7 @@ import Data.Bifunctor (first, second) import Data.ByteString (ByteString) import Data.Foldable ( Foldable + , all , any , asum , for_ @@ -114,7 +115,7 @@ import Data.HashMap.Strict (HashMap) import Data.Hashable (Hashable) import Data.List (foldl') import Data.List.NonEmpty (NonEmpty(..)) -import Data.Maybe (fromMaybe, mapMaybe) +import Data.Maybe (mapMaybe) import Data.Text (Text) import Data.Traversable ( for diff --git a/src/T/Type.hs b/src/T/Type.hs index 8aba3c0..462208a 100644 --- a/src/T/Type.hs +++ b/src/T/Type.hs @@ -82,6 +82,55 @@ fun2 :: (Type, Type) -> Type -> Type fun2 (a1, a2) r = Fun (a1 :| a2 : []) r +data Constraint + = Num + | Eq + | Show + | Sizeable + | Iterable + deriving (Show, Eq, Ord) + +satisfies :: Constraint -> Type -> Bool +satisfies Num = \case + Int -> True + Double -> True + _ -> False +satisfies Eq = \case + Unit -> + True + Bool -> + True + Int -> + True + Double -> + True + String -> + True + Regexp -> + True + Array t -> + satisfies Eq t + Record fs -> + all (satisfies Eq) fs + _ -> + False +satisfies Show = \case + Unit -> True + Bool -> True + Int -> True + Double -> True + String -> True + _ -> False +satisfies Sizeable = \case + String -> True + Array _ -> True + Record _ -> True + _ -> False +satisfies Iterable = \case + Array _ -> True + Record _ -> True + _ -> False + type TypedExp = Cofree Exp.ExpF (Exp.Ann, Type) newtype InferenceT m a = InferenceT (ReaderT Γ (StateT Σ (ExceptT TypeError m)) a) @@ -90,8 +139,9 @@ newtype InferenceT m a = InferenceT (ReaderT Γ (StateT Σ (ExceptT TypeError m) type Γ = HashMap Name Scheme data Σ = Σ - { subst :: Subst - , counter :: Int + { subst :: Subst + , constraints :: HashMap Int (Set Constraint) + , counter :: Int } deriving (Show, Eq) type Subst = HashMap Int Type @@ -99,6 +149,7 @@ type Subst = HashMap Int Type emptyΣ :: Σ emptyΣ = Σ { subst = mempty + , constraints = mempty , counter = 0 } @@ -107,6 +158,7 @@ data TypeError | MissingVar Name | NotARecord Type | TypeMismatch Type Type + | ConstraintViolation Constraint Type | OccursCheck Int Type deriving (Show, Eq) @@ -202,7 +254,7 @@ replaceOnce subst t = Fun args r -> Fun (map (replaceOnce subst) args) (replaceOnce subst r) Var n -> - fromMaybe (Var n) (HashMap.lookup n subst) + HashMap.findWithDefault (Var n) n subst _ -> t @@ -217,15 +269,9 @@ unify t1 t2 = do | a == b -> pure a (Var n, t) -> do - when (occurs n t s) $ - throwError (OccursCheck n t) - extendSubst n t - pure t + unifyVar s n t (t, Var n) -> do - when (occurs n t s) $ - throwError (OccursCheck n t) - extendSubst n t - pure t + unifyVar s n t (Array a, Array b) -> map Array (unify a b) (Record m1, Record m2) -> do @@ -236,6 +282,24 @@ unify t1 t2 = do (a, b) -> throwError (TypeMismatch a b) +unifyVar :: Monad m => Subst -> Int -> Type -> InferenceT m Type +unifyVar s n t = do + when (occurs n t s) (throwError (OccursCheck n t)) + cs <- gets (HashMap.findWithDefault Set.empty n . (.constraints)) + checkConstraints cs t + extendSubst n t + pure t + +checkConstraints :: Monad m => Set Constraint -> Type -> InferenceT m () +checkConstraints cs t0 = do + subst <- gets (.subst) + case replace subst t0 of + Var m -> + modify (\s -> s {constraints = HashMap.insertWith Set.union m cs s.constraints}) + t -> + for_ cs $ \c -> + unless (satisfies c t) (throwError (ConstraintViolation c t)) + occurs :: Int -> Type -> Subst -> Bool occurs n t subst = case replace subst t of From e29e46c8cb7f657da2367700733ce6145a189d98 Mon Sep 17 00:00:00 2001 From: Matvey Aksenov Date: Thu, 15 Jan 2026 21:04:02 +0000 Subject: [PATCH 08/13] constraintify stdlib --- src/T/Stdlib/Fun.hs | 34 ++++++------ src/T/Stdlib/Op.hs | 28 +++++----- src/T/Type.hs | 120 +++++++++++++++++++++++++------------------ test/T/RenderSpec.hs | 8 +-- 4 files changed, 104 insertions(+), 86 deletions(-) diff --git a/src/T/Stdlib/Fun.hs b/src/T/Stdlib/Fun.hs index 5192df5..26cda7b 100644 --- a/src/T/Stdlib/Fun.hs +++ b/src/T/Stdlib/Fun.hs @@ -19,7 +19,7 @@ import T.Error (Error(..)) import T.Exp.Ann ((:+)(..), unann) import T.Name (Name) import T.Prelude -import T.Type (Γ, forall, fun1, fun2) +import T.Type (Γ, forall, forall_, fun1, fun2) import T.Type qualified as Type import T.Value (Value(..), display, displayWith) @@ -41,57 +41,57 @@ typingCtx = functions :: [Fun] functions = [ Fun "empty?" - (forall [0] (Type.Array (Type.Var 0) `fun1` Type.Bool)) -- less polymorphic than we'd like + (forall [0] [(0, Type.Sizeable)] (Type.Var 0 `fun1` Type.Bool)) nullB , Fun "length" - (forall [0] (fun1 (Type.Array (Type.Var 0)) Type.Int)) -- less polymorphic than we'd like + (forall [0] [(0, Type.Sizeable)] (Type.Var 0 `fun1` Type.Int)) lengthB , Fun "floor" - (forall [] (Type.Double `fun1` Type.Int)) + (forall_ (Type.Double `fun1` Type.Int)) (embed0 (floor @Double @Int)) , Fun "ceiling" - (forall [] (Type.Double `fun1` Type.Int)) + (forall_ (Type.Double `fun1` Type.Int)) (embed0 (ceiling @Double @Int)) , Fun "round" - (forall [] (Type.Double `fun1` Type.Int)) + (forall_ (Type.Double `fun1` Type.Int)) (embed0 (round @Double @Int)) , Fun "int->double" - (forall [] (Type.Int `fun1` Type.Double)) + (forall_ (Type.Int `fun1` Type.Double)) (embed0 (fromIntegral @Int @Double)) , Fun "upper-case" - (forall [] (Type.String `fun1` Type.String)) + (forall_ (Type.String `fun1` Type.String)) (embed0 Text.toUpper) , Fun "lower-case" - (forall [] (Type.String `fun1` Type.String)) + (forall_ (Type.String `fun1` Type.String)) (embed0 Text.toLower) , Fun "title-case" - (forall [] (Type.String `fun1` Type.String)) + (forall_ (Type.String `fun1` Type.String)) (embed0 Text.toTitle) , Fun "split" - (forall [] ((Type.String, Type.String) `fun2` Type.Array Type.String)) + (forall_ ((Type.String, Type.String) `fun2` Type.Array Type.String)) (embed0 Text.splitOn) , Fun "join" - (forall [] ((Type.String, Type.Array Type.String) `fun2` Type.String)) + (forall_ ((Type.String, Type.Array Type.String) `fun2` Type.String)) (embed0 Text.intercalate) , Fun "concat" - (forall [] (Type.Array Type.String `fun1` Type.String)) + (forall_ (Type.Array Type.String `fun1` Type.String)) (embed0 Text.concat) , Fun "chunks-of" - (forall [] ((Type.Int, Type.String) `fun2` Type.Array Type.String)) + (forall_ ((Type.Int, Type.String) `fun2` Type.Array Type.String)) (embed0 Text.chunksOf) , Fun "die" - (forall [0] (Type.String `fun1` Type.Var 0)) + (forall [0] [] (Type.String `fun1` Type.Var 0)) dieB , Fun "show" - (forall [0] (Type.Var 0 `fun1` Type.String)) + (forall [0] [(0, Type.Display)] (Type.Var 0 `fun1` Type.String)) (embed0 showB) , Fun "pp" - (forall [0] (Type.Var 0 `fun1` Type.String)) + (forall [0] [(0, Type.Display)] (Type.Var 0 `fun1` Type.String)) (embed0 ppB) ] diff --git a/src/T/Stdlib/Op.hs b/src/T/Stdlib/Op.hs index 5cb7d4f..196562a 100644 --- a/src/T/Stdlib/Op.hs +++ b/src/T/Stdlib/Op.hs @@ -26,7 +26,7 @@ import T.Exp.Ann ((:+)(..)) import T.Name (Name) import T.Prelude import T.SExp (sexp) -import T.Type (Γ, forall, fun1, fun2) +import T.Type (Γ, forall, forall_, fun1, fun2) import T.Type qualified as Type import T.Value (Value(..), typeOf) @@ -63,47 +63,47 @@ priorities = operators :: [Op] operators = [ Op "!" - (forall [] (Type.Bool `fun1` Type.Bool)) + (forall_ (Type.Bool `fun1` Type.Bool)) (embed0 not) Prefix 8 , Op "==" - (forall [0] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) -- more polymorphic than we'd like + (forall [0] [(0, Type.Eq)] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) (embed0 eq) Infix 4 , Op "!=" - (forall [0] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) -- more polymorphic than we'd like + (forall [0] [(0, Type.Eq)] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) (embed0 neq) Infix 4 , Op "=~" - (forall [0] ((Type.String, Type.Regexp) `fun2` Type.Bool)) + (forall_ ((Type.String, Type.Regexp) `fun2` Type.Bool)) (embed0 match) Infix 4 , Op "+" - (forall [] ((Type.Int, Type.Int) `fun2` Type.Int)) -- less polymorphic than we'd like + (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Var 0)) add Infixl 6 , Op "-" - (forall [] ((Type.Int, Type.Int) `fun2` Type.Int)) -- less polymorphic than we'd like + (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Var 0)) subtract Infixl 6 , Op "*" - (forall [] ((Type.Int, Type.Int) `fun2` Type.Int)) -- less polymorphic than we'd like + (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Var 0)) multiply Infixl 7 , Op "/" - (forall [] ((Type.Int, Type.Int) `fun2` Type.Int)) -- less polymorphic than we'd like + (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Var 0)) divide Infixl 7 , Op "<" - (forall [] ((Type.Int, Type.Int) `fun2` Type.Bool)) -- less polymorphic than we'd like + (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) lt Infix 4 , Op "<=" - (forall [] ((Type.Int, Type.Int) `fun2` Type.Bool)) -- less polymorphic than we'd like + (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) le Infix 4 , Op ">" - (forall [] ((Type.Int, Type.Int) `fun2` Type.Bool)) -- less polymorphic than we'd like + (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) gt Infix 4 , Op ">=" - (forall [] ((Type.Int, Type.Int) `fun2` Type.Bool)) -- less polymorphic than we'd like + (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) ge Infix 4 , Op "<>" - (forall [] ((Type.String, Type.String) `fun2` Type.String)) + (forall_ ((Type.String, Type.String) `fun2` Type.String)) (embed0 ((<>) @Text)) Infixr 6 ] diff --git a/src/T/Type.hs b/src/T/Type.hs index 462208a..d9393c1 100644 --- a/src/T/Type.hs +++ b/src/T/Type.hs @@ -5,7 +5,9 @@ module T.Type ( Γ , Type(..) , Scheme(..) + , Constraint(..) , forall + , forall_ , fun1 , fun2 , infer @@ -67,12 +69,19 @@ instance SExp.To Type where Var n -> fromString ('#' : show n) -data Scheme = Forall (Set Int) Type +data Scheme = Forall (Set Int) (HashMap Int (Set Constraint)) Type deriving (Show, Eq) -forall :: [Int] -> Type -> Scheme -forall qs t = - Forall (Set.fromList qs) t +forall :: [Int] -> [(Int, Constraint)] -> Type -> Scheme +forall qs cs t = do + let + cm = + HashMap.fromListWith Set.union (map (\(n, c) -> (n, Set.singleton c)) cs) + Forall (Set.fromList qs) cm t + +forall_ :: Type -> Scheme +forall_ = + forall [] [] fun1 :: Type -> Type -> Type fun1 a1 r = @@ -85,51 +94,49 @@ fun2 (a1, a2) r = data Constraint = Num | Eq - | Show + | Display | Sizeable | Iterable deriving (Show, Eq, Ord) satisfies :: Constraint -> Type -> Bool -satisfies Num = \case - Int -> True - Double -> True - _ -> False -satisfies Eq = \case - Unit -> - True - Bool -> - True - Int -> - True - Double -> - True - String -> - True - Regexp -> - True - Array t -> - satisfies Eq t - Record fs -> - all (satisfies Eq) fs - _ -> - False -satisfies Show = \case - Unit -> True - Bool -> True - Int -> True - Double -> True - String -> True - _ -> False -satisfies Sizeable = \case - String -> True - Array _ -> True - Record _ -> True - _ -> False -satisfies Iterable = \case - Array _ -> True - Record _ -> True - _ -> False +satisfies = \case + Num -> \case + Int -> True + Double -> True + _ -> False + Eq -> \case + Unit -> True + Bool -> True + Int -> True + Double -> True + String -> True + Regexp -> True + Array t -> + satisfies Eq t + Record fs -> + all (satisfies Eq) fs + _ -> False + Display -> \case + Unit -> True + Bool -> True + Int -> True + Double -> True + String -> True + Array t -> + satisfies Display t + Record fs -> + all (satisfies Display) fs + _ -> False + Sizeable -> \case + String -> True + Array _ -> True + Record _ -> True + _ -> False + Iterable -> \case + Array _ -> True + Record _ -> True + _ -> False type TypedExp = Cofree Exp.ExpF (Exp.Ann, Type) @@ -205,8 +212,9 @@ lookupCtx name = do ctx <- ask maybe (throwError (MissingVar name)) instantiate (HashMap.lookup name ctx) -generalize :: Γ -> Type -> Scheme +generalize :: Monad m => Γ -> Type -> InferenceT m Scheme generalize ctx t = do + constraints <- gets (.constraints) let fvs = freeVarsType t @@ -214,7 +222,9 @@ generalize ctx t = do foldMap freeVarsScheme ctx qs = Set.difference fvs ctxvs - Forall qs t + cs = + HashMap.filterWithKey (\k _ -> k `Set.member` qs) constraints + pure (Forall qs cs t) freeVarsType :: Type -> Set Int freeVarsType = \case @@ -230,15 +240,23 @@ freeVarsType = \case Set.empty freeVarsScheme :: Scheme -> Set Int -freeVarsScheme (Forall qs t) = +freeVarsScheme (Forall qs _cs t) = Set.difference (freeVarsType t) qs instantiate :: Monad m => Scheme -> InferenceT m Type -instantiate (Forall vars t) = do - fvs <- for (toList vars) $ \v -> do +instantiate (Forall qs cs t) = do + fvs <- for (toList qs) $ \q -> do fv <- freshVar - pure (v, fv) - pure (replaceOnce (HashMap.fromList fvs) t) + pure (q, fv) + let + localSubst = HashMap.fromList fvs + for_ (HashMap.toList cs) $ \(q, cs) -> + case HashMap.lookup q localSubst of + Just (Var n) -> + modify (\s -> s {constraints = HashMap.insertWith Set.union n cs s.constraints}) + _ -> + pure () + pure (replaceOnce localSubst t) -- | 'replaceOnce' is a variant of 'replace' that doesn't -- do deep substitution; this is necessary separate the namespaces diff --git a/test/T/RenderSpec.hs b/test/T/RenderSpec.hs index 5de5822..68dcff1 100644 --- a/test/T/RenderSpec.hs +++ b/test/T/RenderSpec.hs @@ -24,7 +24,7 @@ import T.Stdlib (def) import T.Stdlib qualified as Stdlib import T.Stdlib.Op qualified as Op import T.SExp (sexp) -import T.Type (forall, fun1, fun2) +import T.Type (forall_, fun1, fun2) import T.Type qualified as Type import T.Value (embedAeson) @@ -408,17 +408,17 @@ rWith json tmplStr = do opExt :: [Stdlib.Op] opExt = [ Stdlib.Op "<+>" - (forall [] ((Type.String, Type.String) `fun2` Type.String)) + (forall_ ((Type.String, Type.String) `fun2` Type.String)) (embed0 (\str0 str1 -> str0 <> "+" <> str1 :: Text)) Stdlib.Infixr 6 ] funExt :: [Stdlib.Fun] funExt = [ Stdlib.Fun "bool01" - (forall [] (Type.Bool `fun1` Type.Int)) + (forall_ (Type.Bool `fun1` Type.Int)) (embed0 (bool @Int 0 1)) , Stdlib.Fun "const" - (forall [] (Type.Bool `fun1` Type.String)) + (forall_ (Type.Bool `fun1` Type.String)) (embed0 (const @Bool @Text)) ] From 20070f21443dc3dadf9ba32c46ef075b55f97500 Mon Sep 17 00:00:00 2001 From: Matvey Aksenov Date: Fri, 16 Jan 2026 10:40:37 +0000 Subject: [PATCH 09/13] re-do constraints --- src/T/Stdlib/Fun.hs | 34 +++++----- src/T/Stdlib/Op.hs | 28 ++++----- src/T/Type.hs | 144 +++++++++++++++++++++++++------------------ test/T/RenderSpec.hs | 8 +-- 4 files changed, 119 insertions(+), 95 deletions(-) diff --git a/src/T/Stdlib/Fun.hs b/src/T/Stdlib/Fun.hs index 26cda7b..5029b53 100644 --- a/src/T/Stdlib/Fun.hs +++ b/src/T/Stdlib/Fun.hs @@ -19,7 +19,7 @@ import T.Error (Error(..)) import T.Exp.Ann ((:+)(..), unann) import T.Name (Name) import T.Prelude -import T.Type (Γ, forall, forall_, fun1, fun2) +import T.Type (Γ, forAll, forAll_, fun1, fun2, tyVar) import T.Type qualified as Type import T.Value (Value(..), display, displayWith) @@ -41,57 +41,57 @@ typingCtx = functions :: [Fun] functions = [ Fun "empty?" - (forall [0] [(0, Type.Sizeable)] (Type.Var 0 `fun1` Type.Bool)) + (forAll [0] [(0, Type.Sizeable)] (tyVar 0 `fun1` Type.Bool)) nullB , Fun "length" - (forall [0] [(0, Type.Sizeable)] (Type.Var 0 `fun1` Type.Int)) + (forAll [0] [(0, Type.Sizeable)] (tyVar 0 `fun1` Type.Int)) lengthB , Fun "floor" - (forall_ (Type.Double `fun1` Type.Int)) + (forAll_ (Type.Double `fun1` Type.Int)) (embed0 (floor @Double @Int)) , Fun "ceiling" - (forall_ (Type.Double `fun1` Type.Int)) + (forAll_ (Type.Double `fun1` Type.Int)) (embed0 (ceiling @Double @Int)) , Fun "round" - (forall_ (Type.Double `fun1` Type.Int)) + (forAll_ (Type.Double `fun1` Type.Int)) (embed0 (round @Double @Int)) , Fun "int->double" - (forall_ (Type.Int `fun1` Type.Double)) + (forAll_ (Type.Int `fun1` Type.Double)) (embed0 (fromIntegral @Int @Double)) , Fun "upper-case" - (forall_ (Type.String `fun1` Type.String)) + (forAll_ (Type.String `fun1` Type.String)) (embed0 Text.toUpper) , Fun "lower-case" - (forall_ (Type.String `fun1` Type.String)) + (forAll_ (Type.String `fun1` Type.String)) (embed0 Text.toLower) , Fun "title-case" - (forall_ (Type.String `fun1` Type.String)) + (forAll_ (Type.String `fun1` Type.String)) (embed0 Text.toTitle) , Fun "split" - (forall_ ((Type.String, Type.String) `fun2` Type.Array Type.String)) + (forAll_ ((Type.String, Type.String) `fun2` Type.Array Type.String)) (embed0 Text.splitOn) , Fun "join" - (forall_ ((Type.String, Type.Array Type.String) `fun2` Type.String)) + (forAll_ ((Type.String, Type.Array Type.String) `fun2` Type.String)) (embed0 Text.intercalate) , Fun "concat" - (forall_ (Type.Array Type.String `fun1` Type.String)) + (forAll_ (Type.Array Type.String `fun1` Type.String)) (embed0 Text.concat) , Fun "chunks-of" - (forall_ ((Type.Int, Type.String) `fun2` Type.Array Type.String)) + (forAll_ ((Type.Int, Type.String) `fun2` Type.Array Type.String)) (embed0 Text.chunksOf) , Fun "die" - (forall [0] [] (Type.String `fun1` Type.Var 0)) + (forAll [0] [] (Type.String `fun1` tyVar 0)) dieB , Fun "show" - (forall [0] [(0, Type.Display)] (Type.Var 0 `fun1` Type.String)) + (forAll [0] [(0, Type.Display)] (tyVar 0 `fun1` Type.String)) (embed0 showB) , Fun "pp" - (forall [0] [(0, Type.Display)] (Type.Var 0 `fun1` Type.String)) + (forAll [0] [(0, Type.Display)] (tyVar 0 `fun1` Type.String)) (embed0 ppB) ] diff --git a/src/T/Stdlib/Op.hs b/src/T/Stdlib/Op.hs index 196562a..a8639c8 100644 --- a/src/T/Stdlib/Op.hs +++ b/src/T/Stdlib/Op.hs @@ -26,7 +26,7 @@ import T.Exp.Ann ((:+)(..)) import T.Name (Name) import T.Prelude import T.SExp (sexp) -import T.Type (Γ, forall, forall_, fun1, fun2) +import T.Type (Γ, forAll, forAll_, fun1, fun2, tyVar) import T.Type qualified as Type import T.Value (Value(..), typeOf) @@ -63,47 +63,47 @@ priorities = operators :: [Op] operators = [ Op "!" - (forall_ (Type.Bool `fun1` Type.Bool)) + (forAll_ (Type.Bool `fun1` Type.Bool)) (embed0 not) Prefix 8 , Op "==" - (forall [0] [(0, Type.Eq)] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) + (forAll [0] [(0, Type.Eq)] ((tyVar 0, tyVar 0) `fun2` Type.Bool)) (embed0 eq) Infix 4 , Op "!=" - (forall [0] [(0, Type.Eq)] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) + (forAll [0] [(0, Type.Eq)] ((tyVar 0, tyVar 0) `fun2` Type.Bool)) (embed0 neq) Infix 4 , Op "=~" - (forall_ ((Type.String, Type.Regexp) `fun2` Type.Bool)) + (forAll_ ((Type.String, Type.Regexp) `fun2` Type.Bool)) (embed0 match) Infix 4 , Op "+" - (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Var 0)) + (forAll [0] [(0, Type.Num)] ((tyVar 0, tyVar 0) `fun2` tyVar 0)) add Infixl 6 , Op "-" - (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Var 0)) + (forAll [0] [(0, Type.Num)] ((tyVar 0, tyVar 0) `fun2` tyVar 0)) subtract Infixl 6 , Op "*" - (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Var 0)) + (forAll [0] [(0, Type.Num)] ((tyVar 0, tyVar 0) `fun2` tyVar 0)) multiply Infixl 7 , Op "/" - (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Var 0)) + (forAll [0] [(0, Type.Num)] ((tyVar 0, tyVar 0) `fun2` tyVar 0)) divide Infixl 7 , Op "<" - (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) + (forAll [0] [(0, Type.Num)] ((tyVar 0, tyVar 0) `fun2` Type.Bool)) lt Infix 4 , Op "<=" - (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) + (forAll [0] [(0, Type.Num)] ((tyVar 0, tyVar 0) `fun2` Type.Bool)) le Infix 4 , Op ">" - (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) + (forAll [0] [(0, Type.Num)] ((tyVar 0, tyVar 0) `fun2` Type.Bool)) gt Infix 4 , Op ">=" - (forall [0] [(0, Type.Num)] ((Type.Var 0, Type.Var 0) `fun2` Type.Bool)) + (forAll [0] [(0, Type.Num)] ((tyVar 0, tyVar 0) `fun2` Type.Bool)) ge Infix 4 , Op "<>" - (forall_ ((Type.String, Type.String) `fun2` Type.String)) + (forAll_ ((Type.String, Type.String) `fun2` Type.String)) (embed0 ((<>) @Text)) Infixr 6 ] diff --git a/src/T/Type.hs b/src/T/Type.hs index d9393c1..571c373 100644 --- a/src/T/Type.hs +++ b/src/T/Type.hs @@ -6,10 +6,11 @@ module T.Type , Type(..) , Scheme(..) , Constraint(..) - , forall - , forall_ + , forAll + , forAll_ , fun1 , fun2 + , tyVar , infer , TypeError , extractType @@ -19,10 +20,12 @@ import Control.Monad (foldM) import Control.Monad.Reader (ReaderT, MonadReader, runReaderT, ask) import Control.Monad.State (StateT, MonadState, runStateT, gets, modify) import Control.Monad.Except (ExceptT, MonadError, runExceptT, throwError) +import Data.List qualified as List import Data.List.NonEmpty qualified as NonEmpty import Data.Set (Set) import Data.Set qualified as Set import Data.HashMap.Strict qualified as HashMap +import Text.Printf (printf) import T.Exp (Exp, (:+)(..)) import T.Exp qualified as Exp @@ -43,7 +46,7 @@ data Type | Array Type | Record (HashMap Name Type) | Fun (NonEmpty Type) Type - | Var Int + | Var Int (Set Constraint) deriving (Show, Eq) instance SExp.To Type where @@ -66,22 +69,37 @@ instance SExp.To Type where SExp.curly (concatMap (\(k, v) -> [sexp k, sexp v]) (HashMap.toList fs)) Fun args r -> SExp.round ["->", SExp.square (toList (map sexp args)), sexp r] - Var n -> - fromString ('#' : show n) + Var n cs -> + fromString (printf "#%d{%s}" n (List.intercalate ", " (map show (toList cs)))) -data Scheme = Forall (Set Int) (HashMap Int (Set Constraint)) Type +data Scheme = Forall (Set Int) Type deriving (Show, Eq) -forall :: [Int] -> [(Int, Constraint)] -> Type -> Scheme -forall qs cs t = do +forAll :: [Int] -> [(Int, Constraint)] -> Type -> Scheme +forAll qs cs t = do let cm = - HashMap.fromListWith Set.union (map (\(n, c) -> (n, Set.singleton c)) cs) - Forall (Set.fromList qs) cm t + HashMap.fromListWith Set.union (map (\(i, c) -> (i, Set.singleton c)) cs) + Forall (Set.fromList qs) (injectConstraints cm t) + +injectConstraints :: HashMap Int (Set Constraint) -> Type -> Type +injectConstraints cm t = case t of + Var n cs -> do + let + extra = HashMap.findWithDefault Set.empty n cm + Var n (cs <> extra) + Array arr -> + Array (injectConstraints cm arr) + Record r -> + Record (map (injectConstraints cm) r) + Fun args r -> + Fun (map (injectConstraints cm) args) (injectConstraints cm r) + _ -> + t -forall_ :: Type -> Scheme -forall_ = - forall [] [] +forAll_ :: Type -> Scheme +forAll_ = + forAll [] [] fun1 :: Type -> Type -> Type fun1 a1 r = @@ -91,6 +109,10 @@ fun2 :: (Type, Type) -> Type -> Type fun2 (a1, a2) r = Fun (a1 :| a2 : []) r +tyVar :: Int -> Type +tyVar n = + Var n mempty + data Constraint = Num | Eq @@ -104,6 +126,7 @@ satisfies = \case Num -> \case Int -> True Double -> True + Var _n cs -> Num `Set.member` cs _ -> False Eq -> \case Unit -> True @@ -116,6 +139,7 @@ satisfies = \case satisfies Eq t Record fs -> all (satisfies Eq) fs + Var _n cs -> Eq `Set.member` cs _ -> False Display -> \case Unit -> True @@ -127,15 +151,18 @@ satisfies = \case satisfies Display t Record fs -> all (satisfies Display) fs + Var _n cs -> Display `Set.member` cs _ -> False Sizeable -> \case String -> True Array _ -> True Record _ -> True + Var _n cs -> Sizeable `Set.member` cs _ -> False Iterable -> \case Array _ -> True Record _ -> True + Var _n cs -> Iterable `Set.member` cs _ -> False type TypedExp = Cofree Exp.ExpF (Exp.Ann, Type) @@ -146,9 +173,8 @@ newtype InferenceT m a = InferenceT (ReaderT Γ (StateT Σ (ExceptT TypeError m) type Γ = HashMap Name Scheme data Σ = Σ - { subst :: Subst - , constraints :: HashMap Int (Set Constraint) - , counter :: Int + { subst :: Subst + , counter :: Int } deriving (Show, Eq) type Subst = HashMap Int Type @@ -156,7 +182,6 @@ type Subst = HashMap Int Type emptyΣ :: Σ emptyΣ = Σ { subst = mempty - , constraints = mempty , counter = 0 } @@ -214,7 +239,6 @@ lookupCtx name = do generalize :: Monad m => Γ -> Type -> InferenceT m Scheme generalize ctx t = do - constraints <- gets (.constraints) let fvs = freeVarsType t @@ -222,13 +246,11 @@ generalize ctx t = do foldMap freeVarsScheme ctx qs = Set.difference fvs ctxvs - cs = - HashMap.filterWithKey (\k _ -> k `Set.member` qs) constraints - pure (Forall qs cs t) + pure (Forall qs t) freeVarsType :: Type -> Set Int freeVarsType = \case - Var n -> + Var n _cs -> Set.singleton n Array t -> freeVarsType t @@ -240,23 +262,15 @@ freeVarsType = \case Set.empty freeVarsScheme :: Scheme -> Set Int -freeVarsScheme (Forall qs _cs t) = +freeVarsScheme (Forall qs t) = Set.difference (freeVarsType t) qs instantiate :: Monad m => Scheme -> InferenceT m Type -instantiate (Forall qs cs t) = do +instantiate (Forall qs t) = do fvs <- for (toList qs) $ \q -> do fv <- freshVar pure (q, fv) - let - localSubst = HashMap.fromList fvs - for_ (HashMap.toList cs) $ \(q, cs) -> - case HashMap.lookup q localSubst of - Just (Var n) -> - modify (\s -> s {constraints = HashMap.insertWith Set.union n cs s.constraints}) - _ -> - pure () - pure (replaceOnce localSubst t) + pure (replaceOnce (HashMap.fromList fvs) t) -- | 'replaceOnce' is a variant of 'replace' that doesn't -- do deep substitution; this is necessary separate the namespaces @@ -271,8 +285,14 @@ replaceOnce subst t = Record (map (replaceOnce subst) r) Fun args r -> Fun (map (replaceOnce subst) args) (replaceOnce subst r) - Var n -> - HashMap.findWithDefault (Var n) n subst + Var n cs -> + case HashMap.lookup n subst of + Nothing -> + Var n cs + Just (Var m ds) -> + Var m (Set.union cs ds) + Just x -> + applyConstraints cs x _ -> t @@ -286,10 +306,10 @@ unify t1 t2 = do (a, b) | a == b -> pure a - (Var n, t) -> do - unifyVar s n t - (t, Var n) -> do - unifyVar s n t + (Var n cs, t) -> do + unifyVar s n cs t + (t, Var n cs) -> do + unifyVar s n cs t (Array a, Array b) -> map Array (unify a b) (Record m1, Record m2) -> do @@ -300,28 +320,25 @@ unify t1 t2 = do (a, b) -> throwError (TypeMismatch a b) -unifyVar :: Monad m => Subst -> Int -> Type -> InferenceT m Type -unifyVar s n t = do - when (occurs n t s) (throwError (OccursCheck n t)) - cs <- gets (HashMap.findWithDefault Set.empty n . (.constraints)) - checkConstraints cs t +unifyVar :: Monad m => Subst -> Int -> Set Constraint -> Type -> InferenceT m Type +unifyVar s n cs t0 = do + when (occurs n t0 s) (throwError (OccursCheck n t0)) + checkConstraints cs t0 + let + t = + applyConstraints cs t0 extendSubst n t pure t checkConstraints :: Monad m => Set Constraint -> Type -> InferenceT m () -checkConstraints cs t0 = do - subst <- gets (.subst) - case replace subst t0 of - Var m -> - modify (\s -> s {constraints = HashMap.insertWith Set.union m cs s.constraints}) - t -> - for_ cs $ \c -> - unless (satisfies c t) (throwError (ConstraintViolation c t)) +checkConstraints cs t = do + for_ cs $ \c -> + unless (satisfies c t) (throwError (ConstraintViolation c t)) occurs :: Int -> Type -> Subst -> Bool occurs n t subst = case replace subst t of - Var m -> + Var m _cs -> n == m Array arr -> occurs n arr subst @@ -341,15 +358,22 @@ replace subst t = Record (map (replace subst) r) Fun args r -> Fun (map (replace subst) args) (replace subst r) - Var n -> + Var n cs -> case HashMap.lookup n subst of Nothing -> - Var n + Var n cs Just nt -> - replace subst nt + applyConstraints cs (replace subst nt) _ -> t +applyConstraints :: Set Constraint -> Type -> Type +applyConstraints cs = \case + Var n vcs -> + Var n (Set.union vcs cs) + t -> + t + extendSubst :: Monad m => Int -> Type -> InferenceT m () extendSubst n t = modify (\s -> s { subst = HashMap.insert n t s.subst }) @@ -372,7 +396,7 @@ freshVar :: Monad m => InferenceT m Type freshVar = do n <- gets (.counter) modify (\s -> s { counter = s.counter + 1 }) - pure (Var n) + pure (Var n mempty) checkKey :: Monad m => TypedExp -> Name -> InferenceT m Type checkKey r name = do @@ -383,9 +407,9 @@ checkKey r name = do pure t Nothing -> throwError (MissingKey name) - Var n -> do + var@(Var _n _cs) -> do v <- freshVar - _ <- unify (Var n) (Record (HashMap.singleton name v)) + _ <- unify var (Record (HashMap.singleton name v)) pure v _ -> throwError (NotARecord (extractType r)) @@ -416,7 +440,7 @@ finalize subst = defaultType :: Type -> Type defaultType = \case - Var _ -> + Var {} -> Unit Array t -> Array (defaultType t) diff --git a/test/T/RenderSpec.hs b/test/T/RenderSpec.hs index 68dcff1..b1ae1e7 100644 --- a/test/T/RenderSpec.hs +++ b/test/T/RenderSpec.hs @@ -24,7 +24,7 @@ import T.Stdlib (def) import T.Stdlib qualified as Stdlib import T.Stdlib.Op qualified as Op import T.SExp (sexp) -import T.Type (forall_, fun1, fun2) +import T.Type (forAll_, fun1, fun2) import T.Type qualified as Type import T.Value (embedAeson) @@ -408,17 +408,17 @@ rWith json tmplStr = do opExt :: [Stdlib.Op] opExt = [ Stdlib.Op "<+>" - (forall_ ((Type.String, Type.String) `fun2` Type.String)) + (forAll_ ((Type.String, Type.String) `fun2` Type.String)) (embed0 (\str0 str1 -> str0 <> "+" <> str1 :: Text)) Stdlib.Infixr 6 ] funExt :: [Stdlib.Fun] funExt = [ Stdlib.Fun "bool01" - (forall_ (Type.Bool `fun1` Type.Int)) + (forAll_ (Type.Bool `fun1` Type.Int)) (embed0 (bool @Int 0 1)) , Stdlib.Fun "const" - (forall_ (Type.Bool `fun1` Type.String)) + (forAll_ (Type.Bool `fun1` Type.String)) (embed0 (const @Bool @Text)) ] From f52a180833ab7d10c7945448889ebfd0f0dd7a56 Mon Sep 17 00:00:00 2001 From: Matvey Aksenov Date: Mon, 19 Jan 2026 21:27:16 +0000 Subject: [PATCH 10/13] no god modules --- src/T/Prelude.hs | 4 - src/T/Render.hs | 4 +- src/T/Type.hs | 402 ++------------------------------------------ src/T/Type/Infer.hs | 132 +++++++++++++++ src/T/Type/Unify.hs | 145 ++++++++++++++++ src/T/Type/Vocab.hs | 182 ++++++++++++++++++++ src/T/Value.hs | 9 +- 7 files changed, 484 insertions(+), 394 deletions(-) create mode 100644 src/T/Type/Infer.hs create mode 100644 src/T/Type/Unify.hs create mode 100644 src/T/Type/Vocab.hs diff --git a/src/T/Prelude.hs b/src/T/Prelude.hs index 98a8bb3..7eed0a7 100644 --- a/src/T/Prelude.hs +++ b/src/T/Prelude.hs @@ -64,10 +64,7 @@ module T.Prelude , filter , first , flip - , foldl' - , foldr , foldM_ - , foldr1 , for , for_ , fromIntegral @@ -83,7 +80,6 @@ module T.Prelude , second , seq , sequence - , toList , traverse , traverse_ , T.Prelude.traceShow diff --git a/src/T/Render.hs b/src/T/Render.hs index 5b389db..e9c9625 100644 --- a/src/T/Render.hs +++ b/src/T/Render.hs @@ -243,7 +243,7 @@ enforceArray exp = do Value.Array xs -> pure xs _ -> - throwError (TypeError exp (error "Type.Array") (Value.typeOf v) (sexp v)) + throwError (TypeError exp (Type.Array (Type.tyVar 0)) (Value.typeOf v) (sexp v)) enforceRecord :: (Ctx m, MonadError Error m) => Exp -> m (HashMap Name Value) enforceRecord exp = do @@ -320,7 +320,7 @@ evalLValue = Just v -> (Right v, path) Right v -> - throwError (TypeError exp (error "Type.Array") (Value.typeOf v) (sexp v)) + throwError (TypeError exp (Type.Array (Type.tyVar 0)) (Value.typeOf v) (sexp v)) Left name -> throwError (NotInScope name) _ :< Key exp0 key@(_ :+ key0) -> do diff --git a/src/T/Type.hs b/src/T/Type.hs index 571c373..c12bb2c 100644 --- a/src/T/Type.hs +++ b/src/T/Type.hs @@ -6,74 +6,33 @@ module T.Type , Type(..) , Scheme(..) , Constraint(..) + , infer + , TypeError(..) + -- * stdlib helpers , forAll , forAll_ , fun1 , fun2 , tyVar - , infer - , TypeError - , extractType ) where -import Control.Monad (foldM) -import Control.Monad.Reader (ReaderT, MonadReader, runReaderT, ask) -import Control.Monad.State (StateT, MonadState, runStateT, gets, modify) -import Control.Monad.Except (ExceptT, MonadError, runExceptT, throwError) -import Data.List qualified as List -import Data.List.NonEmpty qualified as NonEmpty +import Data.HashMap.Strict qualified as HashMap import Data.Set (Set) import Data.Set qualified as Set -import Data.HashMap.Strict qualified as HashMap -import Text.Printf (printf) -import T.Exp (Exp, (:+)(..)) -import T.Exp qualified as Exp -import T.Name (Name) import T.Prelude -import T.SExp (sexp) -import T.SExp qualified as SExp -import T.Tmpl (Tmpl(..)) - - -data Type - = Unit - | Bool - | Int - | Double - | String - | Regexp - | Array Type - | Record (HashMap Name Type) - | Fun (NonEmpty Type) Type - | Var Int (Set Constraint) - deriving (Show, Eq) - -instance SExp.To Type where - sexp = \case - Unit -> - "unit" - Bool -> - "bool" - Int -> - "int" - Double -> - "double" - String -> - "string" - Regexp -> - "regexp" - Array t -> - SExp.square [sexp t] - Record fs -> - SExp.curly (concatMap (\(k, v) -> [sexp k, sexp v]) (HashMap.toList fs)) - Fun args r -> - SExp.round ["->", SExp.square (toList (map sexp args)), sexp r] - Var n cs -> - fromString (printf "#%d{%s}" n (List.intercalate ", " (map show (toList cs)))) +import T.Type.Vocab + ( InferenceT(..) + , Γ + , TypeError(..) + , Type(..) + , freeVarsType + , Scheme(..) + , freeVarsScheme + , Constraint(..) + ) +import T.Type.Infer (infer) -data Scheme = Forall (Set Int) Type - deriving (Show, Eq) forAll :: [Int] -> [(Int, Constraint)] -> Type -> Scheme forAll qs cs t = do @@ -84,10 +43,8 @@ forAll qs cs t = do injectConstraints :: HashMap Int (Set Constraint) -> Type -> Type injectConstraints cm t = case t of - Var n cs -> do - let - extra = HashMap.findWithDefault Set.empty n cm - Var n (cs <> extra) + Var n cs -> + Var n (cs <> HashMap.findWithDefault Set.empty n cm) Array arr -> Array (injectConstraints cm arr) Record r -> @@ -113,130 +70,6 @@ tyVar :: Int -> Type tyVar n = Var n mempty -data Constraint - = Num - | Eq - | Display - | Sizeable - | Iterable - deriving (Show, Eq, Ord) - -satisfies :: Constraint -> Type -> Bool -satisfies = \case - Num -> \case - Int -> True - Double -> True - Var _n cs -> Num `Set.member` cs - _ -> False - Eq -> \case - Unit -> True - Bool -> True - Int -> True - Double -> True - String -> True - Regexp -> True - Array t -> - satisfies Eq t - Record fs -> - all (satisfies Eq) fs - Var _n cs -> Eq `Set.member` cs - _ -> False - Display -> \case - Unit -> True - Bool -> True - Int -> True - Double -> True - String -> True - Array t -> - satisfies Display t - Record fs -> - all (satisfies Display) fs - Var _n cs -> Display `Set.member` cs - _ -> False - Sizeable -> \case - String -> True - Array _ -> True - Record _ -> True - Var _n cs -> Sizeable `Set.member` cs - _ -> False - Iterable -> \case - Array _ -> True - Record _ -> True - Var _n cs -> Iterable `Set.member` cs - _ -> False - -type TypedExp = Cofree Exp.ExpF (Exp.Ann, Type) - -newtype InferenceT m a = InferenceT (ReaderT Γ (StateT Σ (ExceptT TypeError m)) a) - deriving (Functor, Applicative, Monad, MonadReader Γ, MonadState Σ, MonadError TypeError) - -type Γ = HashMap Name Scheme - -data Σ = Σ - { subst :: Subst - , counter :: Int - } deriving (Show, Eq) - -type Subst = HashMap Int Type - -emptyΣ :: Σ -emptyΣ = Σ - { subst = mempty - , counter = 0 - } - -data TypeError - = MissingKey Name - | MissingVar Name - | NotARecord Type - | TypeMismatch Type Type - | ConstraintViolation Constraint Type - | OccursCheck Int Type - deriving (Show, Eq) - -runInferenceT :: Γ -> Σ -> InferenceT m a -> m (Either TypeError (a, Σ)) -runInferenceT ctx subst (InferenceT m) = - runExceptT (runStateT (runReaderT m ctx) subst) - -infer :: Γ -> Exp -> Either TypeError TypedExp -infer ctx exp = do - (te, finalΣ) <- runIdentity (runInferenceT ctx emptyΣ (inferExp exp)) - pure (finalize finalΣ.subst te) - -inferTmpl :: Monad m => Tmpl -> InferenceT m () -inferTmpl = \case - Raw _text -> - pure () - Comment _text -> - pure () - Exp exp -> do - _texp <- inferExp exp - pure () - -inferExp :: Monad m => Exp -> InferenceT m TypedExp -inferExp (ann :< e) = do - te <- traverse inferExp e - inferredType <- case te of - Exp.Lit l -> - inferLiteral l - Exp.Var (_ann :+ name) -> - lookupCtx name - Exp.If b t f -> do - _ <- unify (extractType b) Bool - unify (extractType t) (extractType f) - Exp.App (_ann :+ name) args -> - checkApp name args - Exp.Idx arr idx -> - checkIdx arr idx - Exp.Key r (_ann :+ name) -> - checkKey r name - pure ((ann, inferredType) :< te) - -lookupCtx :: Monad m => Name -> InferenceT m Type -lookupCtx name = do - ctx <- ask - maybe (throwError (MissingVar name)) instantiate (HashMap.lookup name ctx) - generalize :: Monad m => Γ -> Type -> InferenceT m Scheme generalize ctx t = do let @@ -247,204 +80,3 @@ generalize ctx t = do qs = Set.difference fvs ctxvs pure (Forall qs t) - -freeVarsType :: Type -> Set Int -freeVarsType = \case - Var n _cs -> - Set.singleton n - Array t -> - freeVarsType t - Record r -> - foldMap freeVarsType r - Fun args r -> - foldMap freeVarsType args <> freeVarsType r - _ -> - Set.empty - -freeVarsScheme :: Scheme -> Set Int -freeVarsScheme (Forall qs t) = - Set.difference (freeVarsType t) qs - -instantiate :: Monad m => Scheme -> InferenceT m Type -instantiate (Forall qs t) = do - fvs <- for (toList qs) $ \q -> do - fv <- freshVar - pure (q, fv) - pure (replaceOnce (HashMap.fromList fvs) t) - --- | 'replaceOnce' is a variant of 'replace' that doesn't --- do deep substitution; this is necessary separate the namespaces --- of quantified variables and unitification variables which is --- useful for e.g. stdlib definitions -replaceOnce :: Subst -> Type -> Type -replaceOnce subst t = - case t of - Array arr -> - Array (replaceOnce subst arr) - Record r -> - Record (map (replaceOnce subst) r) - Fun args r -> - Fun (map (replaceOnce subst) args) (replaceOnce subst r) - Var n cs -> - case HashMap.lookup n subst of - Nothing -> - Var n cs - Just (Var m ds) -> - Var m (Set.union cs ds) - Just x -> - applyConstraints cs x - _ -> - t - -extractType :: TypedExp -> Type -extractType ((_ann, t) :< _e) = t - -unify :: Monad m => Type -> Type -> InferenceT m Type -unify t1 t2 = do - s <- gets (.subst) - case (replace s t1, replace s t2) of - (a, b) - | a == b -> - pure a - (Var n cs, t) -> do - unifyVar s n cs t - (t, Var n cs) -> do - unifyVar s n cs t - (Array a, Array b) -> - map Array (unify a b) - (Record m1, Record m2) -> do - map Record (sequence (HashMap.intersectionWith unify m1 m2)) - (Fun args1 ret1, Fun args2 ret2) - | NonEmpty.length args1 == NonEmpty.length args2 -> - liftA2 Fun (sequence (NonEmpty.zipWith unify args1 args2)) (unify ret1 ret2) - (a, b) -> - throwError (TypeMismatch a b) - -unifyVar :: Monad m => Subst -> Int -> Set Constraint -> Type -> InferenceT m Type -unifyVar s n cs t0 = do - when (occurs n t0 s) (throwError (OccursCheck n t0)) - checkConstraints cs t0 - let - t = - applyConstraints cs t0 - extendSubst n t - pure t - -checkConstraints :: Monad m => Set Constraint -> Type -> InferenceT m () -checkConstraints cs t = do - for_ cs $ \c -> - unless (satisfies c t) (throwError (ConstraintViolation c t)) - -occurs :: Int -> Type -> Subst -> Bool -occurs n t subst = - case replace subst t of - Var m _cs -> - n == m - Array arr -> - occurs n arr subst - Record r -> - any (\t' -> occurs n t' subst) r - Fun args r -> - any (\a -> occurs n a subst) args || occurs n r subst - _ -> - False - -replace :: Subst -> Type -> Type -replace subst t = - case t of - Array arr -> - Array (replace subst arr) - Record r -> - Record (map (replace subst) r) - Fun args r -> - Fun (map (replace subst) args) (replace subst r) - Var n cs -> - case HashMap.lookup n subst of - Nothing -> - Var n cs - Just nt -> - applyConstraints cs (replace subst nt) - _ -> - t - -applyConstraints :: Set Constraint -> Type -> Type -applyConstraints cs = \case - Var n vcs -> - Var n (Set.union vcs cs) - t -> - t - -extendSubst :: Monad m => Int -> Type -> InferenceT m () -extendSubst n t = - modify (\s -> s { subst = HashMap.insert n t s.subst }) - -checkApp :: Monad m => Name -> NonEmpty TypedExp -> InferenceT m Type -checkApp name args = do - ft <- lookupCtx name - r <- freshVar - _ <- unify ft (Fun (map extractType args) r) - pure r - -checkIdx :: Monad m => TypedExp -> TypedExp -> InferenceT m Type -checkIdx arr idx = do - e <- freshVar - _ <- unify (extractType arr) (Array e) - _ <- unify (extractType idx) Int - pure e - -freshVar :: Monad m => InferenceT m Type -freshVar = do - n <- gets (.counter) - modify (\s -> s { counter = s.counter + 1 }) - pure (Var n mempty) - -checkKey :: Monad m => TypedExp -> Name -> InferenceT m Type -checkKey r name = do - case extractType r of - Record fields -> - case HashMap.lookup name fields of - Just t -> - pure t - Nothing -> - throwError (MissingKey name) - var@(Var _n _cs) -> do - v <- freshVar - _ <- unify var (Record (HashMap.singleton name v)) - pure v - _ -> - throwError (NotARecord (extractType r)) - -inferLiteral :: Monad m => Exp.Literal -> InferenceT m Type -inferLiteral = \case - Exp.Null -> pure Unit - Exp.Bool _ -> pure Bool - Exp.Int _ -> pure Int - Exp.Double _ -> pure Double - Exp.String _ -> pure String - Exp.Regexp _ -> pure Regexp - Exp.Array xs -> do - ys <- traverse inferExp xs - t <- case toList (map extractType ys) of - [] -> - freshVar - z : zs -> - foldM unify z zs - pure (Array t) - Exp.Record r -> do - ts <- traverse inferExp r - pure (Record (map extractType ts)) - -finalize :: Subst -> TypedExp -> TypedExp -finalize subst = - map (\(ann, t) -> (ann, defaultType (replace subst t))) - -defaultType :: Type -> Type -defaultType = \case - Var {} -> - Unit - Array t -> - Array (defaultType t) - Record r -> - Record (map defaultType r) - t -> - t diff --git a/src/T/Type/Infer.hs b/src/T/Type/Infer.hs new file mode 100644 index 0000000..1a62535 --- /dev/null +++ b/src/T/Type/Infer.hs @@ -0,0 +1,132 @@ +{-# LANGUAGE OverloadedRecordDot #-} +module T.Type.Infer + ( infer + ) where + +import Control.Monad (foldM) +import Control.Monad.Reader (ask) +import Control.Monad.Except (throwError) +import Data.HashMap.Strict qualified as HashMap + +import T.Exp (Exp, (:+)(..)) +import T.Exp qualified as Exp +import T.Name (Name) +import T.Prelude +import T.Tmpl (Tmpl(..)) +import T.Type.Vocab + ( InferenceT + , runInferenceT + , Γ + , Σ(..) + , emptyΣ + , freshVar + , TypeError(..) + , TypedExp + , Type(..) + , Scheme(..) + ) +import T.Type.Unify + ( unify + , replaceOnce + , finalize + ) + + +infer :: Γ -> Exp -> Either TypeError TypedExp +infer ctx exp = do + (te, finalΣ) <- runIdentity (runInferenceT ctx emptyΣ (inferExp exp)) + pure (finalize finalΣ.subst te) + +inferTmpl :: Monad m => Tmpl -> InferenceT m () +inferTmpl = \case + Raw _text -> + pure () + Comment _text -> + pure () + Exp exp -> do + _texp <- inferExp exp + pure () + +inferExp :: Monad m => Exp -> InferenceT m TypedExp +inferExp (ann :< e) = do + te <- traverse inferExp e + inferredType <- case te of + Exp.Lit l -> + inferLiteral l + Exp.Var (_ann :+ name) -> + lookupCtx name + Exp.If b t f -> do + _ <- unify (extractType b) Bool + unify (extractType t) (extractType f) + Exp.App (_ann :+ name) args -> + checkApp name args + Exp.Idx arr idx -> + checkIdx arr idx + Exp.Key r (_ann :+ name) -> + checkKey r name + pure ((ann, inferredType) :< te) + +checkApp :: Monad m => Name -> NonEmpty TypedExp -> InferenceT m Type +checkApp name args = do + ft <- lookupCtx name + r <- freshVar + _ <- unify ft (Fun (map extractType args) r) + pure r + +checkIdx :: Monad m => TypedExp -> TypedExp -> InferenceT m Type +checkIdx arr idx = do + e <- freshVar + _ <- unify (extractType arr) (Array e) + _ <- unify (extractType idx) Int + pure e + +checkKey :: Monad m => TypedExp -> Name -> InferenceT m Type +checkKey r name = do + case extractType r of + Record fields -> + case HashMap.lookup name fields of + Just t -> + pure t + Nothing -> + throwError (MissingKey name) + var@(Var _n _cs) -> do + v <- freshVar + _ <- unify var (Record (HashMap.singleton name v)) + pure v + _ -> + throwError (NotARecord (extractType r)) + +inferLiteral :: Monad m => Exp.Literal -> InferenceT m Type +inferLiteral = \case + Exp.Null -> pure Unit + Exp.Bool _ -> pure Bool + Exp.Int _ -> pure Int + Exp.Double _ -> pure Double + Exp.String _ -> pure String + Exp.Regexp _ -> pure Regexp + Exp.Array xs -> do + ys <- traverse inferExp xs + t <- case toList (map extractType ys) of + [] -> + freshVar + z : zs -> + foldM unify z zs + pure (Array t) + Exp.Record r -> do + ts <- traverse inferExp r + pure (Record (map extractType ts)) + +lookupCtx :: Monad m => Name -> InferenceT m Type +lookupCtx name = do + ctx <- ask + maybe (throwError (MissingVar name)) instantiate (HashMap.lookup name ctx) + +instantiate :: Monad m => Scheme -> InferenceT m Type +instantiate (Forall qs t) = do + fvs <- for (toList qs) $ \q -> do + fv <- freshVar + pure (q, fv) + pure (replaceOnce (HashMap.fromList fvs) t) + +extractType :: TypedExp -> Type +extractType ((_ann, t) :< _e) = t diff --git a/src/T/Type/Unify.hs b/src/T/Type/Unify.hs new file mode 100644 index 0000000..e2710dc --- /dev/null +++ b/src/T/Type/Unify.hs @@ -0,0 +1,145 @@ +{-# LANGUAGE OverloadedRecordDot #-} +module T.Type.Unify + ( unify + , replace + , replaceOnce + , finalize + ) where + +import Control.Monad.Except (throwError) +import Control.Monad.State (gets, modify) +import Data.HashMap.Strict qualified as HashMap +import Data.List.NonEmpty qualified as NonEmpty +import Data.Set (Set) +import Data.Set qualified as Set + +import T.Prelude +import T.Type.Vocab + ( InferenceT + , Σ(..) + , Subst + , TypedExp + , Type(..) + , TypeError(..) + , Constraint + , satisfies + ) + + +unify :: Monad m => Type -> Type -> InferenceT m Type +unify t1 t2 = do + s <- gets (.subst) + case (replace s t1, replace s t2) of + (a, b) + | a == b -> + pure a + (Var n cs, t) -> do + unifyVar s n cs t + (t, Var n cs) -> do + unifyVar s n cs t + (Array a, Array b) -> + map Array (unify a b) + (Record m1, Record m2) -> do + map Record (sequence (HashMap.intersectionWith unify m1 m2)) + (Fun args1 ret1, Fun args2 ret2) + | NonEmpty.length args1 == NonEmpty.length args2 -> + liftA2 Fun (sequence (NonEmpty.zipWith unify args1 args2)) (unify ret1 ret2) + (a, b) -> + throwError (TypeMismatch a b) + +unifyVar :: Monad m => Subst -> Int -> Set Constraint -> Type -> InferenceT m Type +unifyVar s n cs t0 = do + when (occurs n t0 s) (throwError (OccursCheck n t0)) + checkConstraints cs t0 + let + t = + applyConstraints cs t0 + extendSubst n t + pure t + +replace :: Subst -> Type -> Type +replace subst t = + case t of + Array arr -> + Array (replace subst arr) + Record r -> + Record (map (replace subst) r) + Fun args r -> + Fun (map (replace subst) args) (replace subst r) + Var n cs -> + case HashMap.lookup n subst of + Nothing -> + Var n cs + Just nt -> + applyConstraints cs (replace subst nt) + _ -> + t + +-- | 'replaceOnce' is a variant of 'replace' that doesn't +-- do deep substitution; this is necessary separate the namespaces +-- of quantified variables and unitification variables which is +-- useful for e.g. stdlib definitions +replaceOnce :: Subst -> Type -> Type +replaceOnce subst t = + case t of + Array arr -> + Array (replaceOnce subst arr) + Record r -> + Record (map (replaceOnce subst) r) + Fun args r -> + Fun (map (replaceOnce subst) args) (replaceOnce subst r) + Var n cs -> + case HashMap.lookup n subst of + Nothing -> + Var n cs + Just (Var m ds) -> + Var m (Set.union cs ds) + Just x -> + applyConstraints cs x + _ -> + t + +occurs :: Int -> Type -> Subst -> Bool +occurs n t subst = + case replace subst t of + Var m _cs -> + n == m + Array arr -> + occurs n arr subst + Record r -> + any (\t' -> occurs n t' subst) r + Fun args r -> + any (\a -> occurs n a subst) args || occurs n r subst + _ -> + False + +checkConstraints :: Monad m => Set Constraint -> Type -> InferenceT m () +checkConstraints cs t = do + for_ cs $ \c -> + unless (satisfies c t) (throwError (ConstraintViolation c t)) + +applyConstraints :: Set Constraint -> Type -> Type +applyConstraints cs = \case + Var n vcs -> + Var n (Set.union vcs cs) + t -> + t + +extendSubst :: Monad m => Int -> Type -> InferenceT m () +extendSubst n t = + modify (\s -> s { subst = HashMap.insert n t s.subst }) + +finalize :: Subst -> TypedExp -> TypedExp +finalize subst = + map (\(ann, t) -> (ann, defaultType (replace subst t))) + +defaultType :: Type -> Type +defaultType = \case + Var {} -> + Unit + Array t -> + Array (defaultType t) + Record r -> + Record (map defaultType r) + t -> + t diff --git a/src/T/Type/Vocab.hs b/src/T/Type/Vocab.hs new file mode 100644 index 0000000..9b13ad0 --- /dev/null +++ b/src/T/Type/Vocab.hs @@ -0,0 +1,182 @@ +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE OverloadedRecordDot #-} +module T.Type.Vocab + ( InferenceT(..) + , runInferenceT + , Γ + , Σ(..) + , Subst + , emptyΣ + , freshVar + , TypeError(..) + , TypedExp + , Type(..) + , freeVarsType + , Scheme(..) + , freeVarsScheme + , Constraint(..) + , satisfies + ) where + +import Control.Monad.Reader (ReaderT, MonadReader, runReaderT) +import Control.Monad.State (StateT, MonadState, runStateT, gets, modify) +import Control.Monad.Except (ExceptT, MonadError, runExceptT) +import Data.HashMap.Strict qualified as HashMap +import Data.List qualified as List +import Data.Set (Set) +import Data.Set qualified as Set +import Text.Printf (printf) + +import T.Exp qualified as Exp +import T.Name (Name) +import T.Prelude +import T.SExp (sexp) +import T.SExp qualified as SExp + + +newtype InferenceT m a = InferenceT (ReaderT Γ (StateT Σ (ExceptT TypeError m)) a) + deriving (Functor, Applicative, Monad, MonadReader Γ, MonadState Σ, MonadError TypeError) + +runInferenceT :: Γ -> Σ -> InferenceT m a -> m (Either TypeError (a, Σ)) +runInferenceT ctx subst (InferenceT m) = + runExceptT (runStateT (runReaderT m ctx) subst) + +type Γ = HashMap Name Scheme + +data Σ = Σ + { subst :: Subst + , counter :: Int + } deriving (Show, Eq) + +type Subst = HashMap Int Type + +emptyΣ :: Σ +emptyΣ = Σ + { subst = mempty + , counter = 0 + } + +freshVar :: Monad m => InferenceT m Type +freshVar = do + n <- gets (.counter) + modify (\s -> s { counter = s.counter + 1 }) + pure (Var n mempty) + +data TypeError + = MissingKey Name + | MissingVar Name + | NotARecord Type + | TypeMismatch Type Type + | ConstraintViolation Constraint Type + | OccursCheck Int Type + deriving (Show, Eq) + +type TypedExp = Cofree Exp.ExpF (Exp.Ann, Type) + +data Type + = Unit + | Bool + | Int + | Double + | String + | Regexp + | Array Type + | Record (HashMap Name Type) + | Fun (NonEmpty Type) Type + | Var Int (Set Constraint) + deriving (Show, Eq) + +instance SExp.To Type where + sexp = \case + Unit -> + "unit" + Bool -> + "bool" + Int -> + "int" + Double -> + "double" + String -> + "string" + Regexp -> + "regexp" + Array t -> + SExp.square [sexp t] + Record fs -> + SExp.curly (concatMap (\(k, v) -> [sexp k, sexp v]) (HashMap.toList fs)) + Fun args r -> + SExp.round ["->", SExp.square (toList (map sexp args)), sexp r] + Var n cs -> + fromString (printf "#%d{%s}" n (List.intercalate ", " (map show (toList cs)))) + +data Scheme = Forall (Set Int) Type + deriving (Show, Eq) + +freeVarsType :: Type -> Set Int +freeVarsType = \case + Var n _cs -> + Set.singleton n + Array t -> + freeVarsType t + Record r -> + foldMap freeVarsType r + Fun args r -> + foldMap freeVarsType args <> freeVarsType r + _ -> + Set.empty + +freeVarsScheme :: Scheme -> Set Int +freeVarsScheme (Forall qs t) = + Set.difference (freeVarsType t) qs + +data Constraint + = Num + | Eq + | Display + | Sizeable + | Iterable + deriving (Show, Eq, Ord) + +satisfies :: Constraint -> Type -> Bool +satisfies = \case + Num -> \case + Int -> True + Double -> True + Var _n cs -> Num `Set.member` cs + _ -> False + Eq -> \case + Unit -> True + Bool -> True + Int -> True + Double -> True + String -> True + Regexp -> True + Array t -> + satisfies Eq t + Record fs -> + all (satisfies Eq) fs + Var _n cs -> Eq `Set.member` cs + _ -> False + Display -> \case + Unit -> True + Bool -> True + Int -> True + Double -> True + String -> True + Array t -> + satisfies Display t + Record fs -> + all (satisfies Display) fs + Var _n cs -> Display `Set.member` cs + _ -> False + Sizeable -> \case + String -> True + Array _ -> True + Record _ -> True + Var _n cs -> Sizeable `Set.member` cs + _ -> False + Iterable -> \case + Array _ -> True + Record _ -> True + Var _n cs -> Iterable `Set.member` cs + _ -> False diff --git a/src/T/Value.hs b/src/T/Value.hs index b9cd4ea..1d2c7f8 100644 --- a/src/T/Value.hs +++ b/src/T/Value.hs @@ -101,9 +101,12 @@ typeOf = \case Double _ -> Type.Double String _ -> Type.String Regexp _ -> Type.Regexp - Array _ -> Type.Array (error "element") - Record _ -> Type.Record (error "fields") - Lam _ -> Type.Fun (error "args") (error "result") + Array _ -> + Type.Array (error "element") + Record fields -> + Type.Record (map typeOf fields) + Lam _ -> + Type.Fun (error "args") (error "result") embedAeson :: Aeson.Value -> Value embedAeson = \case From c7d7668e761dc9ecfc1b37c68908ff985a8b64cf Mon Sep 17 00:00:00 2001 From: Matvey Aksenov Date: Tue, 20 Jan 2026 10:57:51 +0000 Subject: [PATCH 11/13] separate Ann for TypedExp --- src/T/Exp.hs | 12 +++++----- src/T/Type/Infer.hs | 56 ++++++++++++++++++++++++++++++--------------- src/T/Type/Unify.hs | 5 ++-- src/T/Type/Vocab.hs | 9 +++++++- t.cabal | 3 +++ 5 files changed, 58 insertions(+), 27 deletions(-) diff --git a/src/T/Exp.hs b/src/T/Exp.hs index 2204a96..ab81ea2 100644 --- a/src/T/Exp.hs +++ b/src/T/Exp.hs @@ -41,20 +41,20 @@ import T.SExp (sexp) import T.SExp qualified as SExp -type Exp = Cofree ExpF Ann +type Exp = Cofree (ExpF Ann) Ann -data ExpF a +data ExpF ann a = Lit Literal -- ^ literals: 4, [1,2,3], {foo: 4} - | Var (Ann :+ Name) + | Var (ann :+ Name) -- ^ variable lookup: foo | If a a a -- ^ if-expression: if ... then ... else ... - | App (Ann :+ Name) (NonEmpty a) + | App (ann :+ Name) (NonEmpty a) -- ^ application: f(x) | Idx a a -- ^ array index access: xs[0] - | Key a (Ann :+ Name) + | Key a (ann :+ Name) -- ^ record property access: foo.bar deriving (Show, Eq, Generic1, Functor, Foldable, Traversable) @@ -73,7 +73,7 @@ instance SExp.To Exp where _ :< Key exp key -> SExp.round ["at-key", sexp key, sexp exp] -instance Eq1 ExpF where +instance Eq1 (ExpF ann) where liftEq _ (Lit l0) (Lit l1) = l0 == l1 liftEq _ (Var v0) (Var v1) = diff --git a/src/T/Type/Infer.hs b/src/T/Type/Infer.hs index 1a62535..d2b4efc 100644 --- a/src/T/Type/Infer.hs +++ b/src/T/Type/Infer.hs @@ -22,6 +22,7 @@ import T.Type.Vocab , freshVar , TypeError(..) , TypedExp + , Ann(..) , Type(..) , Scheme(..) ) @@ -48,23 +49,42 @@ inferTmpl = \case pure () inferExp :: Monad m => Exp -> InferenceT m TypedExp -inferExp (ann :< e) = do - te <- traverse inferExp e - inferredType <- case te of - Exp.Lit l -> - inferLiteral l - Exp.Var (_ann :+ name) -> - lookupCtx name - Exp.If b t f -> do - _ <- unify (extractType b) Bool - unify (extractType t) (extractType f) - Exp.App (_ann :+ name) args -> - checkApp name args - Exp.Idx arr idx -> - checkIdx arr idx - Exp.Key r (_ann :+ name) -> - checkKey r name - pure ((ann, inferredType) :< te) +inferExp (ann0 :< e) = + case e of + Exp.Lit l -> do + t <- inferLiteral l + pure (annOf t :< Exp.Lit l) + + Exp.Var (ann :+ name) -> do + t <- lookupCtx name + pure (annOf t :< Exp.Var (Ann {spanned = ann, typed = t} :+ name)) + + Exp.If b0 t0 f0 -> do + b1@(tb :< _) <- inferExp b0 + _ <- unify tb.typed Bool + t1@(tt :< _) <- inferExp t0 + f1@(tf :< _) <- inferExp f0 + tu <- unify tt.typed tf.typed + pure (annOf tu :< Exp.If b1 t1 f1) + + Exp.App (ann :+ name) args0 -> do + args1 <- traverse inferExp args0 + t <- checkApp name args1 + pure (annOf t :< Exp.App (Ann {spanned = ann, typed = t} :+ name) args1) + + Exp.Idx arr0 idx0 -> do + arr1 <- inferExp arr0 + idx1 <- inferExp idx0 + t <- checkIdx arr1 idx1 + pure (annOf t :< Exp.Idx arr1 idx1) + + Exp.Key r0 (ann :+ name) -> do + r1 <- inferExp r0 + t <- checkKey r1 name + pure (annOf t :< Exp.Key r1 (Ann {spanned = ann, typed = t} :+ name)) + where + annOf inferred = + Ann {spanned = ann0, typed = inferred} checkApp :: Monad m => Name -> NonEmpty TypedExp -> InferenceT m Type checkApp name args = do @@ -129,4 +149,4 @@ instantiate (Forall qs t) = do pure (replaceOnce (HashMap.fromList fvs) t) extractType :: TypedExp -> Type -extractType ((_ann, t) :< _e) = t +extractType (ann :< _e) = ann.typed diff --git a/src/T/Type/Unify.hs b/src/T/Type/Unify.hs index e2710dc..ec491fb 100644 --- a/src/T/Type/Unify.hs +++ b/src/T/Type/Unify.hs @@ -19,6 +19,7 @@ import T.Type.Vocab , Σ(..) , Subst , TypedExp + , Ann(..) , Type(..) , TypeError(..) , Constraint @@ -127,11 +128,11 @@ applyConstraints cs = \case extendSubst :: Monad m => Int -> Type -> InferenceT m () extendSubst n t = - modify (\s -> s { subst = HashMap.insert n t s.subst }) + modify (\s -> s {subst = HashMap.insert n t s.subst}) finalize :: Subst -> TypedExp -> TypedExp finalize subst = - map (\(ann, t) -> (ann, defaultType (replace subst t))) + map (\ann -> ann {typed = defaultType (replace subst ann.typed)}) defaultType :: Type -> Type defaultType = \case diff --git a/src/T/Type/Vocab.hs b/src/T/Type/Vocab.hs index 9b13ad0..7494ff8 100644 --- a/src/T/Type/Vocab.hs +++ b/src/T/Type/Vocab.hs @@ -10,6 +10,7 @@ module T.Type.Vocab , freshVar , TypeError(..) , TypedExp + , Ann(..) , Type(..) , freeVarsType , Scheme(..) @@ -26,6 +27,7 @@ import Data.List qualified as List import Data.Set (Set) import Data.Set qualified as Set import Text.Printf (printf) +import Text.Trifecta (Span) import T.Exp qualified as Exp import T.Name (Name) @@ -71,7 +73,12 @@ data TypeError | OccursCheck Int Type deriving (Show, Eq) -type TypedExp = Cofree Exp.ExpF (Exp.Ann, Type) +type TypedExp = Cofree (Exp.ExpF Ann) Ann + +data Ann = Ann + { spanned :: Span + , typed :: Type + } deriving (Show, Eq) data Type = Unit diff --git a/t.cabal b/t.cabal index 3dda618..f21477c 100644 --- a/t.cabal +++ b/t.cabal @@ -63,6 +63,9 @@ library T.Stdlib.Op T.Tmpl T.Type + T.Type.Infer + T.Type.Unify + T.Type.Vocab T.Value Meta_t other-modules: From 69c0f566e3f86d129a809cf5a6cf72a6e2a04aa2 Mon Sep 17 00:00:00 2001 From: Matvey Aksenov Date: Tue, 20 Jan 2026 19:15:36 +0000 Subject: [PATCH 12/13] sync typechecking and rendertime errors --- src/T/Embed.hs | 14 +++++++------- src/T/Error.hs | 14 +++++++------- src/T/Render.hs | 24 ++++++++++++------------ src/T/Stdlib/Op.hs | 12 ++++++------ src/T/Type/Infer.hs | 8 ++++---- src/T/Type/Vocab.hs | 5 ++--- test/T/RenderSpec.hs | 20 ++++++++++++-------- 7 files changed, 50 insertions(+), 47 deletions(-) diff --git a/src/T/Embed.hs b/src/T/Embed.hs index 3ceeab8..bb2bfdd 100644 --- a/src/T/Embed.hs +++ b/src/T/Embed.hs @@ -73,46 +73,46 @@ instance Eject Bool where Bool b -> pure b value -> - Left (TypeError (varE name) Type.Bool (typeOf value) (sexp value)) + Left (TagMismatch (varE name) Type.Bool (typeOf value) (sexp value)) instance Eject Int where eject name = \case Int n -> pure n value -> - Left (TypeError (varE name) Type.Int (typeOf value) (sexp value)) + Left (TagMismatch (varE name) Type.Int (typeOf value) (sexp value)) instance Eject Double where eject name = \case Double n -> pure n value -> - Left (TypeError (varE name) Type.Double (typeOf value) (sexp value)) + Left (TagMismatch (varE name) Type.Double (typeOf value) (sexp value)) instance Eject Text where eject name = \case String str -> pure str value -> - Left (TypeError (varE name) Type.String (typeOf value) (sexp value)) + Left (TagMismatch (varE name) Type.String (typeOf value) (sexp value)) instance Eject Pcre.Regex where eject name = \case Regexp regexp -> pure regexp value -> - Left (TypeError (varE name) Type.Regexp (typeOf value) (sexp value)) + Left (TagMismatch (varE name) Type.Regexp (typeOf value) (sexp value)) instance (k ~ Name, v ~ Value) => Eject (HashMap k v) where eject name = \case Record o -> pure o value -> - Left (TypeError (varE name) (Type.Record (error "fields")) (typeOf value) (sexp value)) + Left (TagMismatch (varE name) (Type.Record (error "fields")) (typeOf value) (sexp value)) instance Eject a => Eject [a] where eject name = \case Array xs -> map toList (traverse (eject name) xs) value -> - Left (TypeError (varE name) (Type.Array (error "element")) (typeOf value) (sexp value)) + Left (TagMismatch (varE name) (Type.Array (error "element")) (typeOf value) (sexp value)) diff --git a/src/T/Error.hs b/src/T/Error.hs index 1e90499..49cfef4 100644 --- a/src/T/Error.hs +++ b/src/T/Error.hs @@ -24,9 +24,9 @@ import T.SExp qualified as SExp data Error = NotInScope (Ann :+ Name) | OutOfBounds Exp SExp SExp - | MissingProperty Exp SExp SExp + | MissingField Exp SExp SExp | UserError (Ann :+ Name) Text - | TypeError Exp Type Type SExp + | TagMismatch Exp Type Type SExp | NotLValue Exp deriving (Show, Eq) @@ -41,7 +41,7 @@ prettyError = \case "index: " <> PP.pretty idx <> PP.line <> "is out of bounds for array: " <> PP.pretty array <> PP.line <> excerpt ann - MissingProperty (ann :< _) r key -> + MissingField (ann :< _) r key -> header ann <> "key: " <> PP.pretty key <> PP.line <> "is missing from the record: " <> PP.pretty r <> PP.line <> @@ -50,15 +50,15 @@ prettyError = \case header ann <> PP.pretty name <> ": " <> PP.pretty text <> PP.line <> excerpt ann - TypeError (ann :< Var (_ann :+ name)) expected actual value -> + TagMismatch (ann :< Var (_ann :+ name)) expected actual value -> header ann <> - "mismatched types in " <> PP.pretty name <> ": " <> PP.line <> + "mismatched [rendertime] types in " <> PP.pretty name <> ": " <> PP.line <> PP.indent 2 "expected: " <> PP.pretty (show expected) <> PP.line <> PP.indent 2 " but got: " <> PP.pretty value <> " : " <> PP.pretty (show actual) <> PP.line <> excerpt ann - TypeError (ann :< _) expected actual value -> + TagMismatch (ann :< _) expected actual value -> header ann <> - "mismatched types:" <> PP.line <> + "mismatched [rendertime] types:" <> PP.line <> PP.indent 2 "expected: " <> PP.pretty (show expected) <> PP.line <> PP.indent 2 " but got: " <> PP.pretty value <> " : " <> PP.pretty (show actual) <> PP.line <> excerpt ann diff --git a/src/T/Render.hs b/src/T/Render.hs index e9c9625..1247154 100644 --- a/src/T/Render.hs +++ b/src/T/Render.hs @@ -131,7 +131,7 @@ renderTmpl = \case (HashMap.toList o)) pure (bool (Just xs) Nothing (List.null xs)) _ -> - throwError (TypeError exp (error "Type.Iterable") (Value.typeOf value) (sexp value)) + throwError (TagMismatch exp (error "Type.Iterable") (Value.typeOf value) (sexp value)) case itemsQ of Nothing -> maybe (pure ()) renderTmpl elseTmpl @@ -169,7 +169,7 @@ renderExp exp = do Value.String str -> pure str _ -> - throwError (TypeError exp (error "Type.Renderable") (Value.typeOf value) (sexp value)) + throwError (TagMismatch exp (error "Type.Renderable") (Value.typeOf value) (sexp value)) evalExp :: (Ctx m, MonadError Error m) => Exp -> m Value evalExp = \case @@ -223,7 +223,7 @@ evalExp = \case r <- enforceRecord exp case HashMap.lookup key r of Nothing -> - throwError (MissingProperty exp (sexp r) (sexp key)) + throwError (MissingField exp (sexp r) (sexp key)) Just x -> pure x @@ -234,7 +234,7 @@ enforceInt exp = do Value.Int xs -> pure xs _ -> - throwError (TypeError exp Type.Int (Value.typeOf v) (sexp v)) + throwError (TagMismatch exp Type.Int (Value.typeOf v) (sexp v)) enforceArray :: (Ctx m, MonadError Error m) => Exp -> m (Vector Value) enforceArray exp = do @@ -243,7 +243,7 @@ enforceArray exp = do Value.Array xs -> pure xs _ -> - throwError (TypeError exp (Type.Array (Type.tyVar 0)) (Value.typeOf v) (sexp v)) + throwError (TagMismatch exp (Type.Array (Type.tyVar 0)) (Value.typeOf v) (sexp v)) enforceRecord :: (Ctx m, MonadError Error m) => Exp -> m (HashMap Name Value) enforceRecord exp = do @@ -252,7 +252,7 @@ enforceRecord exp = do Value.Record r -> pure r _ -> - throwError (TypeError exp (Type.Record mempty) (Value.typeOf v) (sexp v)) + throwError (TagMismatch exp (Type.Record mempty) (Value.typeOf v) (sexp v)) evalApp :: (Ctx m, MonadError Error m) @@ -273,7 +273,7 @@ evalApp name@(ann0 :+ _) = go g exps -- in every other case something went wrong :-( go v _ = - throwError (TypeError (ann0 :< Var name) (error "Type.Fun") (Value.typeOf v) (sexp v)) + throwError (TagMismatch (ann0 :< Var name) (error "Type.Fun") (Value.typeOf v) (sexp v)) data Path = Path { var :: Ann :+ Name @@ -320,7 +320,7 @@ evalLValue = Just v -> (Right v, path) Right v -> - throwError (TypeError exp (Type.Array (Type.tyVar 0)) (Value.typeOf v) (sexp v)) + throwError (TagMismatch exp (Type.Array (Type.tyVar 0)) (Value.typeOf v) (sexp v)) Left name -> throwError (NotInScope name) _ :< Key exp0 key@(_ :+ key0) -> do @@ -336,7 +336,7 @@ evalLValue = Just v -> (Right v, path) Right v -> - throwError (TypeError exp (Type.Record mempty) (Value.typeOf v) (sexp v)) + throwError (TagMismatch exp (Type.Record mempty) (Value.typeOf v) (sexp v)) Left name -> throwError (NotInScope name) @@ -393,9 +393,9 @@ insertVar Path {var = (ann :+ name), lookups} v = do [] -> pure (Value.Record (HashMap.insert (fromString (Name.toString key)) v r)) _ -> - throwError (MissingProperty (ann0 :< Lit Null) (sexp r) (sexp key)) + throwError (MissingField (ann0 :< Lit Null) (sexp r) (sexp key)) go v0 (K (ann0 :+ _key) : _path) = - throwError (TypeError (ann0 :< Lit Null) (Type.Record mempty) (Value.typeOf v0) (sexp v0)) + throwError (TagMismatch (ann0 :< Lit Null) (Type.Record mempty) (Value.typeOf v0) (sexp v0)) go (Value.Array xs) (I (ann0 :+ idx) : path) = -- this is pretty similar to records except the lack of the aforementioned special -- treatment. @@ -406,7 +406,7 @@ insertVar Path {var = (ann :+ name), lookups} v = do Nothing -> throwError (OutOfBounds (ann0 :< Lit Null) (sexp xs) (sexp idx)) go v0 (I (ann0 :+ _idx) : _path) = - throwError (TypeError (ann0 :< Lit Null) (Type.Record mempty) (Value.typeOf v0) (sexp v0)) + throwError (TagMismatch (ann0 :< Lit Null) (Type.Record mempty) (Value.typeOf v0) (sexp v0)) go _v0 [] = pure v diff --git a/src/T/Stdlib/Op.hs b/src/T/Stdlib/Op.hs index a8639c8..ca2cdc9 100644 --- a/src/T/Stdlib/Op.hs +++ b/src/T/Stdlib/Op.hs @@ -115,15 +115,15 @@ combineNumbers intOp doubleOp name = _ann :+ Int n1 -> pure (Int (n0 `intOp` n1)) ann :+ n -> - Left (TypeError (varE (ann :+ name)) Type.Int (typeOf n) (sexp n)) + Left (TagMismatch (varE (ann :+ name)) Type.Int (typeOf n) (sexp n)) _ann :+ Double n0 -> pure . Lam $ \case _ann :+ Double n1 -> pure (Double (n0 `doubleOp` n1)) ann :+ n -> - Left (TypeError (varE (ann :+ name)) Type.Double (typeOf n) (sexp n)) + Left (TagMismatch (varE (ann :+ name)) Type.Double (typeOf n) (sexp n)) ann :+ n -> - Left (TypeError (varE (ann :+ name)) (error "Type.Number") (typeOf n) (sexp n)) + Left (TagMismatch (varE (ann :+ name)) (error "Type.Number") (typeOf n) (sexp n)) add :: Name -> Value add = @@ -149,15 +149,15 @@ predicateNumbers intOp doubleOp name = _ann :+ Int n1 -> pure (Bool (n0 `intOp` n1)) ann :+ n -> - Left (TypeError (varE (ann :+ name)) Type.Int (typeOf n) (sexp n)) + Left (TagMismatch (varE (ann :+ name)) Type.Int (typeOf n) (sexp n)) _ann :+ Double n0 -> pure . Lam $ \case _ann :+ Double n1 -> pure (Bool (n0 `doubleOp` n1)) ann :+ n -> - Left (TypeError (varE (ann :+ name)) Type.Double (typeOf n) (sexp n)) + Left (TagMismatch (varE (ann :+ name)) Type.Double (typeOf n) (sexp n)) ann :+ n -> - Left (TypeError (varE (ann :+ name)) (error "Type.Number") (typeOf n) (sexp n)) + Left (TagMismatch (varE (ann :+ name)) (error "Type.Number") (typeOf n) (sexp n)) lt :: Name -> Value lt = diff --git a/src/T/Type/Infer.hs b/src/T/Type/Infer.hs index d2b4efc..de4cab0 100644 --- a/src/T/Type/Infer.hs +++ b/src/T/Type/Infer.hs @@ -108,13 +108,13 @@ checkKey r name = do Just t -> pure t Nothing -> - throwError (MissingKey name) + throwError (MissingField name) var@(Var _n _cs) -> do v <- freshVar _ <- unify var (Record (HashMap.singleton name v)) pure v - _ -> - throwError (NotARecord (extractType r)) + t -> + throwError (TypeMismatch (Record mempty) t) inferLiteral :: Monad m => Exp.Literal -> InferenceT m Type inferLiteral = \case @@ -139,7 +139,7 @@ inferLiteral = \case lookupCtx :: Monad m => Name -> InferenceT m Type lookupCtx name = do ctx <- ask - maybe (throwError (MissingVar name)) instantiate (HashMap.lookup name ctx) + maybe (throwError (NotInScope name)) instantiate (HashMap.lookup name ctx) instantiate :: Monad m => Scheme -> InferenceT m Type instantiate (Forall qs t) = do diff --git a/src/T/Type/Vocab.hs b/src/T/Type/Vocab.hs index 7494ff8..82c2770 100644 --- a/src/T/Type/Vocab.hs +++ b/src/T/Type/Vocab.hs @@ -65,9 +65,8 @@ freshVar = do pure (Var n mempty) data TypeError - = MissingKey Name - | MissingVar Name - | NotARecord Type + = NotInScope Name + | MissingField Name | TypeMismatch Type Type | ConstraintViolation Constraint Type | OccursCheck Int Type diff --git a/test/T/RenderSpec.hs b/test/T/RenderSpec.hs index b1ae1e7..d93c665 100644 --- a/test/T/RenderSpec.hs +++ b/test/T/RenderSpec.hs @@ -65,9 +65,9 @@ spec = r_ "{{ [1,2,3][0] }}" `shouldRender` "1" r_ "{{ [[1,2],[3]][0][1] }}" `shouldRender` "2" r_ "{{ 4[0] }}" `shouldRaise` - error "TypeError" (litE_ (Int 4)) Type.Array Type.Int "4" + TagMismatch (litE_ (Int 4)) (Type.Array (Type.tyVar 0)) Type.Int "4" r_ "{{ [1,2,3][\"foo\"] }}" `shouldRaise` - TypeError (litE_ (String "foo")) Type.Int Type.String "\"foo\"" + TagMismatch (litE_ (String "foo")) Type.Int Type.String "\"foo\"" r_ "{{ [1,2,3][-1] }}" `shouldRaise` OutOfBounds (int (-1)) (sexp (array [int 1, int 2, int 3])) "-1" context "keying" $ @@ -76,8 +76,8 @@ spec = r_ "{{ {foo: [1,2,3]}.foo[0] }}" `shouldRender` "1" r_ "{{ {foo: [1,{bar: 7},3]}.foo[1].bar }}" `shouldRender` "7" r_ "{{ 4.foo }}" `shouldRaise` - error "TypeError" (litE_ (Int 4)) "Type.Record" Type.Int "4" - r_ "{{ {}.foo }}" `shouldRaise` MissingProperty (record mempty) (sexp (record mempty)) "foo" + TagMismatch (litE_ (Int 4)) (Type.Record mempty) Type.Int "4" + r_ "{{ {}.foo }}" `shouldRaise` MissingField (record mempty) (sexp (record mempty)) "foo" context "line blocks" $ it "examples" $ do @@ -267,16 +267,20 @@ spec = r_ "{{ \"Foo\" =~ /foo/i }}" `shouldRender` "true" it "not-iterable" $ - r_ "{% for x in 4 %}{% endfor %}" `shouldRaise` error "TypeError" (litE_ (Int 4)) "Type.Iterable" Type.Int "4" + r_ "{% for x in 4 %}{% endfor %}" `shouldRaise` + TagMismatch (litE_ (Int 4)) (Type.Var 0 [Type.Iterable]) Type.Int "4" it "not-renderable" $ - r_ "{{ [] }}" `shouldRaise` error "TypeError" (array []) error "Type.Renderable" "Type.Array" (sexp (array [])) + r_ "{{ [] }}" `shouldRaise` + TagMismatch (array []) (Type.Var 0 [Type.Display]) (Type.Array (Type.tyVar 0)) (sexp (array [])) it "not-a-function" $ - rWith [aesonQQ|{f: "foo"}|] "{{ f(4) }}" `shouldRaise` error "TypeError" (varE "f") "Type.Fun" Type.String "\"foo\"" + rWith [aesonQQ|{f: "foo"}|] "{{ f(4) }}" `shouldRaise` + TagMismatch (varE "f") (Type.tyVar 0 `Type.fun1` Type.tyVar 1) Type.String "\"foo\"" it "type errors" $ - r_ "{{ bool01(\"foo\") }}" `shouldRaise` TypeError (varE "bool01") Type.Bool Type.String "\"foo\"" + r_ "{{ bool01(\"foo\") }}" `shouldRaise` + TagMismatch (varE "bool01") Type.Bool Type.String "\"foo\"" it "defined?" $ rWith [aesonQQ|{foo: {}}|] "{{ defined?(foo.bar.baz) }}" `shouldRender` "false" From 344340054b2704d3dce06ce9c60a14c8c0000752 Mon Sep 17 00:00:00 2001 From: Matvey Aksenov Date: Tue, 20 Jan 2026 19:47:01 +0000 Subject: [PATCH 13/13] back to working tests --- src/T/Embed.hs | 4 +-- src/T/Error.hs | 11 +++++---- src/T/Parse.hs | 2 +- src/T/Render.hs | 6 ++--- src/T/Stdlib/Op.hs | 5 ++-- src/T/Type/Vocab.hs | 27 ++++++++++++++++++--- src/T/Value.hs | 8 ++++-- test/T/Parse/AnnSpec.hs | 54 ++++++++++++++++++++--------------------- test/T/RenderSpec.hs | 2 +- 9 files changed, 73 insertions(+), 46 deletions(-) diff --git a/src/T/Embed.hs b/src/T/Embed.hs index bb2bfdd..133817e 100644 --- a/src/T/Embed.hs +++ b/src/T/Embed.hs @@ -108,11 +108,11 @@ instance (k ~ Name, v ~ Value) => Eject (HashMap k v) where Record o -> pure o value -> - Left (TagMismatch (varE name) (Type.Record (error "fields")) (typeOf value) (sexp value)) + Left (TagMismatch (varE name) (Type.Record mempty) (typeOf value) (sexp value)) instance Eject a => Eject [a] where eject name = \case Array xs -> map toList (traverse (eject name) xs) value -> - Left (TagMismatch (varE name) (Type.Array (error "element")) (typeOf value) (sexp value)) + Left (TagMismatch (varE name) (Type.Array (Type.tyVar 0)) (typeOf value) (sexp value)) diff --git a/src/T/Error.hs b/src/T/Error.hs index 49cfef4..f34142c 100644 --- a/src/T/Error.hs +++ b/src/T/Error.hs @@ -5,6 +5,7 @@ module T.Error , prettyWarning ) where +import Data.Text.Internal.Builder qualified as Builder import Prettyprinter (Doc) import Prettyprinter qualified as PP import Prettyprinter.Render.Terminal (AnsiStyle) @@ -53,18 +54,18 @@ prettyError = \case TagMismatch (ann :< Var (_ann :+ name)) expected actual value -> header ann <> "mismatched [rendertime] types in " <> PP.pretty name <> ": " <> PP.line <> - PP.indent 2 "expected: " <> PP.pretty (show expected) <> PP.line <> - PP.indent 2 " but got: " <> PP.pretty value <> " : " <> PP.pretty (show actual) <> PP.line <> + PP.indent 2 "expected: " <> PP.pretty (Builder.toLazyText (SExp.render expected)) <> PP.line <> + PP.indent 2 " but got: " <> PP.pretty value <> " : " <> PP.pretty (Builder.toLazyText (SExp.render actual)) <> PP.line <> excerpt ann TagMismatch (ann :< _) expected actual value -> header ann <> "mismatched [rendertime] types:" <> PP.line <> - PP.indent 2 "expected: " <> PP.pretty (show expected) <> PP.line <> - PP.indent 2 " but got: " <> PP.pretty value <> " : " <> PP.pretty (show actual) <> PP.line <> + PP.indent 2 "expected: " <> PP.pretty (Builder.toLazyText (SExp.render expected)) <> PP.line <> + PP.indent 2 " but got: " <> PP.pretty value <> " : " <> PP.pretty (Builder.toLazyText (SExp.render actual)) <> PP.line <> excerpt ann NotLValue exp@(ann :< _) -> header ann <> - "expected an L-Value, but got something else: " <> fromString (show (SExp.render (SExp.sexp exp))) <> + "expected an L-Value, but got something else: " <> fromString (show (SExp.render exp)) <> excerpt ann where header (Tri.Span from _to _line) = diff --git a/src/T/Parse.hs b/src/T/Parse.hs index 59aaf7d..f61dd92 100644 --- a/src/T/Parse.hs +++ b/src/T/Parse.hs @@ -42,7 +42,7 @@ import T.Exp.Ann (anning, anned) import T.Name (Name(..)) import T.Name qualified as Name import T.Parse.Macro qualified as Macro -import T.Prelude +import T.Prelude hiding (for) import T.Tmpl (Tmpl) import T.Tmpl qualified as Tmpl import T.Stdlib (Stdlib(..)) diff --git a/src/T/Render.hs b/src/T/Render.hs index 1247154..5ab7716 100644 --- a/src/T/Render.hs +++ b/src/T/Render.hs @@ -131,7 +131,7 @@ renderTmpl = \case (HashMap.toList o)) pure (bool (Just xs) Nothing (List.null xs)) _ -> - throwError (TagMismatch exp (error "Type.Iterable") (Value.typeOf value) (sexp value)) + throwError (TagMismatch exp (Type.Var 0 (Set.fromList [Type.Iterable])) (Value.typeOf value) (sexp value)) case itemsQ of Nothing -> maybe (pure ()) renderTmpl elseTmpl @@ -169,7 +169,7 @@ renderExp exp = do Value.String str -> pure str _ -> - throwError (TagMismatch exp (error "Type.Renderable") (Value.typeOf value) (sexp value)) + throwError (TagMismatch exp (Type.Var 0 (Set.fromList [Type.Render])) (Value.typeOf value) (sexp value)) evalExp :: (Ctx m, MonadError Error m) => Exp -> m Value evalExp = \case @@ -273,7 +273,7 @@ evalApp name@(ann0 :+ _) = go g exps -- in every other case something went wrong :-( go v _ = - throwError (TagMismatch (ann0 :< Var name) (error "Type.Fun") (Value.typeOf v) (sexp v)) + throwError (TagMismatch (ann0 :< Var name) (Type.tyVar 0 `Type.fun1` Type.tyVar 1) (Value.typeOf v) (sexp v)) data Path = Path { var :: Ann :+ Name diff --git a/src/T/Stdlib/Op.hs b/src/T/Stdlib/Op.hs index ca2cdc9..e2c9a72 100644 --- a/src/T/Stdlib/Op.hs +++ b/src/T/Stdlib/Op.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} @@ -123,7 +124,7 @@ combineNumbers intOp doubleOp name = ann :+ n -> Left (TagMismatch (varE (ann :+ name)) Type.Double (typeOf n) (sexp n)) ann :+ n -> - Left (TagMismatch (varE (ann :+ name)) (error "Type.Number") (typeOf n) (sexp n)) + Left (TagMismatch (varE (ann :+ name)) (Type.Var 0 [Type.Num]) (typeOf n) (sexp n)) add :: Name -> Value add = @@ -157,7 +158,7 @@ predicateNumbers intOp doubleOp name = ann :+ n -> Left (TagMismatch (varE (ann :+ name)) Type.Double (typeOf n) (sexp n)) ann :+ n -> - Left (TagMismatch (varE (ann :+ name)) (error "Type.Number") (typeOf n) (sexp n)) + Left (TagMismatch (varE (ann :+ name)) (Type.Var 0 [Type.Num]) (typeOf n) (sexp n)) lt :: Name -> Value lt = diff --git a/src/T/Type/Vocab.hs b/src/T/Type/Vocab.hs index 82c2770..7771a54 100644 --- a/src/T/Type/Vocab.hs +++ b/src/T/Type/Vocab.hs @@ -112,8 +112,20 @@ instance SExp.To Type where SExp.curly (concatMap (\(k, v) -> [sexp k, sexp v]) (HashMap.toList fs)) Fun args r -> SExp.round ["->", SExp.square (toList (map sexp args)), sexp r] - Var n cs -> - fromString (printf "#%d{%s}" n (List.intercalate ", " (map show (toList cs)))) + Var n cs + | Set.null cs -> + fromString (printf "#%d" n) + | otherwise -> + fromString (printf "#%d{%s}" n (List.intercalate ", " (map showConstraint (toList cs)))) + where + showConstraint = \case + Num -> "num" :: String + Eq -> "eq" + Render -> "render" + Display -> "display" + Sizeable -> "sizeable" + Iterable -> "iterable" + data Scheme = Forall (Set Int) Type deriving (Show, Eq) @@ -138,7 +150,8 @@ freeVarsScheme (Forall qs t) = data Constraint = Num | Eq - | Display + | Render -- can be rendered directly as the value of a {{ }} chunk + | Display -- i.e. Show | Sizeable | Iterable deriving (Show, Eq, Ord) @@ -163,6 +176,14 @@ satisfies = \case all (satisfies Eq) fs Var _n cs -> Eq `Set.member` cs _ -> False + Render -> \case + Unit -> True + Bool -> True + Int -> True + Double -> True + String -> True + Var _n cs -> Render `Set.member` cs + _ -> False Display -> \case Unit -> True Bool -> True diff --git a/src/T/Value.hs b/src/T/Value.hs index 1d2c7f8..10d9a93 100644 --- a/src/T/Value.hs +++ b/src/T/Value.hs @@ -101,8 +101,12 @@ typeOf = \case Double _ -> Type.Double String _ -> Type.String Regexp _ -> Type.Regexp - Array _ -> - Type.Array (error "element") + Array xs -> + case toList xs of + [] -> + Type.Array (Type.tyVar 0) + x : _xs -> + Type.Array (typeOf x) Record fields -> Type.Record (map typeOf fields) Lam _ -> diff --git a/test/T/Parse/AnnSpec.hs b/test/T/Parse/AnnSpec.hs index 587ea76..d2b1cae 100644 --- a/test/T/Parse/AnnSpec.hs +++ b/test/T/Parse/AnnSpec.hs @@ -12,74 +12,74 @@ spec = context "literals" $ do it "null" $ errorOf "{{ null + 1 }}" `shouldBe` - "(interactive):1:4: error: mismatched types in +: \n\ - \ expected: Number\n\ - \ but got: null : Null\n\ + "(interactive):1:4: error: mismatched [rendertime] types in +: \n\ + \ expected: #0{num}\n\ + \ but got: null : unit\n\ \1 | {{ null + 1 }} \n\ \ | ~~~~ " it "bool" $ errorOf "{{ true + 1 }}" `shouldBe` - "(interactive):1:4: error: mismatched types in +: \n\ - \ expected: Number\n\ - \ but got: true : Bool\n\ + "(interactive):1:4: error: mismatched [rendertime] types in +: \n\ + \ expected: #0{num}\n\ + \ but got: true : bool\n\ \1 | {{ true + 1 }} \n\ \ | ~~~~ " it "int" $ errorOf "{{ 1 <> \"foo\" }}" `shouldBe` - "(interactive):1:4: error: mismatched types in <>: \n\ - \ expected: String\n\ - \ but got: 1 : Int\n\ + "(interactive):1:4: error: mismatched [rendertime] types in <>: \n\ + \ expected: string\n\ + \ but got: 1 : int\n\ \1 | {{ 1 <> \"foo\" }} \n\ \ | ~ " it "double" $ errorOf "{{ 1.0 <> \"foo\" }}" `shouldBe` - "(interactive):1:4: error: mismatched types in <>: \n\ - \ expected: String\n\ - \ but got: 1.0 : Double\n\ + "(interactive):1:4: error: mismatched [rendertime] types in <>: \n\ + \ expected: string\n\ + \ but got: 1.0 : double\n\ \1 | {{ 1.0 <> \"foo\" }} \n\ \ | ~~~ " it "string" $ errorOf "{{ \"foo\" + 1 }}" `shouldBe` - "(interactive):1:4: error: mismatched types in +: \n\ - \ expected: Number\n\ - \ but got: \"foo\" : String\n\ + "(interactive):1:4: error: mismatched [rendertime] types in +: \n\ + \ expected: #0{num}\n\ + \ but got: \"foo\" : string\n\ \1 | {{ \"foo\" + 1 }} \n\ \ | ~~~~~ " it "regexp" $ do errorOf "{{ /foo/ + 1 }}" `shouldBe` - "(interactive):1:4: error: mismatched types in +: \n\ - \ expected: Number\n\ - \ but got: (regexp \"foo\") : Regexp\n\ + "(interactive):1:4: error: mismatched [rendertime] types in +: \n\ + \ expected: #0{num}\n\ + \ but got: (regexp \"foo\") : regexp\n\ \1 | {{ /foo/ + 1 }} \n\ \ | ~~~~~ " it "array" $ do errorOf "{{ [1,2,3] + 1 }}" `shouldBe` - "(interactive):1:4: error: mismatched types in +: \n\ - \ expected: Number\n\ - \ but got: [1 2 3] : Array\n\ + "(interactive):1:4: error: mismatched [rendertime] types in +: \n\ + \ expected: #0{num}\n\ + \ but got: [1 2 3] : [int]\n\ \1 | {{ [1,2,3] + 1 }} \n\ \ | ~~~~~~~ " it "record" $ do errorOf "{{ {foo:4} + 1 }}" `shouldBe` - "(interactive):1:4: error: mismatched types in +: \n\ - \ expected: Number\n\ - \ but got: {foo 4} : Record\n\ + "(interactive):1:4: error: mismatched [rendertime] types in +: \n\ + \ expected: #0{num}\n\ + \ but got: {foo 4} : {foo int}\n\ \1 | {{ {foo:4} + 1 }} \n\ \ | ~~~~~~~ " context "property access" $ do it "record" $ do errorOf "{% set foo = {} foo.bar = [] %}{{ foo.bar }}" `shouldBe` - "(interactive):1:38: error: mismatched types:\n\ - \ expected: Renderable\n\ - \ but got: [] : Array\n\ + "(interactive):1:38: error: mismatched [rendertime] types:\n\ + \ expected: #0{render}\n\ + \ but got: [] : [#0]\n\ \1 | {% set foo = {} foo.bar = [] %}{{ foo.bar }} \n\ \ | ~~~~ " diff --git a/test/T/RenderSpec.hs b/test/T/RenderSpec.hs index d93c665..f21fd31 100644 --- a/test/T/RenderSpec.hs +++ b/test/T/RenderSpec.hs @@ -272,7 +272,7 @@ spec = it "not-renderable" $ r_ "{{ [] }}" `shouldRaise` - TagMismatch (array []) (Type.Var 0 [Type.Display]) (Type.Array (Type.tyVar 0)) (sexp (array [])) + TagMismatch (array []) (Type.Var 0 [Type.Render]) (Type.Array (Type.tyVar 0)) (sexp (array [])) it "not-a-function" $ rWith [aesonQQ|{f: "foo"}|] "{{ f(4) }}" `shouldRaise`