Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/nfuzz.ml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ let relative_error1 ty =
let relative_error2 si =
match si with SiConst f -> Some (f /. (1.0 -. f)) | _ -> None


let type_check program =
let ty = Ty_bi.get_type program in
main_info dp "Type of the program: @[%a@]" Print.pp_type ty;
Expand Down Expand Up @@ -185,7 +186,8 @@ let main () =
(* Print the results of the parsing phase *)
main_debug dp "Parsed program:@\n@[%a@]@." Print.pp_term program;

(* if comp_enabled TypeChecker then type_check program; *)
let paired_program = Paired.lower_term_to_core program in
if comp_enabled TypeChecker then type_check paired_program;

(if comp_enabled Backend then
match !outfile with
Expand Down
76 changes: 76 additions & 0 deletions src/paired.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
open Syntax
open Support.Error
open Support.Options

(* Assume that we do not have subtraction.
If we encounter a subtraction operation, fail loudly.
Otherwise, term_core contains the same information as term.*)
let lower_op_to_core (op: op) finfo : op_core =
match op with
| AddOp -> AddOpCore
| SubOp ->
error_msg General finfo "Subtraction is not supported as a core operation."
| MulOp -> MulOpCore
| SqrtOp -> SqrtOpCore
| DivOp -> DivOpCore
| GtOp -> GtOpCore
| EqOp -> EqOpCore

(* Lower a program of type term to core, assuming that
subtraction is not used. *)
let rec lower_term_to_core (program: term) : term_core =
match program with
| TmPrim (finfo, x) -> TmPrimCore (finfo, x)
| TmVar (finfo, x) -> TmVarCore (finfo, x)
| TmLet (finfo, b, c, tm1, tm2) -> TmLetCore (finfo, b, c, lower_term_to_core tm1, lower_term_to_core tm2)
| TmAbs (finfo, b, c, tm) -> TmAbsCore (finfo, b, c, lower_term_to_core tm)
| TmRnd16 (finfo, tm) -> TmRnd16Core (finfo, lower_term_to_core tm)
| TmRnd32 (finfo, tm) -> TmRnd32Core (finfo, lower_term_to_core tm)
| TmRnd64 (finfo, tm) -> TmRnd64Core (finfo, lower_term_to_core tm)
| TmRet (finfo, tm) -> TmRetCore (finfo, lower_term_to_core tm)
| TmOp (finfo, op, tm) -> TmOpCore (finfo, lower_op_to_core op finfo, lower_term_to_core tm) (* handle this later *)
| TmBox (finfo, b, tm) -> TmBoxCore (finfo, b, lower_term_to_core tm)
| TmAmp1 (finfo, tm) -> TmAmp1Core (finfo, lower_term_to_core tm)
| TmAmp2 (finfo, tm) -> TmAmp2Core (finfo, lower_term_to_core tm)
| TmInr (finfo, tm) -> TmInrCore (finfo, lower_term_to_core tm)
| TmInl (finfo, tm) -> TmInlCore (finfo, lower_term_to_core tm)
| TmApp (finfo, tm1, tm2) -> TmAppCore (finfo, lower_term_to_core tm1, lower_term_to_core tm2)
| TmTens (finfo, tm1, tm2) -> TmTensCore (finfo, lower_term_to_core tm1, lower_term_to_core tm2)
| TmLetBind (finfo, b, tm1, tm2) -> TmLetBindCore (finfo, b, lower_term_to_core tm1, lower_term_to_core tm2)
| TmTensDest (finfo, b, c, tm1, tm2) -> TmTensDestCore (finfo, b, c, lower_term_to_core tm1, lower_term_to_core tm2)
| TmAmpersand (finfo, tm1, tm2) -> TmAmpersandCore (finfo, lower_term_to_core tm1, lower_term_to_core tm2)
| TmBoxDest (finfo, b, tm1, tm2) -> TmBoxDestCore (finfo, b, lower_term_to_core tm1, lower_term_to_core tm2)
| TmUnionCase (finfo, tm1, c, tm2, e, tm3) -> TmUnionCaseCore (finfo, lower_term_to_core tm1, c, lower_term_to_core tm2, e, lower_term_to_core tm3)

let lift_core_op_to_op (op: op_core) : op =
match op with
| AddOpCore -> AddOp
| MulOpCore -> MulOp
| SqrtOpCore -> SqrtOp
| DivOpCore -> DivOp
| GtOpCore -> GtOp
| EqOpCore -> EqOp

let rec lift_core_to_term (core_program : term_core) : term =
match core_program with
| TmPrimCore (finfo, x) -> TmPrim (finfo, x)
| TmVarCore (finfo, x) -> TmVar (finfo, x)
| TmLetCore (finfo, b, c, tm1, tm2) -> TmLet (finfo, b, c, lift_core_to_term tm1, lift_core_to_term tm2)
| TmAbsCore (finfo, b, c, tm) -> TmAbs (finfo, b, c, lift_core_to_term tm)
| TmRnd16Core (finfo, tm) -> TmRnd16 (finfo, lift_core_to_term tm)
| TmRnd32Core (finfo, tm) -> TmRnd32 (finfo, lift_core_to_term tm)
| TmRnd64Core (finfo, tm) -> TmRnd64 (finfo, lift_core_to_term tm)
| TmRetCore (finfo, tm) -> TmRet (finfo, lift_core_to_term tm)
| TmOpCore (finfo, op, tm) -> TmOp (finfo, lift_core_op_to_op op, lift_core_to_term tm) (* handle this later *)
| TmBoxCore (finfo, b, tm) -> TmBox (finfo, b, lift_core_to_term tm)
| TmAmp1Core (finfo, tm) -> TmAmp1 (finfo, lift_core_to_term tm)
| TmAmp2Core (finfo, tm) -> TmAmp2 (finfo, lift_core_to_term tm)
| TmInrCore (finfo, tm) -> TmInr (finfo, lift_core_to_term tm)
| TmInlCore (finfo, tm) -> TmInl (finfo, lift_core_to_term tm)
| TmAppCore (finfo, tm1, tm2) -> TmApp (finfo, lift_core_to_term tm1, lift_core_to_term tm2)
| TmTensCore (finfo, tm1, tm2) -> TmTens (finfo, lift_core_to_term tm1, lift_core_to_term tm2)
| TmLetBindCore (finfo, b, tm1, tm2) -> TmLetBind (finfo, b, lift_core_to_term tm1, lift_core_to_term tm2)
| TmTensDestCore (finfo, b, c, tm1, tm2) -> TmTensDest (finfo, b, c, lift_core_to_term tm1, lift_core_to_term tm2)
| TmAmpersandCore (finfo, tm1, tm2) -> TmAmpersand (finfo, lift_core_to_term tm1, lift_core_to_term tm2)
| TmBoxDestCore (finfo, b, tm1, tm2) -> TmBoxDest (finfo, b, lift_core_to_term tm1, lift_core_to_term tm2)
| TmUnionCaseCore (finfo, tm1, c, tm2, e, tm3) -> TmUnionCase (finfo, lift_core_to_term tm1, c, lift_core_to_term tm2, e, lift_core_to_term tm3)
40 changes: 20 additions & 20 deletions src/syntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -216,36 +216,36 @@ type term =
type term_core =
| TmVarCore of info * var_info
(* *)
| TmTensCore of info * term * term
| TmTensDestCore of info * binder_info * binder_info * term * term
| TmInlCore of info * term
| TmInrCore of info * term
| TmUnionCaseCore of info * term * binder_info * term * binder_info * term
| TmTensCore of info * term_core * term_core
| TmTensDestCore of info * binder_info * binder_info * term_core * term_core
| TmInlCore of info * term_core
| TmInrCore of info * term_core
| TmUnionCaseCore of info * term_core * binder_info * term_core * binder_info * term_core
(* t of { inl(x) => tm1 | inl(y) => tm2 } *)
(* Primitive terms *)
(* Primitive core terms *)
| TmPrimCore of info * term_prim
(* Rounding *)
| TmRnd64Core of info * term
| TmRnd32Core of info * term
| TmRnd16Core of info * term
| TmRnd64Core of info * term_core
| TmRnd32Core of info * term_core
| TmRnd16Core of info * term_core
(* Ret *)
| TmRetCore of info * term
| TmRetCore of info * term_core
(* Regular Abstraction and Applicacion *)
| TmAppCore of info * term * term
| TmAbsCore of info * binder_info * ty * term
| TmAppCore of info * term_core * term_core
| TmAbsCore of info * binder_info * ty * term_core
(* & constructor and eliminator *)
| TmAmpersandCore of info * term * term
| TmAmp1Core of info * term
| TmAmp2Core of info * term
| TmAmpersandCore of info * term_core * term_core
| TmAmp1Core of info * term_core
| TmAmp2Core of info * term_core
(* Box constructor and elim *)
| TmBoxCore of info * si * term
| TmBoxDestCore of info * binder_info * term * term
| TmBoxCore of info * si * term_core
| TmBoxDestCore of info * binder_info * term_core * term_core
(* Regular sequencing *)
| TmLetCore of info * binder_info * ty option * term * term
| TmLetCore of info * binder_info * ty option * term_core * term_core
(* Monadic sequencing *)
| TmLetBindCore of info * binder_info * term * term
| TmLetBindCore of info * binder_info * term_core * term_core
(* Basic ops *)
| TmOpCore of info * op_core * term
| TmOpCore of info * op_core * term_core

let map_prim_ty n f p =
match p with
Expand Down
75 changes: 38 additions & 37 deletions src/ty_bi.ml
Original file line number Diff line number Diff line change
Expand Up @@ -324,13 +324,13 @@ module TypeSub = struct
let num = (TyPrim PrimNum) in
let ty_bool = (TyUnion(TyPrim PrimUnit, TyPrim PrimUnit)) in
match op with
| AddOp -> return (TyLollipop((TyAmpersand(num, num)),num))
| MulOp -> return (TyLollipop((TyTensor(num, num)),num))
| SqrtOp -> return (TyLollipop((TyBang(si_hlf, num)),num))
| DivOp -> return (TyLollipop((TyTensor(num, num)),num))
| GtOp ->
| AddOpCore -> return (TyLollipop((TyAmpersand(num, num)),num))
| MulOpCore -> return (TyLollipop((TyTensor(num, num)),num))
| SqrtOpCore -> return (TyLollipop((TyBang(si_hlf, num)),num))
| DivOpCore -> return (TyLollipop((TyTensor(num, num)),num))
| GtOpCore ->
return (TyLollipop((TyTensor(TyBang(si_infty,num),TyBang(si_infty,num))),ty_bool))
| EqOp ->
| EqOpCore ->
return (TyLollipop((TyTensor(TyBang(si_infty,num),TyBang(si_infty,num))),ty_bool))

let check_is_num' ty : bool =
Expand Down Expand Up @@ -460,62 +460,61 @@ let kind_of (i : info) (si : si) : kind checker =
satisfied in order for the type to be valid. Raises an error if it
detects that no typing is possible. *)

