Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dune-project
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
dune
base
menhir
re2
ppx_jane
ppx_deriving
(core_unix :with-test)
Expand Down
10 changes: 10 additions & 0 deletions lib/IntermediateLanguages/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ module Type = struct
and 's forall = ('s, ('s, Kind.t) Source.annotate) abstraction
and 's pi = ('s, ('s, Sort.t) Source.annotate) abstraction
and 's sigma = ('s, ('s, Sort.t) Source.annotate) abstraction
and 's tuple = ('s, 's t list) Source.annotate

and 's raw =
| Ref of ref
Expand All @@ -59,6 +60,7 @@ module Type = struct
| Forall of 's forall
| Pi of 's pi
| Sigma of 's sigma
| Tuple of 's tuple

and 's t = ('s, 's raw) Source.annotate [@@deriving sexp_of]
end
Expand Down Expand Up @@ -151,6 +153,12 @@ module Expr = struct

and 's reifyShape = 's Index.t
and 's reifyDimension = 's Index.t
and 's tupleExpr = ('s, 's t list) Source.annotate

and 's tupleDeref =
{ tuple : 's t
; position : int
}

and 's raw =
| Ref of ref
Expand All @@ -171,6 +179,8 @@ module Expr = struct
| Reshape of 's reshape
| ReifyShape of 's reifyShape
| ReifyDimension of 's reifyDimension
| TupleExpr of 's tupleExpr
| TupleDeref of 's tupleDeref
| IntLiteral of int
| FloatLiteral of float
| CharacterLiteral of char
Expand Down
41 changes: 41 additions & 0 deletions lib/IntermediateLanguages/Explicit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,17 @@ module Expr = struct
and primitive = Typed.Expr.primitive
and literal = Typed.Expr.literal

and tupleExpr =
{ elements : atom list
; type' : Type.tuple
}

and tupleDeref =
{ expr : array
; position : int
; type' : Type.array
}

and contiguousSubArray =
{ arrayArg : array
; indexArg : array
Expand Down Expand Up @@ -134,13 +145,15 @@ module Expr = struct
| Primitive of primitive
| ContiguousSubArray of contiguousSubArray
| Map of map
| TupleDeref of tupleDeref

and atom =
| TermLambda of termLambda
| TypeLambda of typeLambda
| IndexLambda of indexLambda
| Box of box
| Literal of literal
| TupleExpr of tupleExpr

and t =
| Array of array
Expand All @@ -152,6 +165,7 @@ module Expr = struct
| TypeLambda typeLambda -> Forall typeLambda.type'
| IndexLambda indexLambda -> Pi indexLambda.type'
| Box box -> Sigma box.type'
| TupleExpr tuple -> Tuple tuple.type'
| Literal (IntLiteral _) -> Literal IntLiteral
| Literal (FloatLiteral _) -> Literal FloatLiteral
| Literal (CharacterLiteral _) -> Literal CharacterLiteral
Expand All @@ -170,6 +184,7 @@ module Expr = struct
| ReifyIndex reifyIndex -> Arr reifyIndex.type'
| Primitive primitive -> primitive.type'
| ContiguousSubArray contiguousSubArray -> contiguousSubArray.type'
| TupleDeref tupleDeref -> tupleDeref.type'
| Map map -> map.type'
;;
end
Expand Down Expand Up @@ -256,6 +271,12 @@ module Substitute = struct
; l
; type' = Type.subTypesIntoArray types type'
}
| TupleDeref { expr; position; type' } ->
TupleDeref
{ expr = subTypesIntoArray types expr
; position
; type' = Type.subTypesIntoArray types type'
}

and subTypesIntoRef types { id; type' } =
{ id; type' = Type.subTypesIntoArray types type' }
Expand All @@ -281,6 +302,11 @@ module Substitute = struct
; bodyType = Type.subTypesIntoArray types bodyType
; type' = Type.subTypesIntoSigma types type'
}
| TupleExpr { elements; type' } ->
TupleExpr
{ elements = List.map ~f:(subTypesIntoAtom types) elements
; type' = List.map ~f:(Type.subTypesIntoAtom types) type'
}
| Literal _ as literal -> literal

and subTypesIntoTermLambda types { params; body; type' } =
Expand Down Expand Up @@ -372,6 +398,12 @@ module Substitute = struct
; frameShape = Index.subIndicesIntoShape indices frameShape
; type' = Type.subIndicesIntoArray indices type'
}
| TupleDeref { expr; position; type' } ->
TupleDeref
{ expr = subIndicesIntoArray indices expr
; position
; type' = Type.subIndicesIntoArray indices type'
}

and subIndicesIntoRef indices { id; type' } =
{ id; type' = Type.subIndicesIntoArray indices type' }
Expand All @@ -397,6 +429,11 @@ module Substitute = struct
; bodyType = Type.subIndicesIntoArray indices bodyType
; type' = Type.subIndicesIntoSigma indices type'
}
| TupleExpr { elements; type' } ->
TupleExpr
{ elements = List.map ~f:(subIndicesIntoAtom indices) elements
; type' = List.map ~f:(Type.subIndicesIntoAtom indices) type'
}
| Literal _ as literal -> literal

and subIndicesIntoTermLambda indices { params; body; type' } =
Expand Down Expand Up @@ -465,6 +502,8 @@ module Substitute = struct
; frameShape
; type'
}
| TupleDeref { expr; position; type' } ->
TupleDeref { expr = subRefsIntoArray refs expr; position; type' }

and subRefsIntoAtom refs =
let open Expr in
Expand All @@ -477,6 +516,8 @@ module Substitute = struct
IndexLambda { params; body = subRefsIntoArray refs body; type' }
| Box { indices; body; bodyType; type' } ->
Box { indices; body = subRefsIntoArray refs body; bodyType; type' }
| TupleExpr { elements; type' } ->
TupleExpr { elements = List.map ~f:(subRefsIntoAtom refs) elements; type' }
| Literal _ as literal -> literal

and subRefsIntoRef refs { id; type' } =
Expand Down
44 changes: 44 additions & 0 deletions lib/IntermediateLanguages/Typed.ml
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,16 @@ module Type = struct
| ArrayRef of Identifier.t
| Arr of arr

and tuple = atom list

and atom =
| AtomRef of Identifier.t
| Func of func
| Forall of forall
| Pi of pi
| Sigma of sigma
| Literal of literal
| Tuple of tuple

and t =
| Array of array
Expand Down Expand Up @@ -267,6 +270,12 @@ module Expr = struct
; type' : Type.array [@sexp_drop_if fun _ -> true]
}

and tupleDeref =
{ expr : array
; position : int
; type' : Type.array
}

and literal =
| IntLiteral of int
| FloatLiteral of float
Expand All @@ -286,13 +295,20 @@ module Expr = struct
| Primitive of primitive
| Lift of lift
| ContiguousSubArray of contiguousSubArray
| TupleDeref of tupleDeref

and tuple =
{ elements : atom list
; type' : Type.tuple
}

and atom =
| TermLambda of termLambda
| TypeLambda of typeLambda
| IndexLambda of indexLambda
| Box of box
| Literal of literal
| TupleExpr of tuple

and t =
| Array of array
Expand All @@ -304,6 +320,7 @@ module Expr = struct
| TypeLambda typeLambda -> Forall typeLambda.type'
| IndexLambda indexLambda -> Pi indexLambda.type'
| Box box -> Sigma box.type'
| TupleExpr tuple -> Tuple tuple.type'
| Literal (IntLiteral _) -> Literal IntLiteral
| Literal (FloatLiteral _) -> Literal FloatLiteral
| Literal (CharacterLiteral _) -> Literal CharacterLiteral
Expand All @@ -323,6 +340,7 @@ module Expr = struct
| Primitive primitive -> primitive.type'
| Lift lift -> lift.type'
| ContiguousSubArray contiguousSubArray -> contiguousSubArray.type'
| TupleDeref tupleDeref -> tupleDeref.type'
;;

let type' : t -> Type.t = function
Expand Down Expand Up @@ -443,13 +461,16 @@ end = struct
| ArrayRef of Ref.t
| Arr of arr

and tuple = atom list

and atom =
| AtomRef of Ref.t
| Func of func
| Forall of forall
| Pi of pi
| Sigma of sigma
| Literal of literal
| Tuple of tuple

and t =
| Array of array
Expand Down Expand Up @@ -505,6 +526,7 @@ end = struct
let depth = depth + 1 in
arrayFrom env depth body)
}
| Type.Tuple tupleElements -> Tuple (List.map ~f:(atomFrom env depth) tupleElements)
| Type.Literal IntLiteral -> Literal IntLiteral
| Type.Literal FloatLiteral -> Literal FloatLiteral
| Type.Literal CharacterLiteral -> Literal CharacterLiteral
Expand Down Expand Up @@ -592,6 +614,8 @@ module Substitute = struct
| Forall forall -> Forall (subIndicesIntoForall indices forall)
| Pi pi -> Pi (subIndicesIntoPi indices pi)
| Sigma sigma -> Sigma (subIndicesIntoSigma indices sigma)
| Tuple tupleElements ->
Tuple (List.map ~f:(subIndicesIntoAtom indices) tupleElements)
| Literal IntLiteral -> Literal IntLiteral
| Literal FloatLiteral -> Literal FloatLiteral
| Literal CharacterLiteral -> Literal CharacterLiteral
Expand Down Expand Up @@ -648,6 +672,7 @@ module Substitute = struct
| Forall forall -> Forall (subTypesIntoForall types forall)
| Pi pi -> Pi (subTypesIntoPi types pi)
| Sigma sigma -> Sigma (subTypesIntoSigma types sigma)
| Tuple tupleElements -> Tuple (List.map ~f:(subTypesIntoAtom types) tupleElements)
| Literal IntLiteral -> Literal IntLiteral
| Literal FloatLiteral -> Literal FloatLiteral
| Literal CharacterLiteral -> Literal CharacterLiteral
Expand Down Expand Up @@ -749,6 +774,12 @@ module Substitute = struct
; l
; type' = Type.subTypesIntoArray types type'
}
| TupleDeref { expr; position; type' } ->
TupleDeref
{ expr = subTypesIntoArray types expr
; position
; type' = Type.subTypesIntoArray types type'
}

and subTypesIntoAtom types = function
| TermLambda lambda -> TermLambda (subTypesIntoTermLambda types lambda)
Expand All @@ -771,6 +802,8 @@ module Substitute = struct
; bodyType = Type.subTypesIntoArray types bodyType
; type' = Type.subTypesIntoSigma types type'
}
| TupleExpr { elements; type' } ->
TupleExpr { elements = List.map ~f:(subTypesIntoAtom types) elements; type' }
| Literal _ as literal -> literal

and subTypesIntoTermLambda types { params; body; type' } =
Expand Down Expand Up @@ -857,6 +890,12 @@ module Substitute = struct
; l = Index.subIndicesIntoDim indices l
; type' = Type.subIndicesIntoArray indices type'
}
| TupleDeref { expr; position; type' } ->
TupleDeref
{ expr = subIndicesIntoArray indices expr
; position
; type' = Type.subIndicesIntoArray indices type'
}

and subIndicesIntoAtom indices = function
| TermLambda lambda -> TermLambda (subIndicesIntoTermLambda indices lambda)
Expand All @@ -879,6 +918,11 @@ module Substitute = struct
; bodyType = Type.subIndicesIntoArray indices bodyType
; type' = Type.subIndicesIntoSigma indices type'
}
| TupleExpr { elements; type' } ->
TupleExpr
{ elements = List.map ~f:(subIndicesIntoAtom indices) elements
; type' = List.map ~f:(Type.subIndicesIntoAtom indices) type'
}
| Literal _ as literal -> literal

and subIndicesIntoTermLambda indices { params; body; type' } =
Expand Down
Loading