diff --git a/dune-project b/dune-project index 8b82118..8074932 100644 --- a/dune-project +++ b/dune-project @@ -26,6 +26,7 @@ dune base menhir + re2 ppx_jane ppx_deriving (core_unix :with-test) diff --git a/lib/IntermediateLanguages/Ast.ml b/lib/IntermediateLanguages/Ast.ml index 7f1f2bb..c56d748 100644 --- a/lib/IntermediateLanguages/Ast.ml +++ b/lib/IntermediateLanguages/Ast.ml @@ -51,6 +51,7 @@ module Type = struct and 's forall = ('s, ('s, Kind.t) Source.annotate) abstraction and 's pi = ('s, ('s, Sort.t) Source.annotate) abstraction and 's sigma = ('s, ('s, Sort.t) Source.annotate) abstraction + and 's tuple = ('s, 's t list) Source.annotate and 's raw = | Ref of ref @@ -59,6 +60,7 @@ module Type = struct | Forall of 's forall | Pi of 's pi | Sigma of 's sigma + | Tuple of 's tuple and 's t = ('s, 's raw) Source.annotate [@@deriving sexp_of] end @@ -151,6 +153,12 @@ module Expr = struct and 's reifyShape = 's Index.t and 's reifyDimension = 's Index.t + and 's tupleExpr = ('s, 's t list) Source.annotate + + and 's tupleDeref = + { tuple : 's t + ; position : int + } and 's raw = | Ref of ref @@ -171,6 +179,8 @@ module Expr = struct | Reshape of 's reshape | ReifyShape of 's reifyShape | ReifyDimension of 's reifyDimension + | TupleExpr of 's tupleExpr + | TupleDeref of 's tupleDeref | IntLiteral of int | FloatLiteral of float | CharacterLiteral of char diff --git a/lib/IntermediateLanguages/Explicit.ml b/lib/IntermediateLanguages/Explicit.ml index 370cae5..ee3aa23 100644 --- a/lib/IntermediateLanguages/Explicit.ml +++ b/lib/IntermediateLanguages/Explicit.ml @@ -104,6 +104,17 @@ module Expr = struct and primitive = Typed.Expr.primitive and literal = Typed.Expr.literal + and tupleExpr = + { elements : atom list + ; type' : Type.tuple + } + + and tupleDeref = + { expr : array + ; position : int + ; type' : Type.array + } + and contiguousSubArray = { arrayArg : array ; indexArg : array @@ -134,6 +145,7 @@ module Expr = struct | Primitive of primitive | ContiguousSubArray of contiguousSubArray | Map of map + | TupleDeref of tupleDeref and atom = | TermLambda of termLambda @@ -141,6 +153,7 @@ module Expr = struct | IndexLambda of indexLambda | Box of box | Literal of literal + | TupleExpr of tupleExpr and t = | Array of array @@ -152,6 +165,7 @@ module Expr = struct | TypeLambda typeLambda -> Forall typeLambda.type' | IndexLambda indexLambda -> Pi indexLambda.type' | Box box -> Sigma box.type' + | TupleExpr tuple -> Tuple tuple.type' | Literal (IntLiteral _) -> Literal IntLiteral | Literal (FloatLiteral _) -> Literal FloatLiteral | Literal (CharacterLiteral _) -> Literal CharacterLiteral @@ -170,6 +184,7 @@ module Expr = struct | ReifyIndex reifyIndex -> Arr reifyIndex.type' | Primitive primitive -> primitive.type' | ContiguousSubArray contiguousSubArray -> contiguousSubArray.type' + | TupleDeref tupleDeref -> tupleDeref.type' | Map map -> map.type' ;; end @@ -256,6 +271,12 @@ module Substitute = struct ; l ; type' = Type.subTypesIntoArray types type' } + | TupleDeref { expr; position; type' } -> + TupleDeref + { expr = subTypesIntoArray types expr + ; position + ; type' = Type.subTypesIntoArray types type' + } and subTypesIntoRef types { id; type' } = { id; type' = Type.subTypesIntoArray types type' } @@ -281,6 +302,11 @@ module Substitute = struct ; bodyType = Type.subTypesIntoArray types bodyType ; type' = Type.subTypesIntoSigma types type' } + | TupleExpr { elements; type' } -> + TupleExpr + { elements = List.map ~f:(subTypesIntoAtom types) elements + ; type' = List.map ~f:(Type.subTypesIntoAtom types) type' + } | Literal _ as literal -> literal and subTypesIntoTermLambda types { params; body; type' } = @@ -372,6 +398,12 @@ module Substitute = struct ; frameShape = Index.subIndicesIntoShape indices frameShape ; type' = Type.subIndicesIntoArray indices type' } + | TupleDeref { expr; position; type' } -> + TupleDeref + { expr = subIndicesIntoArray indices expr + ; position + ; type' = Type.subIndicesIntoArray indices type' + } and subIndicesIntoRef indices { id; type' } = { id; type' = Type.subIndicesIntoArray indices type' } @@ -397,6 +429,11 @@ module Substitute = struct ; bodyType = Type.subIndicesIntoArray indices bodyType ; type' = Type.subIndicesIntoSigma indices type' } + | TupleExpr { elements; type' } -> + TupleExpr + { elements = List.map ~f:(subIndicesIntoAtom indices) elements + ; type' = List.map ~f:(Type.subIndicesIntoAtom indices) type' + } | Literal _ as literal -> literal and subIndicesIntoTermLambda indices { params; body; type' } = @@ -465,6 +502,8 @@ module Substitute = struct ; frameShape ; type' } + | TupleDeref { expr; position; type' } -> + TupleDeref { expr = subRefsIntoArray refs expr; position; type' } and subRefsIntoAtom refs = let open Expr in @@ -477,6 +516,8 @@ module Substitute = struct IndexLambda { params; body = subRefsIntoArray refs body; type' } | Box { indices; body; bodyType; type' } -> Box { indices; body = subRefsIntoArray refs body; bodyType; type' } + | TupleExpr { elements; type' } -> + TupleExpr { elements = List.map ~f:(subRefsIntoAtom refs) elements; type' } | Literal _ as literal -> literal and subRefsIntoRef refs { id; type' } = diff --git a/lib/IntermediateLanguages/Typed.ml b/lib/IntermediateLanguages/Typed.ml index eb0fc76..29226d7 100644 --- a/lib/IntermediateLanguages/Typed.ml +++ b/lib/IntermediateLanguages/Typed.ml @@ -82,6 +82,8 @@ module Type = struct | ArrayRef of Identifier.t | Arr of arr + and tuple = atom list + and atom = | AtomRef of Identifier.t | Func of func @@ -89,6 +91,7 @@ module Type = struct | Pi of pi | Sigma of sigma | Literal of literal + | Tuple of tuple and t = | Array of array @@ -267,6 +270,12 @@ module Expr = struct ; type' : Type.array [@sexp_drop_if fun _ -> true] } + and tupleDeref = + { expr : array + ; position : int + ; type' : Type.array + } + and literal = | IntLiteral of int | FloatLiteral of float @@ -286,6 +295,12 @@ module Expr = struct | Primitive of primitive | Lift of lift | ContiguousSubArray of contiguousSubArray + | TupleDeref of tupleDeref + + and tuple = + { elements : atom list + ; type' : Type.tuple + } and atom = | TermLambda of termLambda @@ -293,6 +308,7 @@ module Expr = struct | IndexLambda of indexLambda | Box of box | Literal of literal + | TupleExpr of tuple and t = | Array of array @@ -304,6 +320,7 @@ module Expr = struct | TypeLambda typeLambda -> Forall typeLambda.type' | IndexLambda indexLambda -> Pi indexLambda.type' | Box box -> Sigma box.type' + | TupleExpr tuple -> Tuple tuple.type' | Literal (IntLiteral _) -> Literal IntLiteral | Literal (FloatLiteral _) -> Literal FloatLiteral | Literal (CharacterLiteral _) -> Literal CharacterLiteral @@ -323,6 +340,7 @@ module Expr = struct | Primitive primitive -> primitive.type' | Lift lift -> lift.type' | ContiguousSubArray contiguousSubArray -> contiguousSubArray.type' + | TupleDeref tupleDeref -> tupleDeref.type' ;; let type' : t -> Type.t = function @@ -443,6 +461,8 @@ end = struct | ArrayRef of Ref.t | Arr of arr + and tuple = atom list + and atom = | AtomRef of Ref.t | Func of func @@ -450,6 +470,7 @@ end = struct | Pi of pi | Sigma of sigma | Literal of literal + | Tuple of tuple and t = | Array of array @@ -505,6 +526,7 @@ end = struct let depth = depth + 1 in arrayFrom env depth body) } + | Type.Tuple tupleElements -> Tuple (List.map ~f:(atomFrom env depth) tupleElements) | Type.Literal IntLiteral -> Literal IntLiteral | Type.Literal FloatLiteral -> Literal FloatLiteral | Type.Literal CharacterLiteral -> Literal CharacterLiteral @@ -592,6 +614,8 @@ module Substitute = struct | Forall forall -> Forall (subIndicesIntoForall indices forall) | Pi pi -> Pi (subIndicesIntoPi indices pi) | Sigma sigma -> Sigma (subIndicesIntoSigma indices sigma) + | Tuple tupleElements -> + Tuple (List.map ~f:(subIndicesIntoAtom indices) tupleElements) | Literal IntLiteral -> Literal IntLiteral | Literal FloatLiteral -> Literal FloatLiteral | Literal CharacterLiteral -> Literal CharacterLiteral @@ -648,6 +672,7 @@ module Substitute = struct | Forall forall -> Forall (subTypesIntoForall types forall) | Pi pi -> Pi (subTypesIntoPi types pi) | Sigma sigma -> Sigma (subTypesIntoSigma types sigma) + | Tuple tupleElements -> Tuple (List.map ~f:(subTypesIntoAtom types) tupleElements) | Literal IntLiteral -> Literal IntLiteral | Literal FloatLiteral -> Literal FloatLiteral | Literal CharacterLiteral -> Literal CharacterLiteral @@ -749,6 +774,12 @@ module Substitute = struct ; l ; type' = Type.subTypesIntoArray types type' } + | TupleDeref { expr; position; type' } -> + TupleDeref + { expr = subTypesIntoArray types expr + ; position + ; type' = Type.subTypesIntoArray types type' + } and subTypesIntoAtom types = function | TermLambda lambda -> TermLambda (subTypesIntoTermLambda types lambda) @@ -771,6 +802,8 @@ module Substitute = struct ; bodyType = Type.subTypesIntoArray types bodyType ; type' = Type.subTypesIntoSigma types type' } + | TupleExpr { elements; type' } -> + TupleExpr { elements = List.map ~f:(subTypesIntoAtom types) elements; type' } | Literal _ as literal -> literal and subTypesIntoTermLambda types { params; body; type' } = @@ -857,6 +890,12 @@ module Substitute = struct ; l = Index.subIndicesIntoDim indices l ; type' = Type.subIndicesIntoArray indices type' } + | TupleDeref { expr; position; type' } -> + TupleDeref + { expr = subIndicesIntoArray indices expr + ; position + ; type' = Type.subIndicesIntoArray indices type' + } and subIndicesIntoAtom indices = function | TermLambda lambda -> TermLambda (subIndicesIntoTermLambda indices lambda) @@ -879,6 +918,11 @@ module Substitute = struct ; bodyType = Type.subIndicesIntoArray indices bodyType ; type' = Type.subIndicesIntoSigma indices type' } + | TupleExpr { elements; type' } -> + TupleExpr + { elements = List.map ~f:(subIndicesIntoAtom indices) elements + ; type' = List.map ~f:(Type.subIndicesIntoAtom indices) type' + } | Literal _ as literal -> literal and subIndicesIntoTermLambda indices { params; body; type' } = diff --git a/lib/Stages/Explicitize.ml b/lib/Stages/Explicitize.ml index 0e48d52..c7bcb6d 100644 --- a/lib/Stages/Explicitize.ml +++ b/lib/Stages/Explicitize.ml @@ -34,6 +34,7 @@ let rec funcParamNamesArray env : Typed.Expr.array -> string list option = funct | Let l -> let env = Map.set env ~key:l.binding ~data:(funcParamNamesArray env l.value) in funcParamNamesArray env l.body + | TupleDeref { expr; position = _; type' = _ } -> funcParamNamesArray env expr | Primitive p -> (match p.name with | Func func -> @@ -78,12 +79,22 @@ let rec funcParamNamesArray env : Typed.Expr.array -> string list option = funct List.init (List.length argTypes) ~f:(fun i -> [%string "%{name}-arg%{i#Int}"])) | Val _ -> None) +and funcParamNamesTuple env : Typed.Expr.tuple -> string list option = function + | { elements; type' = _ } -> + Some + (elements + |> List.map ~f:(funcParamNamesAtom env) + |> List.bind ~f:(function + | None -> [] + | Some list -> list)) + and funcParamNamesAtom env : Typed.Expr.atom -> string list option = function | TermLambda lambda -> Some (List.map lambda.params ~f:(fun p -> Identifier.name p.binding)) | TypeLambda lambda -> funcParamNamesArray env lambda.body | IndexLambda lambda -> funcParamNamesArray env lambda.body | Box box -> funcParamNamesArray env box.body + | TupleExpr tuple -> funcParamNamesTuple env tuple | Literal _ -> None ;; @@ -243,6 +254,37 @@ let rec explicitizeArray paramNamesEnv array let%map body = explicitizeArray extendedParamNamesEnv body in E.Map { body; args = [ { binding; value } ]; frameShape = []; type' } | T.Primitive { name; type' } -> return (E.Primitive { name; type' }) + | T.TupleDeref { expr; position; type' } -> + let%map expr = explicitizeArray paramNamesEnv expr + and mapArg = ExplicitState.createId "tupleDerefMapArg" in + let exprType = E.arrayType expr in + let tupleType = + match exprType with + | Arr { element; shape = _ } -> Explicit.Type.Arr { element; shape = [] } + | ArrayRef _ -> raise Unreachable.default + in + let atomType, shape = + match type' with + | Arr { element; shape } -> Explicit.Type.Arr { element; shape = [] }, shape + | ArrayRef _ -> raise Unreachable.default + in + E.Map + { body = + (* E.Scalar *) + (* { element = *) + E.TupleDeref + { expr = E.Ref { id = mapArg; type' = tupleType } + ; position + ; type' = atomType + } + (* ; type' = { element = atomType; shape = [] } *) + (* } *) + ; args = [ { binding = mapArg; value = expr } ] + ; frameShape = shape + ; type' + } +(* let%bind expr = explicitizeArray paramNamesEnv expr in *) +(* return (E.TupleDeref { expr; position; type' }) *) and explicitizeAtom paramNamesEnv atom : (CompilerState.state, Explicit.Expr.atom, _) ExplicitState.t @@ -263,6 +305,11 @@ and explicitizeAtom paramNamesEnv atom | T.Box { indices; body; bodyType; type' } -> let%map body = explicitizeArray paramNamesEnv body in E.Box { indices; body; bodyType; type' } + | T.TupleExpr { elements; type' } -> + let%map elements = + elements |> List.map ~f:(explicitizeAtom paramNamesEnv) |> ExplicitState.all + in + E.TupleExpr { elements; type' } | T.Literal literal -> return (E.Literal literal) and explicitizeTermApplication diff --git a/lib/Stages/Fuse.ml b/lib/Stages/Fuse.ml index f641ba6..4317a50 100644 --- a/lib/Stages/Fuse.ml +++ b/lib/Stages/Fuse.ml @@ -1124,6 +1124,7 @@ let rec fuseLoops (scope : Set.M(Identifier).t) (match%bind tryFusing opportunity with | Some result -> let%bind () = FuseState.markFusion in + (* return result *) fuseLoops scope result | None -> tryFusingList rest) | [] -> return (Let { args; body; type' }) diff --git a/lib/Stages/FuseAndSimplify.ml b/lib/Stages/FuseAndSimplify.ml index 9cae62c..320f651 100644 --- a/lib/Stages/FuseAndSimplify.ml +++ b/lib/Stages/FuseAndSimplify.ml @@ -7,6 +7,8 @@ let rec fuseAndSimplify (prog : Nested.t) : (CompilerState.state, Nested.t, _) S if fusionResult.fusedAny then fuseAndSimplify fusionResult.result else return simplified ;; +(* return simplified *) + module Stage (SB : Source.BuilderT) = struct type state = CompilerState.state type input = Nested.t diff --git a/lib/Stages/Inline.ml b/lib/Stages/Inline.ml index 062bd72..04fd509 100644 --- a/lib/Stages/Inline.ml +++ b/lib/Stages/Inline.ml @@ -137,6 +137,11 @@ let rec inlineAtomTypeWithStack appStack : Typed.Type.atom -> Nucleus.Type.array | Literal FloatLiteral -> { element = Literal FloatLiteral; shape = [] } | Literal BooleanLiteral -> { element = Literal BooleanLiteral; shape = [] } | Func _ -> { element = Tuple []; shape = [] } + | Tuple t -> + { element = + Tuple (List.map ~f:(fun e -> (inlineAtomTypeWithStack appStack e).element) t) + ; shape = [] + } | Forall { parameters; body } -> (match appStack with | TypeApp types :: restStack -> @@ -213,6 +218,7 @@ let assertValueRestriction value = | Forall _ -> true | Pi _ -> true | Sigma sigma -> isPolymorphicArray sigma.body + | Tuple elements -> List.for_all ~f:isPolymorphicAtom elements | Literal IntLiteral -> false | Literal FloatLiteral -> false | Literal CharacterLiteral -> false @@ -235,11 +241,13 @@ let assertValueRestriction value = | Primitive _ -> true | Map _ -> false | ContiguousSubArray _ -> false + | TupleDeref tupleDeref -> isValueArray tupleDeref.expr and isValueAtom = function | TermLambda _ -> true | TypeLambda _ -> true | IndexLambda _ -> true | Box box -> isValueArray box.body + | TupleExpr { elements; type' = _ } -> List.for_all ~f:isValueAtom elements | Literal (IntLiteral _ | FloatLiteral _ | CharacterLiteral _ | BooleanLiteral _) -> true in @@ -324,6 +332,9 @@ let rec genNewBindingsArray env (expr : Explicit.Expr.array) = return @@ Expr.ContiguousSubArray { arrayArg; indexArg; originalShape; resultShape; cellShape; l; type' } + | TupleDeref { expr; position; type' } -> + let%bind expr = genNewBindingsArray env expr in + return @@ Expr.TupleDeref { expr; position; type' } and genNewBindingsAtom env (expr : Explicit.Expr.atom) = let open Explicit in @@ -341,6 +352,11 @@ and genNewBindingsAtom env (expr : Explicit.Expr.atom) = | Box { indices; bodyType; body; type' } -> let%bind body = genNewBindingsArray env body in return @@ Expr.Box { indices; bodyType; body; type' } + | TupleExpr { elements; type' } -> + let%bind elements = + elements |> List.map ~f:(genNewBindingsAtom env) |> InlineState.all + in + return @@ Expr.TupleExpr { elements; type' } | Literal (IntLiteral _ | FloatLiteral _ | CharacterLiteral _ | BooleanLiteral _) as lit -> return lit @@ -518,6 +534,30 @@ let rec inlineArray indexEnv (appStack : appStack) (array : Explicit.Expr.array) ; type' = inlineArrayTypeWithStack [] type' }) , functions ) + | TupleDeref { expr; position; type' } -> + let arrayType = E.arrayType expr in + let atomArrayType = + match arrayType with + | Arr { element; shape = _ } -> element + | ArrayRef _ -> raise Unreachable.default + in + let inlineAtomArrayType = inlineAtomTypeWithStack [] atomArrayType in + let inlineType' = inlineArrayTypeWithStack [] type' in + assert (List.is_empty inlineAtomArrayType.shape); + assert (List.is_empty inlineType'.shape); + let%map expr, functions = inlineArray indexEnv appStack expr in + let newDeref = + I.TupleDeref + { tuple = I.ArrayAsAtom { array = expr; type' = inlineAtomArrayType.element } + ; index = position + ; type' = inlineType'.element + } + in + let newArrayDeref = + I.AtomAsArray + { element = newDeref; type' = { element = inlineType'.element; shape = [] } } + in + newArrayDeref, functions and inlineAtom indexEnv (appStack : appStack) (atom : Explicit.Expr.atom) : (Nucleus.Expr.array * FunctionSet.t) InlineState.u @@ -588,6 +628,7 @@ and inlineAtom indexEnv (appStack : appStack) (atom : Explicit.Expr.atom) Set.diff variablesUsed variablesDeclared | E.ReifyIndex { index; type' = _ } -> indexCaptures index | E.Primitive { name = _; type' = _ } -> Set.empty (module Identifier) + | E.TupleDeref { expr; position = _; type' = _ } -> arrayCaptures expr | E.ContiguousSubArray { arrayArg; indexArg; originalShape; resultShape; cellShape; l; type' = _ } -> let arrayArgCaptures = arrayCaptures arrayArg @@ -627,6 +668,8 @@ and inlineAtom indexEnv (appStack : appStack) (atom : Explicit.Expr.atom) let indexCaptures = List.map indices ~f:indexCaptures and bodyCaptures = arrayCaptures body in Set.union_list (module Identifier) (bodyCaptures :: indexCaptures) + | E.TupleExpr { elements; type' = _ } -> + Set.union_list (module Identifier) (List.map ~f:atomCaptures elements) | E.Literal (IntLiteral _ | FloatLiteral _ | CharacterLiteral _ | BooleanLiteral _) -> Set.empty (module Identifier) in @@ -688,6 +731,21 @@ and inlineAtom indexEnv (appStack : appStack) (atom : Explicit.Expr.atom) ; type' = inlineSigmaTypeWithStack appStack type' }) , functions ) + | TupleExpr { elements; type' = _ } -> + let%bind elements, functions = + List.map ~f:(inlineAtom indexEnv appStack) elements + |> InlineState.all + |> InlineState.unzip + in + let functions = FunctionSet.merge functions in + let elements, type' = + elements + |> List.map ~f:(fun elt -> + let type' = (I.arrayType elt).element in + Nucleus.Expr.ArrayAsAtom { array = elt; type' }, type') + |> List.unzip + in + return (scalar (I.Values { elements; type' }), functions) | Literal (CharacterLiteral c) -> return (scalar (I.Literal (CharacterLiteral c)), FunctionSet.Empty) | Literal (IntLiteral i) -> return (scalar (I.Literal (IntLiteral i)), FunctionSet.Empty) diff --git a/lib/Stages/Parse.ml b/lib/Stages/Parse.ml index 8f78fb0..60e2b19 100644 --- a/lib/Stages/Parse.ml +++ b/lib/Stages/Parse.ml @@ -248,6 +248,18 @@ module Make (SB : Source.BuilderT) = struct } ; source = esexpSource arrExp } + | ParenList + { elements = Symbol ("Values", _) :: tupleElements + ; braceSources = _, rParenSource + } as tupleExp -> + let%map parsedTuple = + parseInfixList + ~f:parseType + ~before:(esexpSource tupleExp) + ~after:rParenSource + tupleElements + in + Source.{ elem = Type.Tuple parsedTuple; source = esexpSource tupleExp } | type' -> MResult.err ("Bad type syntax", esexpSource type') and parseExpr : 's Esexp.t -> ('s Expr.t, error) MResult.t = @@ -361,6 +373,30 @@ module Make (SB : Source.BuilderT) = struct ; source } | [] -> MResult.err ("[] sugar can only be used for non-empty frames", source)) + | ParenList + { elements = Symbol ("values", _) :: tupleElements + ; braceSources = lParenSource, rParenSource + } as tupleExpr -> + let%bind tupleElements = + parseInfixList tupleElements ~f:parseExpr ~before:lParenSource ~after:rParenSource + in + MOk Source.{ elem = Expr.TupleExpr tupleElements; source = esexpSource tupleExpr } + | ParenList { elements = [ Symbol (accessor, _); tupleElt ]; braceSources = _, _ } as + derefExp + when Re2.matches (Re2.create_exn "#([0-9]+)") accessor -> + let index = + Int.of_string + (Re2.find_first_exn + ?sub:(Some (`Index 1)) + (Re2.create_exn "#([0-9]+)") + accessor) + in + let%bind parsedExpr = parseExpr tupleElt in + MOk + Source. + { elem = Expr.TupleDeref { tuple = parsedExpr; position = index } + ; source = esexpSource derefExp + } | ParenList { elements = Symbol ("array", arrSource) :: components ; braceSources = _, rParenSource diff --git a/lib/Stages/Simplify.ml b/lib/Stages/Simplify.ml index 7f4a63f..fd73592 100644 --- a/lib/Stages/Simplify.ml +++ b/lib/Stages/Simplify.ml @@ -594,8 +594,8 @@ let rec optimize : Expr.t -> Expr.t = |> List.mapi ~f:(fun i e -> i, e) |> List.for_all ~f:(fun (expectedIndex, e) -> match e with - | TupleDeref { tuple = Ref eRef; index = eI; type' = _ } - when Identifier.equal ref.id eRef.id && eI = expectedIndex -> true + | TupleDeref { tuple = Ref eRef; index = eI; type' = _ } -> + Identifier.equal ref.id eRef.id && eI = expectedIndex | _ -> false) in if sizesMatch && allSequentialDerefs then Ref ref else values diff --git a/lib/Stages/TupleReduce.ml b/lib/Stages/TupleReduce.ml index a89352a..b3a576d 100644 --- a/lib/Stages/TupleReduce.ml +++ b/lib/Stages/TupleReduce.ml @@ -928,7 +928,8 @@ let rec reduceTuplesInExpr (request : TupleRequest.t) expr = in let mapResults, _ = List.unzip resultIdsAndRequests in let resultRequestsFromMap = - List.map resultIdsAndRequests ~f:(fun (resultId, request) -> resultId, request) + resultIdsAndRequests + (* List.map resultIdsAndRequests ~f:(fun (resultId, request) -> resultId, request) *) in let resultRequestsFromConsumer = consumerUsages @@ -986,8 +987,6 @@ let rec reduceTuplesInExpr (request : TupleRequest.t) expr = | _ -> raise @@ Unreachable.Error "expected array type" in let unpackersRaw = createUnpackersFromCache cache [] ~insideWhole:false in - let argRef = Expr.Ref { id = binding; type' = argArrayType.element } in - let unpackers = List.map unpackersRaw ~f:(fun unpacker -> unpacker argRef) in let argRequest = TupleRequest.Collection { subRequest = createRequestFromCache cache @@ -996,7 +995,14 @@ let rec reduceTuplesInExpr (request : TupleRequest.t) expr = in let%map value = reduceTuplesInExpr argRequest (Expr.Ref ref) and valueBinding = ReduceTupleState.createId (Identifier.name ref.id) in + let valueType = + match Expr.type' value with + | Array array -> array + | _ -> raise @@ Unreachable.Error "expected array type" + in + let argRef = Expr.Ref { id = binding; type' = valueType.element } in let valueRef : Expr.ref = { id = valueBinding; type' = Expr.type' value } in + let unpackers = List.map unpackersRaw ~f:(fun unpacker -> unpacker argRef) in let valueUnpacker : Expr.letArg = { binding = valueBinding; value } in Expr.{ binding; ref = valueRef }, unpackers, valueUnpacker)) |> List.filter_opt diff --git a/lib/Stages/TypeCheck.ml b/lib/Stages/TypeCheck.ml index 2f4880c..72891c1 100644 --- a/lib/Stages/TypeCheck.ml +++ b/lib/Stages/TypeCheck.ml @@ -58,6 +58,11 @@ type error = ; ref : Identifier.t } | LiftIndexValueNotInteger of Typed.Type.atom + | TupleDerefNotTuple of Typed.Type.t + | TupleDerefTupleNotEnoughElements of + { tuple : Typed.Type.tuple + ; position : int + } | WrongFunctionBodyType of { expected : Typed.Type.t ; actual : Typed.Type.t @@ -136,6 +141,9 @@ end = struct | Sigma { parameters; body } -> let paramsString = parameters |> showList showParam in [%string "(Σ (%{paramsString}) %{showArray body})"] + | Tuple elements -> + let elementsString = elements |> showList showAtom in + [%string "(Values %{elementsString})"] | Literal IntLiteral -> "int" | Literal FloatLiteral -> "float" | Literal CharacterLiteral -> "char" @@ -227,6 +235,18 @@ let errorMessage = function ref}`, but lifting to a shape requires ending in a dimension"] | LiftIndexValueNotInteger t -> [%string "Lifted index must have integer elements, got `%{Show.type' (Atom t)}`"] + | TupleDerefNotTuple t -> + [%string + "Tried to dereference a type that is not a tuple or tuple array: `%{Show.type' t}`"] + | TupleDerefTupleNotEnoughElements { tuple; position } -> + let tupleStrings = + String.concat + ?sep:(Some ", ") + (List.map ~f:(fun e -> [%string "%{Show.type' (Atom e)}"]) tuple) + in + [%string + "Not enough elements in tuple to dereference: tuple `%{tupleStrings}`, position \ + %{position#Int}"] | WrongFunctionBodyType { expected; actual } -> [%string "Function was declared to have return type `%{Show.type' expected}`,got \ @@ -266,6 +286,8 @@ let errorType = function | LiftShapeGotRef _ | WrongFunctionBodyType _ | LiftShapeGotScalar + | TupleDerefNotTuple _ + | TupleDerefTupleNotEnoughElements _ | LiftIndexValueNotInteger _ -> `Type ;; @@ -528,6 +550,11 @@ module KindChecker = struct let extendedEnv = { env with sorts = extendeSorts } in let%map body = checkAndExpectArray extendedEnv body in T.Atom (T.Sigma { parameters; body }) + | U.Tuple elements -> + let%map atoms = + elements.elem |> List.map ~f:(checkAndExpectAtom env) |> CompilerState.all + in + T.Atom (T.Tuple atoms) and checkAndExpectArray env type' = let open Typed.Type in @@ -626,6 +653,12 @@ module TypeCheck = struct | Reshape _ -> ok () | ReifyDimension _ -> ok () | ReifyShape _ -> ok () + | TupleExpr elements -> + elements.elem + |> List.map ~f:requireValue + |> CheckerState.all + |> CheckerState.ignore_m + | TupleDeref { tuple; position = _ } -> requireValue tuple | IntLiteral _ -> ok () | FloatLiteral _ -> ok () | CharacterLiteral _ -> ok () @@ -692,6 +725,8 @@ module TypeCheck = struct } in findInType extendedEnv (Array body) + | Atom (Tuple elements) -> + List.bind ~f:(fun elt -> findInType env (Atom elt)) elements | Atom (Literal IntLiteral) -> [] | Atom (Literal FloatLiteral) -> [] | Atom (Literal CharacterLiteral) -> [] @@ -1482,11 +1517,57 @@ module TypeCheck = struct ; body = bodyTyped ; type' = T.arrayType bodyTyped }) + | U.TupleExpr t -> + let%bind elements = + t.elem |> List.map ~f:(checkAndExpectAtom env) |> CheckerState.all + in + let type' = List.map ~f:T.atomType elements in + CheckerState.return (T.Atom (TupleExpr { elements; type' })) + | U.TupleDeref { tuple; position } -> + (* type check tuple *) + (* if atom -> grab the nth type *) + (* if array -> find the 'frame', check that atom is tuple and grab its nth type *) + let%bind typedTuple = check env tuple in + let typedTuple = + match typedTuple with + | T.Atom a -> + T.Scalar + { element = a; type' = { element = Typed.Expr.atomType a; shape = [] } } + | T.Array a -> a + in + let%bind type' = extractTupleArray env typedTuple position tuple.source in + CompilerState.return (T.Array (T.TupleDeref { expr = typedTuple; position; type' })) | U.IntLiteral i -> CheckerState.return (T.Atom (Literal (IntLiteral i))) | U.FloatLiteral f -> CheckerState.return (T.Atom (Literal (FloatLiteral f))) | U.CharacterLiteral c -> CheckerState.return (T.Atom (Literal (CharacterLiteral c))) | U.BooleanLiteral b -> CheckerState.return (T.Atom (Literal (BooleanLiteral b))) + and extractTupleArrayType + (env : Environment.t) + (arrType : Typed.Type.array) + position + source + = + match arrType with + | ArrayRef id -> + let type' = Map.find_exn env.types (Identifier.name id) in + extractTupleArray env type' position source + | Arr { element = Typed.Type.Tuple elements; shape } -> + let%bind extractedAtom = extractTupleAtomType elements position source in + CheckerState.return (Typed.Type.Arr { element = extractedAtom; shape }) + | Arr _ -> CheckerState.err { source; elem = TupleDerefNotTuple (Array arrType) } + + and extractTupleArray env array position source = + let type' = Typed.Expr.arrayType array in + extractTupleArrayType env type' position source + + and extractTupleAtomType atomType position source = + if List.length atomType <= position + then + CheckerState.err + { source; elem = TupleDerefTupleNotEnoughElements { tuple = atomType; position } } + else CheckerState.return (List.nth_exn atomType position) + and checkAndExpectArray env expr = let open Typed.Expr in match%bind check env expr with diff --git a/lib/dune b/lib/dune index 86a32ca..5af135e 100644 --- a/lib/dune +++ b/lib/dune @@ -9,7 +9,7 @@ (name remora) (public_name remora) (modules :standard EsexpParser EsexpLexer) - (libraries base) + (libraries base re2) (preprocess (pps ppx_jane ppx_deriving.eq ppx_deriving.show ppx_deriving.ord))) diff --git a/remora.opam b/remora.opam index 153ed92..269b845 100644 --- a/remora.opam +++ b/remora.opam @@ -14,6 +14,7 @@ depends: [ "dune" {>= "3.7"} "base" "menhir" + "re2" "ppx_jane" "ppx_deriving" "core_unix" {with-test}