diff --git a/src/nfuzz.ml b/src/nfuzz.ml index 0422f8d..abc636f 100644 --- a/src/nfuzz.ml +++ b/src/nfuzz.ml @@ -138,6 +138,7 @@ let relative_error1 ty = let relative_error2 si = match si with SiConst f -> Some (f /. (1.0 -. f)) | _ -> None + let type_check program = let ty = Ty_bi.get_type program in main_info dp "Type of the program: @[%a@]" Print.pp_type ty; @@ -185,7 +186,8 @@ let main () = (* Print the results of the parsing phase *) main_debug dp "Parsed program:@\n@[%a@]@." Print.pp_term program; - (* if comp_enabled TypeChecker then type_check program; *) + let paired_program = Paired.lower_term_to_core program in + if comp_enabled TypeChecker then type_check paired_program; (if comp_enabled Backend then match !outfile with diff --git a/src/paired.ml b/src/paired.ml new file mode 100644 index 0000000..b846ed6 --- /dev/null +++ b/src/paired.ml @@ -0,0 +1,76 @@ +open Syntax +open Support.Error +open Support.Options + +(* Assume that we do not have subtraction. + If we encounter a subtraction operation, fail loudly. + Otherwise, term_core contains the same information as term.*) +let lower_op_to_core (op: op) finfo : op_core = + match op with + | AddOp -> AddOpCore + | SubOp -> + error_msg General finfo "Subtraction is not supported as a core operation." + | MulOp -> MulOpCore + | SqrtOp -> SqrtOpCore + | DivOp -> DivOpCore + | GtOp -> GtOpCore + | EqOp -> EqOpCore + +(* Lower a program of type term to core, assuming that + subtraction is not used. *) +let rec lower_term_to_core (program: term) : term_core = + match program with + | TmPrim (finfo, x) -> TmPrimCore (finfo, x) + | TmVar (finfo, x) -> TmVarCore (finfo, x) + | TmLet (finfo, b, c, tm1, tm2) -> TmLetCore (finfo, b, c, lower_term_to_core tm1, lower_term_to_core tm2) + | TmAbs (finfo, b, c, tm) -> TmAbsCore (finfo, b, c, lower_term_to_core tm) + | TmRnd16 (finfo, tm) -> TmRnd16Core (finfo, lower_term_to_core tm) + | TmRnd32 (finfo, tm) -> TmRnd32Core (finfo, lower_term_to_core tm) + | TmRnd64 (finfo, tm) -> TmRnd64Core (finfo, lower_term_to_core tm) + | TmRet (finfo, tm) -> TmRetCore (finfo, lower_term_to_core tm) + | TmOp (finfo, op, tm) -> TmOpCore (finfo, lower_op_to_core op finfo, lower_term_to_core tm) (* handle this later *) + | TmBox (finfo, b, tm) -> TmBoxCore (finfo, b, lower_term_to_core tm) + | TmAmp1 (finfo, tm) -> TmAmp1Core (finfo, lower_term_to_core tm) + | TmAmp2 (finfo, tm) -> TmAmp2Core (finfo, lower_term_to_core tm) + | TmInr (finfo, tm) -> TmInrCore (finfo, lower_term_to_core tm) + | TmInl (finfo, tm) -> TmInlCore (finfo, lower_term_to_core tm) + | TmApp (finfo, tm1, tm2) -> TmAppCore (finfo, lower_term_to_core tm1, lower_term_to_core tm2) + | TmTens (finfo, tm1, tm2) -> TmTensCore (finfo, lower_term_to_core tm1, lower_term_to_core tm2) + | TmLetBind (finfo, b, tm1, tm2) -> TmLetBindCore (finfo, b, lower_term_to_core tm1, lower_term_to_core tm2) + | TmTensDest (finfo, b, c, tm1, tm2) -> TmTensDestCore (finfo, b, c, lower_term_to_core tm1, lower_term_to_core tm2) + | TmAmpersand (finfo, tm1, tm2) -> TmAmpersandCore (finfo, lower_term_to_core tm1, lower_term_to_core tm2) + | TmBoxDest (finfo, b, tm1, tm2) -> TmBoxDestCore (finfo, b, lower_term_to_core tm1, lower_term_to_core tm2) + | TmUnionCase (finfo, tm1, c, tm2, e, tm3) -> TmUnionCaseCore (finfo, lower_term_to_core tm1, c, lower_term_to_core tm2, e, lower_term_to_core tm3) + +let lift_core_op_to_op (op: op_core) : op = + match op with + | AddOpCore -> AddOp + | MulOpCore -> MulOp + | SqrtOpCore -> SqrtOp + | DivOpCore -> DivOp + | GtOpCore -> GtOp + | EqOpCore -> EqOp + +let rec lift_core_to_term (core_program : term_core) : term = + match core_program with + | TmPrimCore (finfo, x) -> TmPrim (finfo, x) + | TmVarCore (finfo, x) -> TmVar (finfo, x) + | TmLetCore (finfo, b, c, tm1, tm2) -> TmLet (finfo, b, c, lift_core_to_term tm1, lift_core_to_term tm2) + | TmAbsCore (finfo, b, c, tm) -> TmAbs (finfo, b, c, lift_core_to_term tm) + | TmRnd16Core (finfo, tm) -> TmRnd16 (finfo, lift_core_to_term tm) + | TmRnd32Core (finfo, tm) -> TmRnd32 (finfo, lift_core_to_term tm) + | TmRnd64Core (finfo, tm) -> TmRnd64 (finfo, lift_core_to_term tm) + | TmRetCore (finfo, tm) -> TmRet (finfo, lift_core_to_term tm) + | TmOpCore (finfo, op, tm) -> TmOp (finfo, lift_core_op_to_op op, lift_core_to_term tm) (* handle this later *) + | TmBoxCore (finfo, b, tm) -> TmBox (finfo, b, lift_core_to_term tm) + | TmAmp1Core (finfo, tm) -> TmAmp1 (finfo, lift_core_to_term tm) + | TmAmp2Core (finfo, tm) -> TmAmp2 (finfo, lift_core_to_term tm) + | TmInrCore (finfo, tm) -> TmInr (finfo, lift_core_to_term tm) + | TmInlCore (finfo, tm) -> TmInl (finfo, lift_core_to_term tm) + | TmAppCore (finfo, tm1, tm2) -> TmApp (finfo, lift_core_to_term tm1, lift_core_to_term tm2) + | TmTensCore (finfo, tm1, tm2) -> TmTens (finfo, lift_core_to_term tm1, lift_core_to_term tm2) + | TmLetBindCore (finfo, b, tm1, tm2) -> TmLetBind (finfo, b, lift_core_to_term tm1, lift_core_to_term tm2) + | TmTensDestCore (finfo, b, c, tm1, tm2) -> TmTensDest (finfo, b, c, lift_core_to_term tm1, lift_core_to_term tm2) + | TmAmpersandCore (finfo, tm1, tm2) -> TmAmpersand (finfo, lift_core_to_term tm1, lift_core_to_term tm2) + | TmBoxDestCore (finfo, b, tm1, tm2) -> TmBoxDest (finfo, b, lift_core_to_term tm1, lift_core_to_term tm2) + | TmUnionCaseCore (finfo, tm1, c, tm2, e, tm3) -> TmUnionCase (finfo, lift_core_to_term tm1, c, lift_core_to_term tm2, e, lift_core_to_term tm3) diff --git a/src/syntax.ml b/src/syntax.ml index 1f11e7d..0a7e0b2 100644 --- a/src/syntax.ml +++ b/src/syntax.ml @@ -216,36 +216,36 @@ type term = type term_core = | TmVarCore of info * var_info (* *) - | TmTensCore of info * term * term - | TmTensDestCore of info * binder_info * binder_info * term * term - | TmInlCore of info * term - | TmInrCore of info * term - | TmUnionCaseCore of info * term * binder_info * term * binder_info * term + | TmTensCore of info * term_core * term_core + | TmTensDestCore of info * binder_info * binder_info * term_core * term_core + | TmInlCore of info * term_core + | TmInrCore of info * term_core + | TmUnionCaseCore of info * term_core * binder_info * term_core * binder_info * term_core (* t of { inl(x) => tm1 | inl(y) => tm2 } *) - (* Primitive terms *) + (* Primitive core terms *) | TmPrimCore of info * term_prim (* Rounding *) - | TmRnd64Core of info * term - | TmRnd32Core of info * term - | TmRnd16Core of info * term + | TmRnd64Core of info * term_core + | TmRnd32Core of info * term_core + | TmRnd16Core of info * term_core (* Ret *) - | TmRetCore of info * term + | TmRetCore of info * term_core (* Regular Abstraction and Applicacion *) - | TmAppCore of info * term * term - | TmAbsCore of info * binder_info * ty * term + | TmAppCore of info * term_core * term_core + | TmAbsCore of info * binder_info * ty * term_core (* & constructor and eliminator *) - | TmAmpersandCore of info * term * term - | TmAmp1Core of info * term - | TmAmp2Core of info * term + | TmAmpersandCore of info * term_core * term_core + | TmAmp1Core of info * term_core + | TmAmp2Core of info * term_core (* Box constructor and elim *) - | TmBoxCore of info * si * term - | TmBoxDestCore of info * binder_info * term * term + | TmBoxCore of info * si * term_core + | TmBoxDestCore of info * binder_info * term_core * term_core (* Regular sequencing *) - | TmLetCore of info * binder_info * ty option * term * term + | TmLetCore of info * binder_info * ty option * term_core * term_core (* Monadic sequencing *) - | TmLetBindCore of info * binder_info * term * term + | TmLetBindCore of info * binder_info * term_core * term_core (* Basic ops *) - | TmOpCore of info * op_core * term + | TmOpCore of info * op_core * term_core let map_prim_ty n f p = match p with diff --git a/src/ty_bi.ml b/src/ty_bi.ml index d24c67e..97c6ce1 100644 --- a/src/ty_bi.ml +++ b/src/ty_bi.ml @@ -324,13 +324,13 @@ module TypeSub = struct let num = (TyPrim PrimNum) in let ty_bool = (TyUnion(TyPrim PrimUnit, TyPrim PrimUnit)) in match op with - | AddOp -> return (TyLollipop((TyAmpersand(num, num)),num)) - | MulOp -> return (TyLollipop((TyTensor(num, num)),num)) - | SqrtOp -> return (TyLollipop((TyBang(si_hlf, num)),num)) - | DivOp -> return (TyLollipop((TyTensor(num, num)),num)) - | GtOp -> + | AddOpCore -> return (TyLollipop((TyAmpersand(num, num)),num)) + | MulOpCore -> return (TyLollipop((TyTensor(num, num)),num)) + | SqrtOpCore -> return (TyLollipop((TyBang(si_hlf, num)),num)) + | DivOpCore -> return (TyLollipop((TyTensor(num, num)),num)) + | GtOpCore -> return (TyLollipop((TyTensor(TyBang(si_infty,num),TyBang(si_infty,num))),ty_bool)) - | EqOp -> + | EqOpCore -> return (TyLollipop((TyTensor(TyBang(si_infty,num),TyBang(si_infty,num))),ty_bool)) let check_is_num' ty : bool = @@ -460,39 +460,39 @@ let kind_of (i : info) (si : si) : kind checker = satisfied in order for the type to be valid. Raises an error if it detects that no typing is possible. *) -let rec type_of (t : term) : (ty * bsi list) checker = - - ty_debug (tmInfo t) "--> [%3d] Enter type_of: @[%a@]" !ty_seq - (Print.limit_boxes Print.pp_term) t; incr ty_seq; +let rec type_of (t : term_core) : (ty * bsi list) checker = + let t_lifted = Paired.lift_core_to_term t in + ty_debug (tmInfo t_lifted) "--> [%3d] Enter type_of: @[%a@]" !ty_seq + (Print.limit_boxes Print.pp_term) t_lifted; incr ty_seq; (match t with (* Variables *) - | TmVar(_i, x) -> + | TmVarCore(_i, x) -> get_ctx_length >>= fun len -> get_var_ty x >>= fun ty_x -> return (ty_x, singleton len x) (* Primitive terms *) - | TmPrim(_, pt) -> + | TmPrimCore(_, pt) -> get_ctx_length >>= fun len -> return (type_of_prim pt, zeros len) (* Rounding *) - | TmRnd64(i, v) -> + | TmRnd64Core(i, v) -> type_of v >>= fun (ty_v, sis_v) -> check_is_num i ty_v >> let eps = SiConst ( (2.220446049250313e-16)) in return (TyMonad(eps, TyPrim PrimNum), sis_v) - | TmRnd32(i, v) -> + | TmRnd32Core(i, v) -> type_of v >>= fun (ty_v, sis_v) -> check_is_num i ty_v >> let eps = SiConst ( (1.192092895507812e-7)) in return (TyMonad(eps, TyPrim PrimNum), sis_v) - | TmRnd16(i, v) -> + | TmRnd16Core(i, v) -> type_of v >>= fun (ty_v, sis_v) -> check_is_num i ty_v >> @@ -500,7 +500,7 @@ let rec type_of (t : term) : (ty * bsi list) checker = return (TyMonad(eps, TyPrim PrimNum), sis_v) (* Ret *) - | TmRet(_i, v) -> + | TmRetCore(_i, v) -> type_of v >>= fun (ty_v, sis_v) -> return (TyMonad(si_zero, ty_v), sis_v) @@ -508,14 +508,13 @@ let rec type_of (t : term) : (ty * bsi list) checker = (* Abstraction and Application *) (* λ (x : tya_x) { tm } *) - | TmAbs(i, b_x, tya_x, tm) -> - + | TmAbsCore(i, b_x, tya_x, tm) -> with_extended_ctx i b_x.b_name tya_x (type_of tm) >>= fun (ty_tm, si_x, sis) -> let si_x1 = si_of_bsi si_x in let si_x2 = Simpl.si_simpl_compute si_x1 in - ty_debug (tmInfo t) "### [%3d] Inferred sensitivity for binder @[%a@] is @[%a@]" !ty_seq P.pp_binfo b_x P.pp_si si_x2; + ty_debug (tmInfo t_lifted) "### [%3d] Inferred sensitivity for binder @[%a@] is @[%a@]" !ty_seq P.pp_binfo b_x P.pp_si si_x2; let si_x3 = Simpl.si_simpl si_x2 in let si_x4 = Simpl.si_simpl_compute si_x3 in @@ -524,7 +523,7 @@ let rec type_of (t : term) : (ty * bsi list) checker = return (TyLollipop (tya_x, ty_tm), sis) (* tm1 β → α, tm2: β *) - | TmApp(i, tm1, tm2) -> + | TmAppCore(i, tm1, tm2) -> type_of tm1 >>= fun (ty1, sis1) -> type_of tm2 >>= fun (ty2, sis2) -> @@ -537,7 +536,7 @@ let rec type_of (t : term) : (ty * bsi list) checker = (* Standard let-binding *) (* x : oty_x = tm_x ; e *) - | TmLet(i, x, oty_x, tm_x, e) -> + | TmLetCore(i, x, oty_x, tm_x, e) -> type_of tm_x >>= fun (ty_x, sis_x) -> @@ -558,7 +557,7 @@ let rec type_of (t : term) : (ty * bsi list) checker = return (ty_e, add_sens sis_e (scale_sens (Some si_x) sis_x)) (* Monadic let-binding x = v ; e *) - | TmLetBind(i, x, v, e) -> + | TmLetBindCore(i, x, v, e) -> type_of v >>= fun (ty_v, sis_v) -> @@ -577,28 +576,28 @@ let rec type_of (t : term) : (ty * bsi list) checker = return (TyMonad(si_total,ty_e'), add_sens sis_e (scale_sens (Some si_x) sis_v)) (* Tensor product and Cartesian product (ampersand &)*) - | TmAmpersand(_i, tm1, tm2) -> + | TmAmpersandCore(_i, tm1, tm2) -> type_of tm1 >>= fun (ty1, sis1) -> type_of tm2 >>= fun (ty2, sis2) -> return (TyAmpersand(ty1, ty2), lub_sens sis1 sis2) - | TmAmp1(i, tm1) -> + | TmAmp1Core(i, tm1) -> type_of tm1 >>= fun (ty, sis1) -> check_amp_shape i ty >>= fun(ty_1, _ty_2) -> return (ty_1, sis1) - | TmAmp2(i, tm2) -> + | TmAmp2Core(i, tm2) -> type_of tm2 >>= fun (ty, sis2) -> check_amp_shape i ty >>= fun(_ty_1, ty_2) -> return (ty_2, sis2) - | TmTens(_i, e1, e2) -> + | TmTensCore(_i, e1, e2) -> type_of e1 >>= fun (ty1, sis1) -> type_of e2 >>= fun (ty2, sis2) -> @@ -606,7 +605,7 @@ let rec type_of (t : term) : (ty * bsi list) checker = return @@ (TyTensor(ty1, ty2), add_sens sis1 sis2) (* let (x,y) = v in e *) - | TmTensDest(i, x, y, v, e) -> + | TmTensDestCore(i, x, y, v, e) -> type_of v >>= fun (ty_v, sis_v) -> check_tensor_shape i ty_v >>= fun (ty_x, ty_y) -> @@ -621,14 +620,14 @@ let rec type_of (t : term) : (ty * bsi list) checker = return (ty_e, add_sens sis_e (scale_sens (Some si_max) sis_v)) (* Exponentials (bangs and boxes) *) - | TmBox(_i,si_v,v) -> + | TmBoxCore(_i,si_v,v) -> type_of v >>= fun (ty_v, sis_v) -> return(TyBang(si_v, ty_v), scale_sens (Some si_v) sis_v) (* let [x] = v in e *) - | TmBoxDest(i,x,tm_v,tm_e) -> + | TmBoxDestCore(i,x,tm_v,tm_e) -> type_of tm_v >>= fun (sity_v, sis_v) -> @@ -648,7 +647,7 @@ let rec type_of (t : term) : (ty * bsi list) checker = (* Case analysis *) (* case v of inl(x) => e_l | inr(y) => f_r *) - | TmUnionCase(i, v, b_x, e_l, b_y, f_r) -> + | TmUnionCaseCore(i, v, b_x, e_l, b_y, f_r) -> type_of v >>= fun (ty_v, sis_v) -> @@ -660,9 +659,11 @@ let rec type_of (t : term) : (ty * bsi list) checker = check_ty_union i tyl tyr >>= fun ty_exp -> - ty_debug (tmInfo v) "### In case, [%3d] Inferred sensitivity for binder @[%a@] is @[%a@]" !ty_seq P.pp_binfo b_x P.pp_si (si_of_bsi si_x); + let v_lifted = Paired.lift_core_to_term v in + + ty_debug (tmInfo v_lifted) "### In case, [%3d] Inferred sensitivity for binder @[%a@] is @[%a@]" !ty_seq P.pp_binfo b_x P.pp_si (si_of_bsi si_x); - ty_debug (tmInfo v) "*** Context: @[%a@]" (Print.pp_list Print.pp_si) (bsi_sens sis_v); + ty_debug (tmInfo v_lifted) "*** Context: @[%a@]" (Print.pp_list Print.pp_si) (bsi_sens sis_v); let si_x = si_of_bsi si_x in let si_y = si_of_bsi si_y in @@ -676,18 +677,18 @@ let rec type_of (t : term) : (ty * bsi list) checker = else return (ty_exp, add_sens theta (scale_sens (Some si_x) sis_v)) - | TmInl(_i, tm_l) -> + | TmInlCore(_i, tm_l) -> type_of tm_l >>= fun (ty, sis) -> return (TyUnion(ty, TyPrim PrimUnit), sis) - | TmInr(_i, tm_r) -> + | TmInrCore(_i, tm_r) -> type_of tm_r >>= fun (ty, sis) -> return (TyUnion(TyPrim PrimUnit, ty), sis) (* Ops *) - | TmOp(i, fop, v) -> + | TmOpCore(i, fop, v) -> type_of v >>= fun (ty_v, sis_v) -> @@ -702,8 +703,8 @@ let rec type_of (t : term) : (ty * bsi list) checker = decr ty_seq; (* We limit pp_term *) - ty_debug (tmInfo t) "<-- [%3d] Exit type_of : @[%a@] with type @[%a@]" !ty_seq - (Print.limit_boxes Print.pp_term) t Print.pp_type ty; + ty_debug (tmInfo t_lifted) "<-- [%3d] Exit type_of : @[%a@] with type @[%a@]" !ty_seq + (Print.limit_boxes Print.pp_term) t_lifted Print.pp_type ty; (* TODO: pretty printer for sensitivities *) (* ty_debug2 (tmInfo t) "<-- Context: @[%a@]" Print.pp_context ctx; *)