let rec type_of (t : term) : (ty * bsi list) checker =

ty_debug (tmInfo t) "--> [%3d] Enter type_of: @[%a@]" !ty_seq
(Print.limit_boxes Print.pp_term) t; incr ty_seq;
let rec type_of (t : term_core) : (ty * bsi list) checker =
let t_lifted = Paired.lift_core_to_term t in
ty_debug (tmInfo t_lifted) "--> [%3d] Enter type_of: @[%a@]" !ty_seq
(Print.limit_boxes Print.pp_term) t_lifted; incr ty_seq;

(match t with
(* Variables *)
| TmVar(_i, x) ->
| TmVarCore(_i, x) ->
get_ctx_length >>= fun len ->
get_var_ty x >>= fun ty_x ->
return (ty_x, singleton len x)

(* Primitive terms *)
| TmPrim(_, pt) ->
| TmPrimCore(_, pt) ->
get_ctx_length >>= fun len ->
return (type_of_prim pt, zeros len)

(* Rounding *)
| TmRnd64(i, v) ->
| TmRnd64Core(i, v) ->
type_of v >>= fun (ty_v, sis_v) ->
check_is_num i ty_v >>

let eps = SiConst ( (2.220446049250313e-16)) in
return (TyMonad(eps, TyPrim PrimNum), sis_v)

| TmRnd32(i, v) ->
| TmRnd32Core(i, v) ->
type_of v >>= fun (ty_v, sis_v) ->
check_is_num i ty_v >>

let eps = SiConst ( (1.192092895507812e-7)) in
return (TyMonad(eps, TyPrim PrimNum), sis_v)

| TmRnd16(i, v) ->
| TmRnd16Core(i, v) ->
type_of v >>= fun (ty_v, sis_v) ->
check_is_num i ty_v >>

let eps = SiConst ( (0.0009765625)) in
return (TyMonad(eps, TyPrim PrimNum), sis_v)

(* Ret *)
| TmRet(_i, v) ->
| TmRetCore(_i, v) ->
type_of v >>= fun (ty_v, sis_v) ->

return (TyMonad(si_zero, ty_v), sis_v)

(* Abstraction and Application *)

(* λ (x : tya_x) { tm } *)
| TmAbs(i, b_x, tya_x, tm) ->

| TmAbsCore(i, b_x, tya_x, tm) ->
with_extended_ctx i b_x.b_name tya_x (type_of tm) >>= fun (ty_tm, si_x, sis) ->

let si_x1 = si_of_bsi si_x in
let si_x2 = Simpl.si_simpl_compute si_x1 in

ty_debug (tmInfo t) "### [%3d] Inferred sensitivity for binder @[%a@] is @[%a@]" !ty_seq P.pp_binfo b_x P.pp_si si_x2;
ty_debug (tmInfo t_lifted) "### [%3d] Inferred sensitivity for binder @[%a@] is @[%a@]" !ty_seq P.pp_binfo b_x P.pp_si si_x2;

let si_x3 = Simpl.si_simpl si_x2 in
let si_x4 = Simpl.si_simpl_compute si_x3 in
Expand All @@ -524,7 +523,7 @@ let rec type_of (t : term) : (ty * bsi list) checker =
return (TyLollipop (tya_x, ty_tm), sis)

(* tm1 β → α, tm2: β *)
| TmApp(i, tm1, tm2) ->
| TmAppCore(i, tm1, tm2) ->

type_of tm1 >>= fun (ty1, sis1) ->
type_of tm2 >>= fun (ty2, sis2) ->
Expand All @@ -537,7 +536,7 @@ let rec type_of (t : term) : (ty * bsi list) checker =

(* Standard let-binding *)
(* x : oty_x = tm_x ; e *)
| TmLet(i, x, oty_x, tm_x, e) ->
| TmLetCore(i, x, oty_x, tm_x, e) ->

type_of tm_x >>= fun (ty_x, sis_x) ->

Expand All @@ -558,7 +557,7 @@ let rec type_of (t : term) : (ty * bsi list) checker =
return (ty_e, add_sens sis_e (scale_sens (Some si_x) sis_x))

(* Monadic let-binding x = v ; e *)
| TmLetBind(i, x, v, e) ->
| TmLetBindCore(i, x, v, e) ->

type_of v >>= fun (ty_v, sis_v) ->

Expand All @@ -577,36 +576,36 @@ let rec type_of (t : term) : (ty * bsi list) checker =
return (TyMonad(si_total,ty_e'), add_sens sis_e (scale_sens (Some si_x) sis_v))

(* Tensor product and Cartesian product (ampersand &)*)
| TmAmpersand(_i, tm1, tm2) ->
| TmAmpersandCore(_i, tm1, tm2) ->

type_of tm1 >>= fun (ty1, sis1) ->
type_of tm2 >>= fun (ty2, sis2) ->

return (TyAmpersand(ty1, ty2), lub_sens sis1 sis2)

| TmAmp1(i, tm1) ->
| TmAmp1Core(i, tm1) ->

type_of tm1 >>= fun (ty, sis1) ->
check_amp_shape i ty >>= fun(ty_1, _ty_2) ->

return (ty_1, sis1)

| TmAmp2(i, tm2) ->
| TmAmp2Core(i, tm2) ->

type_of tm2 >>= fun (ty, sis2) ->
check_amp_shape i ty >>= fun(_ty_1, ty_2) ->

return (ty_2, sis2)

| TmTens(_i, e1, e2) ->
| TmTensCore(_i, e1, e2) ->

type_of e1 >>= fun (ty1, sis1) ->
type_of e2 >>= fun (ty2, sis2) ->

return @@ (TyTensor(ty1, ty2), add_sens sis1 sis2)

(* let (x,y) = v in e *)
| TmTensDest(i, x, y, v, e) ->
| TmTensDestCore(i, x, y, v, e) ->

type_of v >>= fun (ty_v, sis_v) ->
check_tensor_shape i ty_v >>= fun (ty_x, ty_y) ->
Expand All @@ -621,14 +620,14 @@ let rec type_of (t : term) : (ty * bsi list) checker =
return (ty_e, add_sens sis_e (scale_sens (Some si_max) sis_v))

(* Exponentials (bangs and boxes) *)
| TmBox(_i,si_v,v) ->
| TmBoxCore(_i,si_v,v) ->

type_of v >>= fun (ty_v, sis_v) ->

return(TyBang(si_v, ty_v), scale_sens (Some si_v) sis_v)

(* let [x] = v in e *)
| TmBoxDest(i,x,tm_v,tm_e) ->
| TmBoxDestCore(i,x,tm_v,tm_e) ->

type_of tm_v >>= fun (sity_v, sis_v) ->

Expand All @@ -648,7 +647,7 @@ let rec type_of (t : term) : (ty * bsi list) checker =

(* Case analysis *)
(* case v of inl(x) => e_l | inr(y) => f_r *)
| TmUnionCase(i, v, b_x, e_l, b_y, f_r) ->
| TmUnionCaseCore(i, v, b_x, e_l, b_y, f_r) ->

type_of v >>= fun (ty_v, sis_v) ->

Expand All @@ -660,9 +659,11 @@ let rec type_of (t : term) : (ty * bsi list) checker =

check_ty_union i tyl tyr >>= fun ty_exp ->

ty_debug (tmInfo v) "### In case, [%3d] Inferred sensitivity for binder @[%a@] is @[%a@]" !ty_seq P.pp_binfo b_x P.pp_si (si_of_bsi si_x);
let v_lifted = Paired.lift_core_to_term v in

ty_debug (tmInfo v_lifted) "### In case, [%3d] Inferred sensitivity for binder @[%a@] is @[%a@]" !ty_seq P.pp_binfo b_x P.pp_si (si_of_bsi si_x);

ty_debug (tmInfo v) "*** Context: @[%a@]" (Print.pp_list Print.pp_si) (bsi_sens sis_v);
ty_debug (tmInfo v_lifted) "*** Context: @[%a@]" (Print.pp_list Print.pp_si) (bsi_sens sis_v);

let si_x = si_of_bsi si_x in
let si_y = si_of_bsi si_y in
Expand All @@ -676,18 +677,18 @@ let rec type_of (t : term) : (ty * bsi list) checker =
else
return (ty_exp, add_sens theta (scale_sens (Some si_x) sis_v))

| TmInl(_i, tm_l) ->
| TmInlCore(_i, tm_l) ->

type_of tm_l >>= fun (ty, sis) ->
return (TyUnion(ty, TyPrim PrimUnit), sis)

| TmInr(_i, tm_r) ->
| TmInrCore(_i, tm_r) ->

type_of tm_r >>= fun (ty, sis) ->
return (TyUnion(TyPrim PrimUnit, ty), sis)

(* Ops *)
| TmOp(i, fop, v) ->
| TmOpCore(i, fop, v) ->

type_of v >>= fun (ty_v, sis_v) ->

Expand All @@ -702,8 +703,8 @@ let rec type_of (t : term) : (ty * bsi list) checker =

decr ty_seq;
(* We limit pp_term *)
ty_debug (tmInfo t) "<-- [%3d] Exit type_of : @[%a@] with type @[%a@]" !ty_seq
(Print.limit_boxes Print.pp_term) t Print.pp_type ty;
ty_debug (tmInfo t_lifted) "<-- [%3d] Exit type_of : @[%a@] with type @[%a@]" !ty_seq
(Print.limit_boxes Print.pp_term) t_lifted Print.pp_type ty;

(* TODO: pretty printer for sensitivities *)
(* ty_debug2 (tmInfo t) "<-- Context: @[%a@]" Print.pp_context ctx; *)
Expand Down