diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..1810e6abc9 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "libs/lospecs/tests/simde"] + path = libs/lospecs/tests/simde + url = git@github.com:simd-everywhere/simde.git diff --git a/config/tests.config b/config/tests.config index f7df574a8f..8e017a6e82 100644 --- a/config/tests.config +++ b/config/tests.config @@ -8,7 +8,7 @@ exclude = theories/prelude [test-examples] okdirs = !examples -exclude = examples/MEE-CBC examples/old examples/old/list-ddh !examples/incomplete examples/to-port +exclude = examples/MEE-CBC examples/exclude examples/old examples/old/list-ddh !examples/incomplete examples/to-port [test-mee-cbc] okdirs = examples/MEE-CBC diff --git a/dune b/dune index 7c8edf7096..6e918e80d4 100644 --- a/dune +++ b/dune @@ -1,4 +1,9 @@ -(dirs 3rdparty src etc theories examples assets scripts) +(env + (dev (flags -rectypes -warn-error -a+31 -w +28+33-9-23-32-58-67-69)) + (release (flags -rectypes -warn-error -a+31 -w +28+33-9-23-32-58-67-69) + (ocamlopt_flags -O3 -unbox-closures))) + +(dirs 3rdparty src etc libs theories examples assets scripts) (install (section (site (easycrypt commands))) diff --git a/dune-project b/dune-project index 85f142616e..64b6a5eaf7 100644 --- a/dune-project +++ b/dune-project @@ -13,7 +13,8 @@ (sites (lib theories) (libexec commands) (lib doc) (lib config)) (depends (ocaml (>= 4.08.0)) - (batteries (>= 3)) + (batteries (>= 3.9)) + bitwuzla (camlp-streams (>= 5)) camlzip dune @@ -22,6 +23,12 @@ markdown (pcre2 (>= 8)) (why3 (and (>= 1.8.0) (< 1.9))) + ppx_deriving + ppx_deriving_yojson + hex + iter + cmdliner + progress yojson (zarith (>= 1.10)) )) diff --git a/easycrypt.opam b/easycrypt.opam index 08bdb40eac..d957b69428 100644 --- a/easycrypt.opam +++ b/easycrypt.opam @@ -1,7 +1,8 @@ # This file is generated by dune, edit dune-project instead depends: [ "ocaml" {>= "4.08.0"} - "batteries" {>= "3"} + "batteries" {>= "3.9"} + "bitwuzla" "camlp-streams" {>= "5"} "camlzip" "dune" {>= "3.13"} @@ -10,6 +11,12 @@ depends: [ "markdown" "pcre2" {>= "8"} "why3" {>= "1.8.0" & < "1.9"} + "ppx_deriving" + "ppx_deriving_yojson" + "hex" + "iter" + "cmdliner" + "progress" "yojson" "zarith" {>= "1.10"} "odoc" {with-doc} diff --git a/examples/bindings.ec b/examples/bindings.ec new file mode 100644 index 0000000000..1dc74cb333 --- /dev/null +++ b/examples/bindings.ec @@ -0,0 +1,410 @@ +require import AllCore Bool IntDiv CoreMap List Distr QFABV. +from Jasmin require import JModel JArray. + +clone import PolyArray as Array2 with op size <- 2. + +bind array Array2."_.[_]" Array2."_.[_<-_]" Array2.to_list Array2.of_list Array2.t 2. +realize tolistP by admit. +realize eqP by admit. +realize get_setP by admit. +realize get_out by admit. + +export Array2. + +(* ----------- BEGIN BOOL BINDINGS ---------- *) +op bool2bits (b : bool) : bool list = [b]. +op bits2bool (b: bool list) : bool = List.nth false b 0. + +op i2b (i : int) = (i %% 2 <> 0). +op b2si (b: bool) = 0. + +bind bitstring bool2bits bits2bool b2i b2si i2b bool 1. +realize size_tolist by auto. +realize tolistP by auto. +realize oflistP by admit. +realize ofintP by admit. +realize touintP by admit. +realize tosintP by auto. + +bind op bool (&&) "mul". +realize bvmulP by admit. + +bind op bool (^^) "add". +realize bvaddP by admit. + +op sub (a : bool, b: bool) : bool = + a ^^ b. + +bind op bool sub "sub". +realize bvsubP by admit. + +(* bind op bool udiv "udiv". + realize bvudivP by admit. + +bind op bool umod "urem". +realize bvuremP by admit. *) + +bind op bool (/\) "and". +realize bvandP by admit. + +bind op bool (\/) "or". +realize bvorP by admit. + +bind op bool [!] "not". +realize bvnotP by admit. + +(* TODO: Add shifts once we have truncate/extend *) + + +(* ----------- BEGIN W8 BINDINGS ---------- *) +bind bitstring W8.w2bits W8.bits2w W8.to_uint W8.to_sint W8.of_int W8.t 8. +realize size_tolist by auto. +realize tolistP by auto. +realize oflistP by admit. +realize ofintP by admit. +realize touintP by admit. +realize tosintP by admit. + +bind op W8.t W8.( + ) "add". +realize bvaddP by admit. + +bind op W8.t W8.( * ) "mul". +realize bvmulP by admit. + +op W8_sub (a : W8.t, b: W8.t) : W8.t = + a - b. + +bind op W8.t W8_sub "sub". +realize bvsubP by admit. + +bind op W8.t W8.\udiv "udiv". +realize bvudivP by admit. + +bind op W8.t W8.\umod "urem". +realize bvuremP by admit. + +bind op W8.t W8.andw "and". +realize bvandP by admit. + +bind op W8.t W8.orw "or". +realize bvorP by admit. + +bind op W8.t W8.(+^) "xor". +realize bvxorP by admit. + +bind op W8.t W8.invw "not". +realize bvnotP by admit. + +bind op [bool & W8.t] W8.\ult "ult". +realize bvultP by admit. + +bind op [bool & W8.t] W8.\ule "ule". +realize bvuleP by admit. + +bind op [bool & W8.t] W8.\slt "slt". +realize bvsltP by admit. + +bind op [bool & W8.t] W8.\sle "sle". +realize bvsleP by admit. + +bind op W8.t W8.(`>>`) "shr". +realize bvshrP by admit. + +bind op W8.t W8.(`<<`) "shl". +realize bvshlP by admit. + +bind op W8.t W8.(`|>>`) "ashr". +realize bvashrP by admit. + + + +(* ----------- BEGIN W16 BINDINGS ---------- *) + +bind bitstring W16.w2bits W16.bits2w W16.to_uint W16.to_sint W16.of_int W16.t 16. +realize size_tolist by auto. +realize tolistP by auto. +realize oflistP by admit. +realize ofintP by admit. +realize touintP by admit. +realize tosintP by admit. + +bind op W16.t W16.( + ) "add". +realize bvaddP by admit. + +bind op W16.t W16.( * ) "mul". +realize bvmulP by admit. + +op W16_sub (a : W16.t, b: W16.t) : W16.t = + a - b. + +bind op W16.t W16_sub "sub". +realize bvsubP by admit. + +bind op W16.t W16.\udiv "udiv". +realize bvudivP by admit. + +bind op W16.t W16.\umod "urem". +realize bvuremP by admit. + +bind op W16.t W16.andw "and". +realize bvandP by admit. + +bind op W16.t W16.orw "or". +realize bvorP by admit. + +bind op W16.t W16.(+^) "xor". +realize bvxorP by admit. + +bind op W16.t W16.invw "not". +realize bvnotP by admit. + +bind op [bool & W16.t] W16.\ult "ult". +realize bvultP by admit. + +bind op [bool & W16.t] W16.\ule "ule". +realize bvuleP by admit. + +bind op [bool & W16.t] W16.\sle "sle". +realize bvsleP by admit. + +bind op [bool & W16.t] W16.\slt "slt". +realize bvsltP by admit. + +op uext8_16 (w: W8.t) : W16.t = + W16.of_int (W8.to_uint w). + +bind op [W8.t & W16.t] uext8_16 "zextend". +realize bvzextendP by admit. + +op sext8_16 (w: W8.t) : W16.t = + W16.of_int (W8.to_sint w). + +bind op [W8.t & W16.t] sext8_16 "sextend". +realize bvsextendP by admit. + +op concat8_8_16 (w: W8.t) (w: W8.t) : W16.t. + +bind op [W8.t & W8.t & W16.t] concat8_8_16 "concat". +realize bvconcatP by admit. + + +op shl16 (w: W16.t) (sa: W16.t) : W16.t. + +lemma shl_shift w sa : + W16.(`<<`) w sa = shl16 w (uext8_16 sa) by admit. + +bind op W16.t shl16 "shl". +realize bvshlP by admit. + +(* TODO: Add shifts once we have truncate/extend *) + + +(* ----------- BEGIN W32 BINDINGS ---------- *) +bind bitstring W32.w2bits W32.bits2w W32.to_uint W32.to_sint W32.of_int W32.t 32. +realize size_tolist by auto. +realize tolistP by auto. +realize oflistP by admit. +realize touintP by admit. +realize tosintP by admit. +realize ofintP by admit. + +bind op W32.t W32.( + ) "add". +realize bvaddP by admit. + +bind op W32.t W32.( * ) "mul". +realize bvmulP by admit. + +op W32_sub (a : W32.t, b: W32.t) : W32.t = + a - b. + +bind op W32.t W32_sub "sub". +realize bvsubP by admit. + +bind op W32.t W32.\udiv "udiv". +realize bvudivP by admit. + +bind op W32.t W32.\umod "urem". +realize bvuremP by admit. + +bind op W32.t W32.andw "and". +realize bvandP by admit. + +bind op W32.t W32.orw "or". +realize bvorP by admit. + +bind op W32.t W32.(+^) "xor". +realize bvxorP by admit. + +bind op W32.t W32.invw "not". +realize bvnotP by admit. + +bind op [W32.t & bool] W32."_.[_]" "get". +realize bvgetP by admit. + +(* TODO: Add shifts once we have truncate/extend *) + + +(* ----------- BEGIN W64 BINDINGS ---------- *) + +bind bitstring W64.w2bits W64.bits2w W64.to_uint W64.to_sint W64.of_int W64.t 64. +realize size_tolist by auto. +realize tolistP by auto. +realize oflistP by admit. +realize touintP by admit. +realize tosintP by admit. +realize ofintP by admit. + +bind op W64.t W64.( + ) "add". +realize bvaddP by admit. + +bind op W64.t W64.( * ) "mul". +realize bvmulP by admit. + +op W64_sub (a : W64.t, b: W64.t) : W64.t = + a - b. + +bind op W64.t W64_sub "sub". +realize bvsubP by admit. + +bind op W64.t W64.\udiv "udiv". +realize bvudivP by admit. + +bind op W64.t W64.\umod "urem". +realize bvuremP by admit. + +bind op W64.t W64.andw "and". +realize bvandP by admit. + +bind op W64.t W64.orw "or". +realize bvorP by admit. + +bind op W64.t W64.(+^) "xor". +realize bvxorP by admit. + +bind op W64.t W64.invw "not". +realize bvnotP by admit. + +(* TODO: Add shifts once we have truncate/extend *) + + +(* ----------- BEGIN W128 BINDINGS ---------- *) + +bind bitstring W128.w2bits W128.bits2w W128.to_uint W128.to_sint W128.of_int W128.t 128. +realize size_tolist by auto. +realize tolistP by auto. +realize oflistP by admit. +realize touintP by admit. +realize tosintP by admit. +realize ofintP by admit. + +bind op W128.t W128.( + ) "add". +realize bvaddP by admit. + +bind op W128.t W128.( * ) "mul". +realize bvmulP by admit. + +op W128_sub (a : W128.t, b: W128.t) : W128.t = + a - b. + +bind op W128.t W128_sub "sub". +realize bvsubP by admit. + +bind op W128.t W128.\udiv "udiv". +realize bvudivP by admit. + +bind op W128.t W128.\umod "urem". +realize bvuremP by admit. + +bind op W128.t W128.andw "and". +realize bvandP by admit. + +bind op W128.t W128.orw "or". +realize bvorP by admit. + +bind op W128.t W128.(+^) "xor". +realize bvxorP by admit. + +bind op W128.t W128.invw "not". +realize bvnotP by admit. + +(* TODO: Add shifts once we have truncate/extend *) + +(* ----------- BEGIN W256 BINDINGS ---------- *) + +bind bitstring W256.w2bits W256.bits2w W256.to_uint W256.to_sint W256.of_int W256.t 256. +realize size_tolist by auto. +realize tolistP by auto. +realize oflistP by admit. +realize touintP by admit. +realize tosintP by admit. +realize ofintP by admit. + +bind op W256.t W256.( + ) "add". +realize bvaddP by admit. + +bind op W256.t W256.( * ) "mul". +realize bvmulP by admit. + +op W256_sub (a : W256.t, b: W256.t) : W256.t = + a - b. + +bind op W256.t W256_sub "sub". +realize bvsubP by admit. + +bind op W256.t W256.\udiv "udiv". +realize bvudivP by admit. + +bind op W256.t W256.\umod "urem". +realize bvuremP by admit. + +bind op W256.t W256.andw "and". +realize bvandP by admit. + +bind op W256.t W256.orw "or". +realize bvorP by admit. + +bind op W256.t W256.(+^) "xor". +realize bvxorP by admit. + +bind op W256.t W256.invw "not". +realize bvnotP by admit. + +(* TODO: Add shifts once we have truncate/extend *) + +(* ----------- BEGIN SPEC FILE BINDINDS ---------- *) + + +(* +bind circuit W32.(`<<`) "LSHIFT32". + bind circuit W32.(`>>`) "RSHIFTL_32". + *) + +print VPSUB_16u16. + +(* -- AVX2 VECTORIZED -- *) +bind circuit VPSUB_16u16 "VPSUB_16u16". +bind circuit VPADD_16u16 "VPADD_16u16". +bind circuit VPBROADCAST_16u16 "VPBROADCAST_16u16". +bind circuit VPMULH_16u16 "VPMULH_16u16". +bind circuit VPMULL_16u16 "VPMULL_16u16". +bind circuit VPMULHRS_16u16 "VPMULHRS_16u16". +bind circuit VPACKUS_16u16 "VPACKUS_16u16". +bind circuit VPMADDUBSW_256 "VPMADDUBSW_256". +bind circuit VPERMD "VPERMD". + + +(* FIXME: Check new types *) +bind circuit VPSRA_16u16 "VPSRA_16u16_new". + + +bind op [bool & W16.t] W16.init "init". +realize bvinitP by admit. + +bind op [bool & W32.t] W32.init "init". +realize bvinitP by admit. + +op map_test (f: W16.t -> W16.t) (arr: W16.t Array2.t) : W16.t Array2.t = + Array2.map f arr. + +bind op [W16.t & W16.t & Array2.t] map_test "map". +realize mapP by admit. diff --git a/examples/circuit_test.ec b/examples/circuit_test.ec new file mode 100644 index 0000000000..123b8ad872 --- /dev/null +++ b/examples/circuit_test.ec @@ -0,0 +1,158 @@ +require import AllCore Bool IntDiv CoreMap List Distr QFABV. +from Jasmin require import JModel JArray. + +require import Bindings. + + +op sub16 (a b: W16.t) = a - b. + +bind op W16.t sub16 "sub". +realize bvsubP by admit. + +type word = W32.t. + +op ROR_W32(w1 w2: W32.t) = + w1 `|>>>|` (W32.to_uint w2). + +bind op W32.t ROR_W32 "ror". +realize bvrorP by admit. + +print (`|>>|`). + +op SHR_W32(w1 w2: W32.t) = + w1 `|>>|` (W8.of_int (W32.to_uint w2)). + +bind op W32.t SHR_W32 "shr". +realize bvshrP by admit. + +lemma rw_RORw (w1: W32.t) (i: int) : + w1 `|>>|` (W8.of_int i) = ROR_W32 w1 (W32.of_int i). +by admit. qed. + +lemma rw_SHLw (w1: W32.t) (i: int) : + w1 `>>` (W8.of_int i) = SHR_W32 w1 (W32.of_int i). +by admit. qed. + + +module M = { + proc and_or_test (a: W16.t) : W16.t = { + var b : W16.t; + b <- W16.andw a (W16.of_int 514); + b <- W16.orw b (W16.of_int 1028); + return b; + } + + proc vp_test (a: W256.t) : W256.t = { + a <- VPADD_16u16 a a; + return a; + } + + proc test_of_list (a: W16.t Array2.t) : W16.t Array2.t = { + a <- Array2.of_list witness [W16.of_int 2; W16.of_int 2]; + return a; + } + + proc test_bvinit (a: W16.t) : W16.t = { + a <- W16.init (fun i => a.[i] ^^ a.[i]); + return a; + } + + proc test_init (a: W16.t Array2.t) : W16.t Array2.t = { + a <- Array2.init (fun i => a.[i]); + return a; + } + + proc __sigma_0 (w:W32.t) : W32.t = { + var w1:W32.t; + var w2:W32.t; + w1 <- w; + w2 <- w; + w <- (w `|>>|` (W8.of_int 7)); + w1 <- (w1 `|>>|` (W8.of_int 18)); + w2 <- (w2 `>>` (W8.of_int 3)); + w <- (w `^` w1); + w <- (w `^` w2); + return w; + } +}. + + +op ident_W16 (w: W16.t) : W16.t = w. +op predT_W16 (w: W16.t) : bool = true. +op times2_W16 (w: W16.t) : W16.t = w + w. +op const2_W16 (w: W16.t) : W16.t = W16.of_int 2. +op const0_W16 (w: W16.t) : W16.t = W16.of_int 0. + +op predT_W32 (w: W32.t) : bool = true. + +bind op W32.t W32.(+^) "xor". +realize bvxorP by admit. + + +bind op [bool & W32.t] W32.init "init". +realize bvinitP by admit. + +bind op [W32.t & bool] W32."_.[_]" "get". +realize bvgetP by admit. + +op small_sig0 (w : word) : word = + let x = w `|>>>|` 7 in + let y = w `|>>>|` 18 in + let z = w `>>>` 3 in + x +^ y +^ z. + +lemma small_sig (w_: W32.t) : hoare [ M.__sigma_0 : w_ = w ==> res = small_sig0 w_]. +proof. +proc. +print (`|>>|`). +proc change 3 : (w `|>>>|` ((to_uint (W8.of_int 7)) %% 32)).auto. +proc change 4 : (w1 `|>>>|` ((to_uint (W8.of_int 18)) %% 32)). auto. +proc change 5 : (w2 `>>>` ((to_uint (W8.of_int 3)) %% 32)). auto. +proc rewrite 3 /=. +proc rewrite 4 /=. +proc rewrite 5 /=. +bdep 32 32 [w_] [w] [w] small_sig0 predT_W32. +admitted. + + + +lemma small_sig_orig (w_: W32.t) : hoare [ M.__sigma_0 : w_ = w ==> res = small_sig0 w_]. +proof. +proc. +bdep 32 32 [w_] [w] [w] small_sig0 predT_W32. + +op predT_W8 (w: W8.t) : bool = true. +op and2_W8 (w: W8.t) : W8.t = W8.orw (W8.andw w (W8.of_int 2)) (W8.of_int 4). + + +print W16.( [-] ). + +lemma test_add_sub (w_: W16.t) : +hoare [ M.and_or_test : (w_ = a) ==> res = w_ ]. + proof. + proc. + bdep 8 8 [w_] [a] [b] and2_W8 predT_W8. + admitted. + +lemma test_vp (w_: W256.t) : +hoare [ M.vp_test : (w_ = a) ==> res = w_ ]. + proof. + proc. + bdep 16 16 [w_] [a] [a] times2_W16 predT_W16. + admitted. + +lemma test_of_list (w_: W16.t Array2.t) : +hoare [ M.test_of_list : (w_ = a) ==> res = w_ ]. + proof. + proc. + bdep 16 16 [w_] [a] [a] const2_W16 predT_W16. + admitted. + +lemma test_bvinit (w_: W16.t) : +hoare [ M.test_bvinit : (w_ = a) ==> res = w_ ]. + proof. + proc. + bdep 16 16 [w_] [a] [a] const0_W16 predT_W16. + admitted. + + diff --git a/examples/exclude/rejection.ec b/examples/exclude/rejection.ec new file mode 100644 index 0000000000..cf163d88e8 --- /dev/null +++ b/examples/exclude/rejection.ec @@ -0,0 +1,156 @@ +(* -------------------------------------------------------------------- *) +require import AllCore List. + +(* -------------------------------------------------------------------- *) +from Jasmin require import JWord. + +(* -------------------------------------------------------------------- *) +type w8 = W8.t. +type w16 = W16.t. +type w32 = W32.t. +type w64 = W64.t. +type w128 = W128.t. +type w256 = W256.t. + +(* -------------------------------------------------------------------- *) +op VPERMQ : w256 -> w8 -> w256. +op VPSHUFB_256 : w256 -> w256 -> w256. +op VPSRL_16u16 : w256 -> w8 -> w256. +op VPBLEND_16u16 : w256 -> w256 -> w8 -> w256. +op VPBROADCAST_16u16 : w16 -> w256. +op VPAND_256 : w256 -> w256 -> w256. +op VPCMPGT_16u16 : w256 -> w256 -> w256. +op VPACKSS_16u16 : w256 -> w256 -> w256. +op VPMOVMSKB_u256u64 : w256 -> w64. +op VINSERTI128 : w256 -> w128 -> int -> w256. +op VEXTRACTI128 : w256 -> int -> w128. +op VPADD_32u8 : w256 -> w256 -> w256. +op VPUNPCKL_32u8 : w256 -> w256 -> w256. + +(* -------------------------------------------------------------------- *) +op sst : int -> W64.t. + +(* -------------------------------------------------------------------- *) +module M = { + proc gen_matrix_sample_iterate_x3_fast_filter48( + r0 : w64, + r1 : w64, + r2 : w64, + r3 : w64, + r4 : w64, + r5 : w64, + r6 : w64 + ) = { + var permq : w8; (* VPERMQ mask *) + var shfb : w256; (* VPSHUFB mask *) + var andm : w256; + var bounds : w256; + var ones : w256; + + var f0, f1, g0, g1, g : w256; + var good : w64; + + var t0_0, t0_1, t1_0, t1_1 : w64; + + var shuffle_0 : w256; + var shuffle_0_1 : w128; + + var shuffle_1 : w256; + var shuffle_1_1 : w128; + + var shuffle_t : w256; + + var counter : w64 <- W64.zero; + + permq <- W8.of_int 148; (* FIXME: hex/bin notations *) + shfb <- W32u8.pack32 (List.map W8.of_int [ + 0; 1; 1; 2; 3; 4; 4; 5; + 6; 7; 7; 8; 9; 10; 10; 11; + 4; 5; 5; 6; 7; 8; 8; 9; + 10; 11; 11; 12; 13; 14; 14; 15 + ]); + + f0 <- VPERMQ (W4u64.pack4 [r0; r1; r2; r3]) permq; + f1 <- VPERMQ (W4u64.pack4 [r3; r4; r5; r6]) permq; + + f0 <- VPSHUFB_256 f0 shfb; + f1 <- VPSHUFB_256 f1 shfb; + + g0 <- VPSRL_16u16 f0 (W8.of_int 4); + g1 <- VPSRL_16u16 f1 (W8.of_int 4); + + f0 <- VPBLEND_16u16 f0 g0 (W8.of_int 170); (* 0xaa *) + f1 <- VPBLEND_16u16 f1 g1 (W8.of_int 170); (* 0xaa *) + + andm <- VPBROADCAST_16u16 (W16.of_int 4095); (* 0x0fff *) + f0 <- VPAND_256 f0 andm; + f1 <- VPAND_256 f1 andm; + + bounds <- VPBROADCAST_16u16 (W16.of_int 3309); + g0 <- VPCMPGT_16u16 bounds f0; + g1 <- VPCMPGT_16u16 bounds f1; + + g <- VPACKSS_16u16 g0 g1; + good <- VPMOVMSKB_u256u64 g; + + t0_0 <- good; + t0_0 <- t0_0 `&` W64.of_int 255; + shuffle_0 <- W256.of_int (W64.to_sint (sst (W64.to_uint t0_0))); + t0_0 <- (POPCNT_64 t0_0).`6; + counter <- counter + t0_0; + + t0_1 <- good; + t0_1 <- t0_1 `>>>` 16; + t0_1 <- t0_1 `&` W64.of_int 255; + shuffle_0_1 <- W128.of_int (W64.to_sint (sst (W64.to_uint t0_1))); + t0_1 <- (POPCNT_64 t0_1).`6; + counter <- counter + t0_1; + t0_1 <- t0_1 + t0_0; + + t1_0 <- good; + t1_0 <- t1_0 `>>>` 8; + t1_0 <- t1_0 `&` W64.of_int 255; + shuffle_1 <- W256.of_int (W64.to_sint (sst (W64.to_uint t1_0))); + t1_0 <- (POPCNT_64 t1_0).`6; + counter <- counter + t1_0; + t1_0 <- t1_0 + t0_1; + + t1_1 <- good; + t1_1 <- t1_1 `>>>` 24; + t1_1 <- t1_1 `&` W64.of_int 255; + shuffle_1_1 <- W128.of_int (W64.to_sint (sst (W64.to_uint t1_1))); + t1_1 <- (POPCNT_64 t1_1).`6; + counter <- counter + t1_1; + t1_1 <- t1_1 + t1_0; + + shuffle_0 <- VINSERTI128 shuffle_0 shuffle_0_1 1; + shuffle_1 <- VINSERTI128 shuffle_1 shuffle_1_1 1; + + ones <- VPBROADCAST_16u16 (W16.of_int 1); + + shuffle_t <- VPADD_32u8 shuffle_0 ones; + shuffle_0 <- VPUNPCKL_32u8 shuffle_0 shuffle_t; + + shuffle_t <- VPADD_32u8 shuffle_1 ones; + shuffle_1 <- VPUNPCKL_32u8 shuffle_1 shuffle_t; + + f0 <- VPSHUFB_256 f0 shuffle_0; + f1 <- VPSHUFB_256 f1 shuffle_1; + + (* + matrix.[u128 2*(int) matrix_offset] = (128u)f0; + matrix.[u128 2*(int) t0_0] = #VEXTRACTI128(f0, 1); + matrix.[u128 2*(int) t0_1] = (128u)f1; + matrix.[u128 2*(int) t1_0] = #VEXTRACTI128(f1, 1); + matrix_offset = t1_1; + + return counter, matrix, matrix_offset; + *) + } +}. + +hoare H : M.gen_matrix_sample_iterate_x3_fast_filter48 : true ==> false. +proof. +proc. + +idassign ^t0_0<-{2} t0_0. diff --git a/examples/mapreduce_paper.ec b/examples/mapreduce_paper.ec new file mode 100644 index 0000000000..68dc6a1c9f --- /dev/null +++ b/examples/mapreduce_paper.ec @@ -0,0 +1,128 @@ +require import AllCore Bool IntDiv CoreMap List Distr QFABV. +from Jasmin require import JModel JArray. + + +bind bitstring W8.w2bits W8.bits2w W8.to_uint W8.to_sint W8.of_int W8.t 8. +realize size_tolist by auto. +realize tolistP by auto. +realize oflistP by admit. +realize ofintP by admit. +realize touintP by admit. +realize tosintP by admit. + +bind op W8.t W8.(+^) "xor". +realize bvxorP by admit. + +op bool2bits (b : bool) : bool list = [b]. +op bits2bool (b: bool list) : bool = List.nth false b 0. + +op i2b (i : int) = (i %% 2 <> 0). +op b2si (b: bool) = 0. + +bind bitstring bool2bits bits2bool b2i b2si i2b bool 1. +realize size_tolist by auto. +realize tolistP by auto. +realize oflistP by admit. +realize ofintP by admit. +realize touintP by admit. +realize tosintP by auto. + +bind op bool (^^) "add". +realize bvaddP by admit. + +op predT_bool : bool -> bool = fun _ => true. +op xor1_bool (b: bool) = b ^^ true. + +op xor_left (w1 : W8.t) = + (w1 +^ (W8.of_int 42)) +^ (W8.of_int 213). + +op xor_right (w1 : W8.t) = + w1 +^ ((W8.of_int 42)) +^ (W8.of_int 213). + +op xor_left_spec : W8.t -> W8.t. + +bind circuit xor_left_spec "XOR_LEFT8". + +op predT_W8 : W8.t -> bool = fun _ => true. + +module M = { + proc xor_left_proc (w1: W8.t) = { + w1 <- w1 +^ (W8.of_int 42); + w1 <- w1 +^ (W8.of_int 213); + return w1; + } + + proc xor_right_proc (w1: W8.t) = { + var w2 : W8.t; + w2 <- (W8.of_int 42); + w2 <- w2 +^ (W8.of_int 213); + w1 <-w1 +^ w2; + return w1; + } +}. + +lemma xor_left_corr (w: W8.t) : + hoare [ M.xor_left_proc : w = w1 ==> res = xor_left w]. +proof. +proc. +bdep 8 8 [w] [w1] [w1] xor_left predT_W8. +admit. +admit. +qed. + +lemma xor_left_equiv_xor_right_proc (w: W8.t) : + equiv [ M.xor_left_proc ~ M.xor_right_proc : w = arg{1} /\ arg{1} = arg{2} ==> res{1} = res{2} ]. +proof. +proc. +bdepeq 8 [w1] [w1] {8 : [w1 ~ w1]} predT_W8. +admit. +auto. +qed. + +lemma xor_left_equiv_xor_right_proc_lanes (w: W8.t) : + equiv [ M.xor_left_proc ~ M.xor_right_proc : w = arg{1} /\ arg{1} = arg{2} ==> res{1} = res{2} ]. +proof. +proc. +bdepeq 1 [w1] [w1] {1 : [w1 ~ w1]} predT_bool. +admit. +auto. +qed. + + +lemma xor_left_corr_lanes (w: W8.t) : + hoare [ M.xor_left_proc : w = w1 ==> res = xor_left w]. +proof. + proc. +bdep 1 1 [w] [w1] [w1] xor1_bool predT_bool. +admit. +admit. +qed. + +lemma xor_left_corr_spec (w: W8.t) : + hoare [ M.xor_left_proc : w = w1 ==> res = xor_left w]. +proof. +proc. +bdep 8 8 [w] [w1] [w1] xor_left_spec predT_W8. +admit. +admit. +qed. + +lemma xor_left_eq_xor_right (w: W8.t) : xor_left w = xor_right w. + proof. + bdep solve. + qed. + +lemma xor_left_corr_wp (w: W8.t) : + hoare [ M.xor_left_proc : w = w1 ==> res = xor_left w]. +proof. + proc. + wp; skip => &hr. by bdep solve. +qed. + +lemma xor_left_corr_wp_alt (w: W8.t) : + hoare [ M.xor_left_proc : w = w1 ==> res = xor_left w]. +proof. + proc. + wp; skip => &hr eq. + by bdep solve. +qed. diff --git a/flake.lock b/flake.lock index d66af42062..0ef29822ea 100644 --- a/flake.lock +++ b/flake.lock @@ -1,13 +1,32 @@ { "nodes": { + "emacs-overlay": { + "inputs": { + "nixpkgs": "nixpkgs", + "nixpkgs-stable": "nixpkgs-stable" + }, + "locked": { + "lastModified": 1757668180, + "narHash": "sha256-pqxwsvg8cVOY4bgEy5PUsWLVGDbgYFDnGP20bdWhjiM=", + "owner": "nix-community", + "repo": "emacs-overlay", + "rev": "b21511280c6e1ea516e551fc5e7bb27372f6c8c3", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "emacs-overlay", + "type": "github" + } + }, "flake-compat": { "flake": false, "locked": { - "lastModified": 1696426674, - "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", "owner": "edolstra", "repo": "flake-compat", - "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", "type": "github" }, "original": { @@ -39,11 +58,11 @@ "systems": "systems_2" }, "locked": { - "lastModified": 1726560853, - "narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=", + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", "owner": "numtide", "repo": "flake-utils", - "rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", "type": "github" }, "original": { @@ -70,11 +89,43 @@ }, "nixpkgs": { "locked": { - "lastModified": 1730785428, - "narHash": "sha256-Zwl8YgTVJTEum+L+0zVAWvXAGbWAuXHax3KzuejaDyo=", + "lastModified": 1757487488, + "narHash": "sha256-zwE/e7CuPJUWKdvvTCB7iunV4E/+G0lKfv4kk/5Izdg=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "ab0f3607a6c7486ea22229b92ed2d355f1482ee0", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs-stable": { + "locked": { + "lastModified": 1751274312, + "narHash": "sha256-/bVBlRpECLVzjV19t5KMdMFWSwKLtb5RyXdjz3LJT+g=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "50ab793786d9de88ee30ec4e4c24fb4236fc2674", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-24.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1751792365, + "narHash": "sha256-J1kI6oAj25IG4EdVlg2hQz8NZTBNYvIS0l4wpr9KcUo=", "owner": "nixos", "repo": "nixpkgs", - "rev": "4aa36568d413aca0ea84a1684d2d46f55dbabad7", + "rev": "1fd8bada0b6117e6c7eb54aad5813023eed37ccb", "type": "github" }, "original": { @@ -89,17 +140,17 @@ "flake-compat": "flake-compat", "flake-utils": "flake-utils_2", "mirage-opam-overlays": "mirage-opam-overlays", - "nixpkgs": "nixpkgs", + "nixpkgs": "nixpkgs_2", "opam-overlays": "opam-overlays", "opam-repository": "opam-repository", "opam2json": "opam2json" }, "locked": { - "lastModified": 1736955560, - "narHash": "sha256-9I42xwKXH7h+jQGJQ8t797j/mWylIItIljRLm44CHS8=", + "lastModified": 1756988401, + "narHash": "sha256-S+zc1RYWZBGKnbrEWbyJ6fGt8ft/9d4BzpigSN2PpqE=", "owner": "tweag", "repo": "opam-nix", - "rev": "5f760f445d6693eb086327fa7d7ae8e43c906718", + "rev": "0c9c0e0c058dfb8de56adff612f2c776530f7f1e", "type": "github" }, "original": { @@ -111,11 +162,11 @@ "opam-overlays": { "flake": false, "locked": { - "lastModified": 1726822209, - "narHash": "sha256-bwM18ydNT9fYq91xfn4gmS21q322NYrKwfq0ldG9GYw=", + "lastModified": 1741116009, + "narHash": "sha256-Z0PIW82fHJFvAv/JYpAffnp2DaOjLhsPutqyIrORZd0=", "owner": "dune-universe", "repo": "opam-overlays", - "rev": "f2bec38beca4aea9e481f2fd3ee319c519124649", + "rev": "e031bb64e33bf93be963e9a38b28962e6e14381f", "type": "github" }, "original": { @@ -127,11 +178,11 @@ "opam-repository": { "flake": false, "locked": { - "lastModified": 1736935757, - "narHash": "sha256-LNkGSkZJXJmxpUd+luDUIIV/1B5MZIBMTB1qZqypa4o=", + "lastModified": 1756946712, + "narHash": "sha256-jo24cfjG/Yf1yPppKtL5ogjw6YBCMaMNsfkktRUm018=", "owner": "ocaml", "repo": "opam-repository", - "rev": "a8b00ead922e2049581ab16994586ed4ddbdb784", + "rev": "e28312d8e0d10f256ec9998ff7e868cb6e010778", "type": "github" }, "original": { @@ -145,14 +196,15 @@ "nixpkgs": [ "opam-nix", "nixpkgs" - ] + ], + "systems": "systems_3" }, "locked": { - "lastModified": 1671540003, - "narHash": "sha256-5pXfbUfpVABtKbii6aaI2EdAZTjHJ2QntEf0QD2O5AM=", + "lastModified": 1749457947, + "narHash": "sha256-+QVm+HOYikF3wUhqSIV8qJbE/feSG+p48fgxIosbHS0=", "owner": "tweag", "repo": "opam2json", - "rev": "819d291ea95e271b0e6027679de6abb4d4f7f680", + "rev": "0ecd66fc2bfb25d910522c990dd36412259eac1f", "type": "github" }, "original": { @@ -178,42 +230,43 @@ "type": "github" } }, - "prover_cvc5_1_0_9": { + "prover_cvc5_1_3_0": { "flake": false, "locked": { - "lastModified": 1702998934, - "narHash": "sha256-AwUQHFftn51Xt6HtmDsWAdkOS8i64r2FhaHu31KYwZA=", + "lastModified": 1750292852, + "narHash": "sha256-w8rIGPG9BTEPV9HG2U40A4DYYnC6HaWbzqDKCRhaT00=", "owner": "cvc5", "repo": "cvc5", - "rev": "8fca72aebcb5293434c3207dca081a845ff8d6fe", + "rev": "02c4e43d191f86b67a8a6d615544630a8df0f18e", "type": "github" }, "original": { "owner": "cvc5", - "ref": "cvc5-1.0.9", + "ref": "cvc5-1.3.0", "repo": "cvc5", "type": "github" } }, - "prover_z3_4_12_6": { + "prover_z3_4_14_1": { "flake": false, "locked": { - "lastModified": 1708814107, - "narHash": "sha256-X4wfPWVSswENV0zXJp/5u9SQwGJWocLKJ/CNv57Bt+E=", + "lastModified": 1741647008, + "narHash": "sha256-pTsDzf6Frk4mYAgF81wlR5Kb1x56joFggO5Fa3G2s70=", "owner": "z3prover", "repo": "z3", - "rev": "fa2c0e027894a8d55d2b841e27cbeecc99692a3f", + "rev": "3c0d786e6e86b6a10cbc14703c3f863c568b85dd", "type": "github" }, "original": { "owner": "z3prover", - "ref": "z3-4.12.6", + "ref": "z3-4.14.1", "repo": "z3", "type": "github" } }, "root": { "inputs": { + "emacs-overlay": "emacs-overlay", "flake-utils": "flake-utils", "nixpkgs": [ "opam-nix", @@ -221,23 +274,23 @@ ], "opam-nix": "opam-nix", "prover_cvc4_1_8": "prover_cvc4_1_8", - "prover_cvc5_1_0_9": "prover_cvc5_1_0_9", - "prover_z3_4_12_6": "prover_z3_4_12_6", + "prover_cvc5_1_3_0": "prover_cvc5_1_3_0", + "prover_z3_4_14_1": "prover_z3_4_14_1", "stable": "stable" } }, "stable": { "locked": { - "lastModified": 1717179513, - "narHash": "sha256-vboIEwIQojofItm2xGCdZCzW96U85l9nDW3ifMuAIdM=", + "lastModified": 1751290243, + "narHash": "sha256-kNf+obkpJZWar7HZymXZbW+Rlk3HTEIMlpc6FCNz0Ds=", "owner": "nixos", "repo": "nixpkgs", - "rev": "63dacb46bf939521bdc93981b4cbb7ecb58427a0", + "rev": "5ab036a8d97cb9476fbe81b09076e6e91d15e1b6", "type": "github" }, "original": { "owner": "nixos", - "ref": "24.05", + "ref": "release-24.11", "repo": "nixpkgs", "type": "github" } @@ -271,6 +324,21 @@ "repo": "default", "type": "github" } + }, + "systems_3": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } } }, "root": "root", diff --git a/flake.nix b/flake.nix index 77d38a85ff..085ea5555c 100644 --- a/flake.nix +++ b/flake.nix @@ -4,22 +4,23 @@ flake-utils.url = "github:numtide/flake-utils"; - nixpkgs.url = "github:nixos/nixpkgs/24.05"; - stable.url = "github:nixos/nixpkgs/24.05"; + # nixpkgs.url = "github:nixos/nixpkgs/release-24.11"; + stable.url = "github:nixos/nixpkgs/release-24.11"; nixpkgs.follows = "opam-nix/nixpkgs"; + emacs-overlay.url = "github:nix-community/emacs-overlay"; prover_cvc4_1_8 = { url = "github:CVC4/CVC4-archived/1.8"; flake = false; }; - prover_cvc5_1_0_9 = { - url = "github:cvc5/cvc5/cvc5-1.0.9"; + prover_cvc5_1_3_0 = { + url = "github:cvc5/cvc5/cvc5-1.3.0"; flake = false; }; - prover_z3_4_12_6 = { - url = "github:z3prover/z3/z3-4.12.6"; + prover_z3_4_14_1 = { + url = "github:z3prover/z3/z3-4.14.1"; flake = false; }; }; @@ -40,7 +41,7 @@ }; query = devPackagesQuery // { - ocaml-base-compiler = "4.14.2"; + ocaml-base-compiler = "4.14.1"; }; scope = on.buildOpamProject' { } ./. query; @@ -54,9 +55,23 @@ ''; doNixSupport = false; }); - conf-pkg-config = prev.conf-pkg-config.overrideAttrs (oa: { - nativeBuildInputs = oa.nativeBuildInputs ++ [pkgs.pkg-config]; + conf-zlib = prev.conf-zlib.overrideAttrs (finalAttrs: prevAttrs: rec { + nativeBuildInputs = prevAttrs.nativeBuildInputs + ++ (with pkgs; [ pkg-config ]); }); + conf-git = prev.conf-git.overrideAttrs (finalAttrs: prevAttrs: rec { + nativeBuildInputs = prevAttrs.nativeBuildInputs + ++ (with pkgs; [ git ]); + buildInputs = prevAttrs.buildInputs + ++ (with pkgs; [ git ]); + }); + alt-ergo = prev.alt-ergo.overrideAttrs (finalAttrs: prevAttrs: rec { + nativeBuildInputs = prevAttrs.nativeBuildInputs + ++ (with pkgs; [ darwin.sigtool ]); + }); + frama-c = prev.frama-c.overrideAttrs (finalAttrs: prevAttrs: rec { + configureFlags = (prevAttrs.configureFlags or []) ++ ["--prefix=${prev.frama-c}/lib"]; + }); }; scope' = scope.overrideScope overlay; @@ -78,20 +93,48 @@ src = inputs."${"prover_" + pkg + "_" + builtins.replaceStrings ["."] ["_"] version}"; }); - mkAltErgo = version: - ((on.queryToScope { } (query // { alt-ergo = version; })).overrideScope overlay).alt-ergo; + mkAltErgo = version: (on.queryToScope { } (query // { alt-ergo = version; })).alt-ergo; + + devTools = + (let + overlays = [ (import inputs.emacs-overlay) ]; + pkgs = import nixpkgs { + inherit system overlays; + }; + in + (with pkgs; [ + (emacsWithPackagesFromUsePackage { + config = ''(setq easycrypt-prog-name "ec.native")''; + defaultInitFile = true; + alwaysEnsure = true; + package = pkgs.emacs; + extraEmacsPackages = epkgs: [ epkgs.proof-general ]; + }) + bashInteractive + git + difftastic + ]) + ++ + (with pkgs; + lib.optionals (!stdenv.isDarwin) [ perf-tools ]) + ); in rec { legacyPackages = scope'; packages = rec { - z3 = mkProverPackage "z3" "4.12.6"; + z3 = mkProverPackage "z3" "4.14.1"; cvc4 = mkProverPackage "cvc4" "1.8"; - cvc5 = mkProverPackage "cvc5" "1.0.9"; - altErgo = mkAltErgo "2.4.3"; + cvc5 = mkProverPackage "cvc5" "1.3.0"; + altErgo = mkAltErgo "2.4.2"; provers = pkgs.symlinkJoin { name = "provers"; - paths = [ altErgo z3 cvc4 cvc5 ]; + paths = [ + # altErgo + z3 + # cvc4 + cvc5 + ]; }; with_provers = pkgs.symlinkJoin { @@ -102,12 +145,40 @@ default = main; }; - devShells.default = pkgs.mkShell { + devShells.barebones = pkgs.mkShell { inputsFrom = [ scope'.easycrypt ]; buildInputs = - devPackages - ++ [ pkgs.git scope'.why3 packages.provers ] - ++ (with pkgs.python3Packages; [ pyyaml ]); + devPackages + ++ [ scope'.why3 ] + ++ (with pkgs.python3Packages; [ pyyaml ]); }; + + devShells.noProvers = pkgs.mkShell rec { + inputsFrom = [ scope'.easycrypt ]; + buildInputs = + devPackages + ++ devTools + ++ [ scope'.why3 ] + ++ (with pkgs.python3Packages; [ pyyaml ]); + SHELL = ''${pkgs.bashInteractive + "/bin/bash"}''; + shellHook = builtins.replaceStrings ["\n"] [" "] '' + export SHELL=${SHELL} && + export PATH=$PATH:`realpath .` + ''; + }; + + devShells.withDevTools = pkgs.mkShell rec { + inputsFrom = [ scope'.easycrypt ]; + buildInputs = + devPackages + ++ devTools + ++ [ scope'.why3 packages.provers ] + ++ (with pkgs.python3Packages; [ pyyaml ]); + SHELL = ''${pkgs.bashInteractive + "/bin/bash"}''; + shellHook = builtins.replaceStrings ["\n"] [" "] '' + export SHELL=${SHELL} && + export PATH=$PATH:`realpath .` + ''; + }; }); } diff --git a/libs/lospecs/aig.ml b/libs/lospecs/aig.ml new file mode 100644 index 0000000000..89a398cd85 --- /dev/null +++ b/libs/lospecs/aig.ml @@ -0,0 +1,685 @@ +(* -------------------------------------------------------------------- *) +type name = int +[@@deriving yojson] + +(* -------------------------------------------------------------------- *) +type var = name * int +[@@deriving yojson] + +(* -------------------------------------------------------------------- *) +type node_r = + | False + | Input of var + | And of node * node +[@@deriving yojson] + +and node = { + gate : node_r; + id : int; + neg : node; +} +[@@deriving yojson] + +(* -------------------------------------------------------------------- *) +let fresh = + let counter = ref 0 in + fun () -> incr counter; !counter + +(* -------------------------------------------------------------------- *) +type reg = node array +[@@deriving yojson] + +(* -------------------------------------------------------------------- *) +module HCons : sig + val hashcons : node_r -> node + + val clear : unit -> unit +end = struct + module H = Weak.Make(struct + type t = node + + let hash (x : t) : int = + match x.gate with + | False -> + Hashtbl.hash False + | Input v -> + Hashtbl.hash v + | And (n1, n2) -> + Hashtbl.hash (abs n1.id, abs n2.id) + + let equal (n1 : node) (n2 : node) = + match n1.gate, n2.gate with + | False, False -> + true + | Input v1, Input v2 -> + v1 = v2 + | And (n1, m1), And (n2, m2) -> + n1 == n2 && m1 == m2 + | _, _ -> + false + end) + + let tag = ref 1 + + let htable = H.create 5003 + + let clear = fun () -> H.clear htable + + let hashcons (n : node_r) = + let rec pos = { gate = n; id = !tag; neg = neg; } + and neg = { gate = n; id = - !tag; neg = pos; } in + + let o = H.merge htable pos in + + if o == pos then incr tag; o +end + +(* -------------------------------------------------------------------- *) +let rec pp_node (fmt : Format.formatter) (n : node) = + match n with + | { gate = False; id } when 0 < id -> + Format.fprintf fmt "⊥" + + | { gate = False; } -> + Format.fprintf fmt "⊤" + + | { gate = Input (n, i); id; } -> + Format.fprintf fmt "%s%d#%0.4x" + (if 0 < id then "" else "¬") n i + + | { gate = And (n1, n2); id; } when 0 < id -> + Format.fprintf fmt "(%a) ∧ (%a)" pp_node n1 pp_node n2 + + | { gate = And (n1, n2); } -> + Format.fprintf fmt "¬((%a) ∧ (%a))" pp_node n1 pp_node n2 + +(* -------------------------------------------------------------------- *) +let mk (n : node_r) : node = + HCons.hashcons n + +(* -------------------------------------------------------------------- *) +let false_ : node = + mk False + +(* -------------------------------------------------------------------- *) +let true_ : node = + false_.neg + +(* -------------------------------------------------------------------- *) +let input (i : var) : node = + mk (Input i) + +(* -------------------------------------------------------------------- *) +let constant (b : bool) : node = + if b then true_ else false_ + +(* -------------------------------------------------------------------- *) +let neg (n : node) : node = + n.neg + +(* -------------------------------------------------------------------- *) +let and_ (n1 : node) (n2 : node) : node = + match () with + | _ when n1 == n2 -> n1 + | _ when n1 == n2.neg -> false_ + | _ when n1 == false_ -> false_ + | _ when n2 == false_ -> false_ + | _ when n1 == true_ -> n2 + | _ when n2 == true_ -> n1 + | _ -> mk (And (n1, n2)) + +(* -------------------------------------------------------------------- *) +let nand (n1 : node) (n2 : node) : node = + neg (and_ n1 n2) + +(* -------------------------------------------------------------------- *) +let or_ (n1 : node) (n2 : node) : node = + nand (neg n1) (neg n2) + +(* -------------------------------------------------------------------- *) +let xor (n1 : node) (n2 : node) : node = + let n = nand n1 n2 in nand (nand n1 n) (nand n2 n) + +(* -------------------------------------------------------------------- *) +let xnor (n1 : node) (n2 : node) : node = + neg (xor n1 n2) + +(* -------------------------------------------------------------------- *) +let get_bit (b : bytes) (i : int) = + Char.code (Bytes.get b (i / 8)) lsr (i mod 8) land 0b1 <> 0 + +(* -------------------------------------------------------------------- *) +let env_of_regs (rs : bytes list) = + let rs = Array.of_list rs in + fun ((n, i) : var) -> get_bit rs.(n) i + +(* ==================================================================== *) +let map (env : var -> node option) : node -> node = + let cache : (int, node) Hashtbl.t = Hashtbl.create 0 in + + let rec doit (n : node) : node = + let mn = + match Hashtbl.find_option cache (abs n.id) with + | None -> + let mn = doit_r n.gate in + Hashtbl.add cache (abs n.id) mn; + mn + | Some mn -> + mn + in + if 0 < n.id then mn else neg mn + + and doit_r (n : node_r) = + match n with + | False -> + false_ + | Input v -> + Option.default (input v) (env v) + | And (n1, n2) -> + and_ (doit n1) (doit n2) + + in fun (n : node) -> doit n + +(* -------------------------------------------------------------------- *) +let maps (env : var -> node option) : reg -> reg = + fun r -> Array.map (map env) r + +(* ==================================================================== *) +let equivs (inputs : (var * var) list) (c1 : reg) (c2 : reg) : bool = + let inputs = Map.of_seq (List.to_seq inputs) in + let env (v : var) = Option.map input (Map.find_opt v inputs) in + Array.for_all2 (==) (maps env c1) c2 + +(* ==================================================================== *) +let eval (env : var -> bool) = + let cache : (int, bool) Hashtbl.t = Hashtbl.create 0 in + + let rec for_node (n : node) = + let value = + match Hashtbl.find_option cache (abs n.id) with + | None -> + let value = for_node_r n.gate in + Hashtbl.add cache (abs n.id) value; + value + | Some value -> + value + + in if 0 < n.id then value else not value + + and for_node_r (n : node_r) = + match n with + | False -> false + | Input x -> env x + | And (n1, n2) -> for_node n1 && for_node n2 + + in fun (n : node) -> for_node n + +(* -------------------------------------------------------------------- *) +let evals (env : var -> bool) = + List.map (eval env) + +(* -------------------------------------------------------------------- *) +let eval0 (n : node) = + eval (fun (_ : var) -> false) n + +(* ==================================================================== *) +module VarRange : sig + type 'a t + + val empty : 'a t + + val push : 'a t -> ('a * int) -> 'a t + + val contents : 'a t -> ('a * (int * int) list) list + + val pp : + (Format.formatter -> 'a -> unit) + -> Format.formatter + -> 'a t + -> unit +end = struct + type range = int * int + + type ranges = range list + + type 'a dep1 = 'a * ranges + + type 'a t = ('a, ranges) Map.t + + let empty : 'a t = + Map.empty + + let rec add (rg : ranges) (v : int) = + match rg with + | [] -> + [(v, v)] + + (* join two segments *) + | (lo, hi) :: (lo', hi') :: tl when hi+1 = v && v+1 = lo' -> + (lo, hi') :: tl + + (* add to the front of a segment *) + | (lo, hi) :: tl when v+1 = lo -> + (v, hi) :: tl + + (* add to the back of a segment *) + | (lo, hi) :: tl when hi+1 = v -> + (lo, v) :: tl + + | hd :: tl -> + hd :: add tl v + + let push (r : 'a t) ((n, i) : 'a * int) : 'a t = + let change (rg : ranges option) = + Some (add (Option.default [] rg) i) + in Map.modify_opt n change r + + let contents (r : 'a t) : ('a * ranges) list = + Map.bindings r + + let pp + (pp : Format.formatter -> 'a -> unit) + (fmt : Format.formatter) + (r : 'a t) + = + let pp_range (fmt : Format.formatter) ((lo, hi) : range) = + if lo = hi then + Format.fprintf fmt "%d" lo + else + Format.fprintf fmt "%d-%d" lo hi in + + let pp_ranges (fmt : Format.formatter) (rgs : ranges) = + Format.fprintf fmt "%a" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",") + pp_range) + rgs in + + let pp_dep1 (fmt : Format.formatter) ((v, rgs) : 'a dep1) = + Format.fprintf fmt "%a#%a" pp v pp_ranges rgs in + + Format.fprintf fmt "%a" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "; ") + pp_dep1) + (Map.bindings r) +end + +(* ==================================================================== *) +let deps_ () = + let cache : (int, var Set.t) Hashtbl.t = Hashtbl.create 0 in + + let rec doit_force (n : node) = + match n.gate with + | False -> Set.empty + | Input v -> Set.singleton v + | And (n1, n2) -> Set.union (doit n1) (doit n2) + + and doit (n : node) = + match Hashtbl.find_option cache (abs n.id) with + | Some value -> + value + | None -> + let value = doit_force n in + Hashtbl.add cache (abs n.id) value; value + + in fun (n : node) -> doit n + +(* -------------------------------------------------------------------- *) +let deps (r : reg) = + let out = ref [] in + + let push (hi : int) (dhi : var Set.t) = + match !out with + | _ when Set.is_empty dhi -> + () + | ((lo, v), dlo) :: tl when v+1 = hi && not (Set.disjoint dlo dhi) -> + out := ((lo, hi), Set.union dlo dhi) :: tl + | _ -> + out := ((hi, hi), dhi) :: !out in + + Array.iteri push (Array.map (deps_ ()) r); + !out + |> List.rev_map (fun (r, vs) -> + let vs = + Set.fold + (fun v vs -> VarRange.push vs v) + vs VarRange.empty + in (r, vs) + ) + |> List.sort (fun (r1, _) (r2, _) -> compare r1 r2) + +exception AigerError of string + +(* -------------------------------------------------------------------- *) +(* SERIALIZATION *) +(* Return map of indice renaming + list of and gates (increasing order) + (max variable index, and gate count, input gate count) *) +let aiger_preprocess ~(input_count: int) (r: reg) : (node -> int) * (node list) * (int * int * int) = + let cache : (int, int) Hashtbl.t = Hashtbl.create 0 in + let count_and = ref 0 in + let and_gates = ref [] in + + let rec doit (n: node) : unit = + match Hashtbl.find_option cache (abs n.id) with + | Some v -> () + | None -> + let value = doit_force n in + Hashtbl.add cache (abs n.id) value + + and doit_force (n: node) = + match n.gate with + | False -> 0 + | Input (v, i) -> 64*v + i + | And (n1, n2) -> + doit n1; doit n2; + incr count_and; + and_gates := n::(!and_gates); + !count_and + in + + Array.iter doit r; + let and_cnt = !count_and in + let inp_cnt = input_count in + let id_map = + Hashtbl.to_seq cache |> Map.of_seq + in + let id_map = (function + | { gate = False; id } -> (if 0 < id then 0 else 1) + | { gate = And _; id } -> ((Map.find (abs id) id_map) + inp_cnt) lsl 1 + (if 0 < id then 0 else 1) + | { gate = Input _; id } -> (Map.find (abs id) id_map) lsl 1 + (if 0 < id then 0 else 1) + ) in + id_map, + List.sort (fun n1 n2 -> compare (id_map n1) (id_map n2)) !and_gates, + (and_cnt + inp_cnt, and_cnt, inp_cnt) + +let aiger_serialize_int (id: int) : string = + if not (id > 0) then raise (AigerError "serialize_int"); + let mask = 0x7f in + let rec doit (id: int) : int list = + if id < 0x80 then + [id] + else + ((id land mask) lor (0x80))::(doit (id lsr 7)) + in + + List.fold_left (fun acc id -> (Format.sprintf "%c" (char_of_int id)) ^ acc) "" (List.rev (doit id)) + +let pp_aiger_int fmt (id: int) : unit = + Format.fprintf fmt "%s" (aiger_serialize_int id) + +(* FIXME PR: Look at correction of this and after making sure it is correct *) +(* we can remove or do something else with the asserts *) +(* but they should not be triggered on a normal execution *) +let pp_aiger_and fmt ((gid, id1, id2): int * int * int) : unit = + if not (gid > id1 && id1 > id2) then Format.eprintf "gid : %d | id1: %d | id2: %d@." gid id1 id2; + assert (gid > id1 && id1 > id2); + let delta0 = gid - id1 in + let delta1 = id1 - id2 in + assert(delta0 > 0 && delta1 > 0); + assert(id1 = gid - delta0); + assert(gid - delta0 - delta1 = id2); + Format.fprintf fmt "%a%a" pp_aiger_int (gid - id1) pp_aiger_int (id1 - id2) + +(* + mvi -> Max Variable Index + agc -> And Gate Count + igc -> Input Gate Count + lgc -> Latch Gate Count + ogc -> Output Gate Count +*) +let write_aiger_bin + ~(input_count: int) + ?(inp_name_map : int -> string = fun (i: int) -> "inp" ^ (string_of_int i)) + oc + (r: reg) = + let aiger_id_of_node, and_gates, (mvi, agc, igc) = aiger_preprocess ~input_count r in + + let ogc = Array.length r in + let lgc = 0 in + Printf.fprintf oc "aig %d %d %d %d %d\n" mvi igc lgc ogc agc; + Array.iter (fun n -> Printf.fprintf oc "%d\n" (aiger_id_of_node n)) r; + List.iter (function + | { gate = And (n1, n2); } as n -> + let id = aiger_id_of_node n in + let id1 = aiger_id_of_node n1 in + let id2 = aiger_id_of_node n2 in + let id = id - (id land 1) in + let id1, id2 = if id1 > id2 then id1, id2 else id2, id1 in + Printf.fprintf oc "%s" (Format.asprintf "%a" pp_aiger_and (id, id1, id2)) + | _ -> assert false (* Should not be triggered *) + ) and_gates; + for i = 0 to igc-1 do + Printf.fprintf oc "i%d %s@\n" i (inp_name_map i) + done + +let write_aiger_bin_temp + ~(input_count: int) + ?(inp_name_map: (int -> string) option) + ?(name: string = "circuit") + (r: reg) = + let tf_name, tf_oc = Filename.open_temp_file ~mode:[Open_binary] name ".aig" in + let tf_oc = BatIO.output_channel ~cleanup:true tf_oc in + write_aiger_bin ~input_count ?inp_name_map tf_oc r; + tf_name + +(* Assumes inputs are already matched *) +let abc_check_equiv + ?(r1_name = "r1") + ?(r2_name = "r2") + ~(input_count: int) + ?(inp_name_map: (int -> string) option) + (r1: reg) (r2: reg) : bool = + + let tf1_name, tf1_oc = Filename.open_temp_file ~mode:[Open_binary] r1_name ".aig" in + let tf2_name, tf2_oc = Filename.open_temp_file ~mode:[Open_binary] r2_name ".aig" in + Format.eprintf "Created temp files (%s) (%s)!@." tf1_name tf2_name; + let tf1_oc = BatIO.output_channel ~cleanup:true tf1_oc in + let tf2_oc = BatIO.output_channel ~cleanup:true tf2_oc in + write_aiger_bin ~input_count ?inp_name_map tf1_oc r1; + write_aiger_bin ~input_count ?inp_name_map tf2_oc r2; + Format.eprintf "Wrote aig files!@."; + BatIO.close_out tf1_oc; BatIO.close_out tf2_oc; + let abc_command = Format.sprintf "cec %s %s" tf1_name tf2_name in + Format.eprintf "Abc command: %s@." abc_command; + let abc_output_c, abc_in = Unix.open_process "abc" in +(* let abc_in = BatIO.output_channel ~cleanup:true abc_in in *) + BatIO.write_string abc_in (abc_command ^ "\n"); + BatIO.close_out abc_in; +(* let abc_output_c = BatIO.input_channel ~autoclose:true ~cleanup:true abc_output_c in *) + (* FIXME: Get the actual output in all cases from abc *) + let re = Str.regexp {|.*Networks are equivalent.*|} in + Format.eprintf "Before read@."; + let abc_output = BatIO.read_all abc_output_c in + Format.eprintf "====== BEGIN ABC OUTPUT =====@.%s@.======= END ABC OUTPUT =====@." abc_output; + let abc_output = String.replace_chars (function | '\n' -> "|" | c -> String.of_char c) abc_output in + if Str.string_match re abc_output 0 then true else false + +(* -------------------------------------------------------------------- *) +exception InvalidWire + +(* -------------------------------------------------------------------- *) +(* true -> positive wire *) +let u2si (u : int) : bool * int = + if u < 0 then raise InvalidWire; + let s = (u land 0b1) = 0 in + let i = u lsr 1 in (* We divide by 2 *) + (s, i) + +(* -------------------------------------------------------------------- *) +let si2u ((b, i) : bool * int) : int = + assert (0 <= i); + (i lsl 1) lor (match b with true -> 0 | false -> 1) + +(* -------------------------------------------------------------------- *) +exception InvalidAIG of string + +(* -------------------------------------------------------------------- *) +(* Load an aig file *) +let load (inp : IO.input) : reg * (Set.String.t * string array) option = + let parse_asuint = + let re = Str.regexp "^[0-9]+$" in + + let doit (x : string) = + if not (Str.string_match re x 0) then + raise (InvalidAIG ("not a valid uint: " ^ x)); + int_of_string x (* FIXME: overflow *) + in fun x -> doit x in + + let header = String.trim (IO.read_line inp) in + let header = Str.split (Str.regexp "[ \t]+") header in + let header = Array.of_list header in + + if Array.length header <> 6 || header.(0) <> "aig" then + raise (InvalidAIG "invalid header"); + + let c_m = parse_asuint header.(1) in (* maximum variable index *) + let c_i = parse_asuint header.(2) in (* number of inputs *) + let c_l = parse_asuint header.(3) in (* number of latches *) + let c_o = parse_asuint header.(4) in (* number of outputs *) + let c_a = parse_asuint header.(5) in (* number of AND gates *) + + (* We have c_l = 0 so /\ c_m = c_i + c_l + c_a + * + * Hence: c_m = c_i + c_a + *) + + if c_m <> c_i + c_l + c_a || c_l <> 0 then + raise (InvalidAIG "invalid header (sum)"); + + let outputs = ref [] in + + (* Reading outputs *) + for _ = 1 to c_o do + let output = String.trim (IO.read_line inp) in + let (_, u) as output = u2si (parse_asuint output) in + + if not (0 <= u && u <= c_m) then + raise (InvalidAIG "invalid output"); + + outputs := output :: !outputs + done; + + let outputs = Array.of_list (List.rev !outputs) in + + (* Reading arguments of AND gate *) + let read_uint () = + let exception Done in + + let i, o = ref 0, ref 0 in + try + while true do + assert (!o < 4); + let d = IO.read_byte inp in + i := !i lor ((d land 0x7f) lsl (7 * !o)); + o := !o + 1; + if (d land 0x80) = 0 then + raise Done + done; + assert false + with Done -> !i + in + + + let gates = List.fold_left (fun map -> function + | 0 -> + Map.add 0 false_ map + + | i when 0 < i && i <= c_i -> + Map.add i (input (0, i-1)) map + + | i when c_i < i && i <= c_i + c_a -> + let delta0 = read_uint () in + let delta1 = read_uint () in + + if delta0 = 0 then + raise (InvalidAIG "invalid delta0"); + + (* delta0 = lhs - rhs0, delta1 = rhs0 - rhs1 *) + + let lhs = 2 * i in + let rhs0_ = lhs - delta0 in + let rhs1_ = rhs0_ - delta1 in + + if lhs = c_i*2 + 2 then + Format.eprintf "Lhs: %d | Rhs0: %d | Rhs1: %d@." lhs rhs0_ rhs1_; + + let (b1, u1) = try + u2si rhs0_ + with InvalidWire -> + Format.eprintf "Invalid wire for rhs0 for params: lhs: %d | rhs0: %d | rhs1: %d@." lhs rhs0_ rhs1_; assert false + in + let (b2, u2) = try + u2si rhs1_ + with InvalidWire -> + Format.eprintf "Invalid wire for rhs1 for params: lhs: %d | rhs0: %d | rhs1: %d@." lhs rhs0_ rhs1_; assert false + in + + let n1 = Map.find u1 map in + let n1 = if b1 then n1 else n1.neg in + let n2 = Map.find u2 map in + let n2 = if b2 then n2 else n2.neg in + + if not (u1 <= c_m && u2 <= c_m) then + raise (InvalidAIG "invalid delta1"); + + Map.add i (and_ n1 n2) map + + | _ -> + assert false + ) Map.empty (List.init (c_i + c_a + 1) (fun i -> i)) in + + (* Reading annotations *) + let ainputs = Array.make c_i None in + + begin try + while true do + let exception Continue in + + try + let line = String.trim (IO.read_line inp) in + + if line = "" then + raise Continue; + if line = "c" then + raise IO.No_more_input; + + if not ( + Str.string_match + (Str.regexp "^i\\([0-9]+\\)[ \t]+\\(.*\\)$") + line 0 + ) then raise (InvalidAIG ("invalid annotation: " ^ line)); + + let s = Str.matched_group 2 line in + let i = parse_asuint (Str.matched_group 1 line) in + + if not (i < c_i) then + raise (InvalidAIG "invalid annotation (index)"); + + if Option.is_some ainputs.(i) then + raise (InvalidAIG "invalid annotation (dup. index)"); + + ainputs.(i) <- Some s + + with Continue -> () + done + + with IO.No_more_input -> () end; + + let ainputs = + if Array.for_all Option.is_none ainputs then + None + else if Array.exists Option.is_none ainputs then + raise (InvalidAIG "invalid annotation (partial)") + else + let ainputs = Array.map Option.get ainputs in + let keys = Set.String.of_array ainputs in + + if Set.String.cardinal keys <> Array.length ainputs then + raise (InvalidAIG "invalid annotation (dup)"); + Some (keys, ainputs) + in + + (* Construct network *) + Array.map (fun (b, i) -> + if b then (Map.find i gates).neg else Map.find i gates + ) outputs, ainputs diff --git a/libs/lospecs/ast.ml b/libs/lospecs/ast.ml new file mode 100644 index 0000000000..7df7bd130e --- /dev/null +++ b/libs/lospecs/ast.ml @@ -0,0 +1,104 @@ +(* -------------------------------------------------------------------- *) +type symbol = Ptree.symbol [@@deriving yojson] + +(* FIXME PR: Maybe get a decl file to declare errors and other common things? *) +exception DestrError of string + +(* -------------------------------------------------------------------- *) +module Ident : sig + type ident [@@deriving yojson] + + val create : string -> ident + val name : ident -> string + val id : ident -> int +end = struct + type ident = symbol * int [@@deriving yojson] + + let create (x : string) : ident = (x, Oo.id (object end)) + let name ((x, _) : ident) : string = x + let id ((_, i) : ident) : int = i +end + +module IdentMap = Map.Make(struct + type t = Ident.ident + let compare a b = (Ident.id a) - (Ident.id b) +end) + +(* -------------------------------------------------------------------- *) +type ident = Ident.ident [@@deriving yojson] + +(* -------------------------------------------------------------------- *) +type aword = [ `W of int ] [@@deriving yojson] + +(* -------------------------------------------------------------------- *) +type atype = [ aword | `Signed | `Unsigned ] [@@deriving yojson] + +(* -------------------------------------------------------------------- *) +type aarg = ident * aword [@@deriving yojson] + +(* -------------------------------------------------------------------- *) +type aargs = aarg list [@@deriving yojson] + +(* -------------------------------------------------------------------- *) +type lr = [`L | `R] [@@deriving yojson] +type la = [`L | `A] [@@deriving yojson] +type us = [`U | `S] [@@deriving yojson] +type hl = [`H | `L] [@@deriving yojson] +type hld = [hl | `D] [@@deriving yojson] +type mulk = [`U of hld | `S of hld | `US] [@@deriving yojson] + +(* -------------------------------------------------------------------- *) +type aexpr_ = + | EVar of ident + | EInt of int + | ESlice of aexpr * (aexpr * int * int) + | EAssign of aexpr * (aexpr * int * int) * aexpr + | EApp of ident * aexpr list + | EMap of (aword * aword) * (aargs * aexpr) * aexpr list + | EConcat of aword * aexpr list + | ERepeat of aword * (aexpr * int) + | EShift of lr * la * (aexpr * aexpr) + | EExtend of us * aword * aexpr + | ESat of us * aword * aexpr + | ELet of (ident * aargs option * aexpr) * aexpr + | ECond of aexpr * (aexpr * aexpr) + | ENot of aword * aexpr + | EIncr of aword * aexpr + | EAdd of aword * [`Sat of us | `Word] * (aexpr * aexpr) + | ESub of aword * (aexpr * aexpr) + | EMul of mulk * aword * (aexpr * aexpr) + | EOr of aword * (aexpr * aexpr) + | EXor of aword * (aexpr * aexpr) + | EAnd of aword * (aexpr * aexpr) + | ECmp of aword * us * [`Gt | `Ge] * (aexpr * aexpr) + | EPopCount of aword * aexpr +[@@deriving yojson] + +and aexpr = { node : aexpr_; type_ : atype } [@@deriving yojson] + +(* -------------------------------------------------------------------- *) +type adef = { + name: string; + arguments : aargs; + body : aexpr; + rettype : aword; +} [@@deriving yojson] + +(* -------------------------------------------------------------------- *) +let atype_as_aword (ty : atype) = + match ty with `W n -> n | _ -> raise (DestrError "atype_as_aword") + +(* -------------------------------------------------------------------- *) +let get_size (`W w : aword) : int = + w + +(* -------------------------------------------------------------------- *) +let pp_aword (fmt : Format.formatter) (`W n : aword) = + Format.fprintf fmt "@%d" n + +(* -------------------------------------------------------------------- *) +let pp_atype (fmt : Format.formatter) (t : atype) = + match t with + | `W _ as w -> Format.fprintf fmt "%a" pp_aword w + | `Unsigned -> Format.fprintf fmt "%s" "unsigned" + | `Signed -> Format.fprintf fmt "%s" "signed" diff --git a/libs/lospecs/circuit.ml b/libs/lospecs/circuit.ml new file mode 100644 index 0000000000..8176dc21ad --- /dev/null +++ b/libs/lospecs/circuit.ml @@ -0,0 +1,764 @@ +(* ==================================================================== *) +open Aig + +(* ==================================================================== *) +let rec log2 n = + if n <= 1 then 0 else 1 + log2 (n lsr 1) + +(* ==================================================================== *) +let sint_of_bools (bs : bool array) : int = + assert (Array.length bs <= Sys.int_size); + + let bs = + match Array.length bs with + | 0 -> + Array.make Sys.int_size false + | n -> + Array.append (Array.left bs (n - 1)) (Array.make (Sys.int_size - (n-1)) (bs.(n - 1))) + in + + Array.fold_lefti + (fun v i b -> if b then (1 lsl i) lor v else v) + 0 bs + +let split_at_arr (type t) (n: int) (r: t array) : t array * t array = + Array.sub r 0 n, Array.right r (Array.length r - n) + +(* -------------------------------------------------------------------- *) +let uint_of_bools (bs : bool array) : int = + assert (Array.length bs <= Sys.int_size - 1); + + Array.fold_lefti + (fun v i b -> if b then (1 lsl i) lor v else v) + 0 bs + +(* -------------------------------------------------------------------- *) +let int32_of_bools (bs : bool array) : int32 = + Array.fold_lefti + (fun v i b -> + if b then + Int32.logor (Int32.shift_left 1l i) v + else + v) + 0l bs + +let int64_of_bools (bs : bool array) : int64 = + Array.fold_lefti + (fun v i b -> + if b then + Int64.(logor (shift_left 1L i) v) + else + v) + 0L bs + +let ubigint_of_bools (bs: bool array) : Z.t = + Array.fold_right + (fun b acc -> + Z.(+) (Z.shift_left acc 1) (if b then Z.one else Z.zero)) + bs + Z.zero + +(* FIXME: Check this *) +let sbigint_of_bools (bs: bool array) : Z.t = + let bs = Array.rev bs in + let msb = bs.(0) in + Array.fold_left + (fun acc b -> + Z.(+) (Z.shift_left acc 1) (if b then Z.one else Z.zero)) + (if msb then Z.neg Z.one else Z.zero) + bs + +(* -------------------------------------------------------------------- *) +let explode (type t) ~(size : int) (r : t array) = + assert (Array.length r mod size == 0); + + Array.init ((Array.length r) / size) (fun i -> + Array.init size (fun j -> r.(i * size + j))) + + +(* -------------------------------------------------------------------- *) +let bytes_of_bools (bs : bool array) : bytes = + let bs = (Array.to_seq (explode ~size:8 bs)) in + let bs = Seq.map (uint_of_bools %> Char.chr) bs in + Bytes.of_seq bs + +(* -------------------------------------------------------------------- *) +let bools_of_reg (r: reg) : bool array = + Array.map (function + | { gate = False; id } when id > 0 -> false + | { gate = False; id } -> true + | _ -> raise (Invalid_argument "Can't convert non constant reg to bool array") + ) r + +let bool_list_of_reg : reg -> bool list = fun r -> bools_of_reg r |> Array.to_list + +(* -------------------------------------------------------------------- *) +let pp_reg_ ~(size : int) (fmt : Format.formatter) (r : bool array) = + assert (Array.length r mod (size * 4) = 0); + + let r = explode ~size:(size * 4) r in +(* let r = explode ~size:(size * 4) r in *) + let r = Array.map int32_of_bools r in + + Format.fprintf fmt "%a" + (fun fmt arr -> Array.iteri (fun i x -> + Format.fprintf fmt "%0.8lx" x; + if i < Array.length arr - 1 then + Format.fprintf fmt "_" + ) arr) + r + +let pp_reg ~(size: int) (fmt: Format.formatter) (r: reg) = + assert (size mod 4 = 0); + pp_reg_ ~size:(size / 4) fmt (bools_of_reg r) + +(* ==================================================================== *) +let bit ~(position : int) (v : int) : bool = + (v lsr position) land 0b1 <> 0 + +(* -------------------------------------------------------------------- *) +let bit32 ~(position : int) (v : int32) : bool = + let open Int32 in + logand (shift_right v position) 0b1l <> 0l + +(* -------------------------------------------------------------------- *) +let bit64 ~(position : int) (v : int64) : bool = + let open Int64 in + logand (shift_right v position) 0b1L <> 0L + +(* ==================================================================== *) +let of_int ~(size : int) (v : int) : reg = + Array.init size (fun i -> constant (bit ~position:i v)) + +(* -------------------------------------------------------------------- *) +let of_int32 (v : int32) : reg = + Array.init 32 (fun i -> constant (bit32 ~position:i v)) + +(* -------------------------------------------------------------------- *) +let of_int64 (v : int64) : reg = + Array.init 64 (fun i -> constant (bit64 ~position:i v)) + +(* -------------------------------------------------------------------- *) +let of_int32s (vs : int32 array) : reg = + Array.reduce Array.append (Array.map of_int32 vs) + +(* -------------------------------------------------------------------- *) +let of_bigint ~(size : int) (v : Z.t) : reg = + assert (0 <= Z.compare v Z.zero); + assert (Z.numbits v <= size); + Array.init size (fun i -> constant (Z.testbit v i)) + +(* -------------------------------------------------------------------- *) +let of_string ~(size : int) (s : string) : reg = + of_bigint ~size (Z.of_string s) + +(* ==================================================================== *) +let w8 (i : int) : reg = + of_int ~size:8 i + +(* -------------------------------------------------------------------- *) +let w16 (i : int) : reg = + of_int ~size:16 i + +(* -------------------------------------------------------------------- *) +let w32 (i : int32) : reg = + of_int32 i + +(* -------------------------------------------------------------------- *) +let w64 (i : int64) : reg = + of_int64 i + +(* -------------------------------------------------------------------- *) +let w128 (s : string) : reg = + of_string ~size:128 s + +(* -------------------------------------------------------------------- *) +let w256 (s : string) : reg = + of_string ~size:256 s + +(* ==================================================================== *) +let reg ~(size : int) ~(name : int) : reg = + Array.init size (fun i -> input (name, i)) + +(* ==================================================================== *) +let split_msb (r : reg) : node * reg = + let n = Array.length r in + let msb = r.(n-1) in + let r = Array.sub r 0 (n-1) in + msb, r + +(* ==================================================================== *) +let lnot_ (r : reg) : reg = + Array.map neg r + +(* -------------------------------------------------------------------- *) +let lor_ (r1 : reg) (r2 : reg) : reg = + Array.map2 or_ r1 r2 + +(* -------------------------------------------------------------------- *) +let lxor_ (r1 : reg) (r2 : reg) : reg = + Array.map2 xor r1 r2 + +(* -------------------------------------------------------------------- *) +let lxnor_ (r1 : reg) (r2 : reg) : reg = + Array.map2 xnor r1 r2 + +(* -------------------------------------------------------------------- *) +let land_ (r1 : reg) (r2 : reg) : reg = + Array.map2 and_ r1 r2 + +(* -------------------------------------------------------------------- *) +let ors (r : node array) : node = + Array.fold_left or_ false_ r + +(* -------------------------------------------------------------------- *) +let ands (r : node array) : node = + Array.fold_left and_ true_ r + +(* -------------------------------------------------------------------- *) +let lshift ~(offset : int) (r : reg) : reg = + Array.append (Array.make offset false_) r + +(* -------------------------------------------------------------------- *) +let uextend ~(size : int) (r : reg) : reg = + Array.append r @@ Array.make (max 0 (size - Array.length r)) false_ + +(* -------------------------------------------------------------------- *) +let sextend ~(size : int) (r : reg) : reg = + let lr = Array.length r in + + if size > lr then + match Array.length r with + | 0 -> + Array.make size false_ + | _ -> + Array.append r (Array.make (size - lr) (r.(lr - 1))) + else + r + +(* -------------------------------------------------------------------- *) +let trunc ~(size: int) (r: reg) : reg = + Array.sub r 0 size + +(* -------------------------------------------------------------------- *) +let mux2 (n1 : node) (n2 : node) (c : node) = + or_ (and_ (neg c) n1) (and_ c n2) + +(* -------------------------------------------------------------------- *) +let mux2_reg (r1 : reg) (r2 : reg) (c : node) = + assert (Array.length r1 = Array.length r2); + Array.map2 (fun n1 n2 -> mux2 n1 n2 c) r1 r2 + +(* -------------------------------------------------------------------- *) +let mux2_2 + ~(k00 : node) + ~(k01 : node) + ~(k10 : node) + ~(k11 : node) + ((c1, c2) : node * node) += + mux2 + (mux2 k00 k01 c2) + (mux2 k10 k11 c2) + c1 + +(* -------------------------------------------------------------------- *) +let mux2_2reg + ~(k00 : reg) + ~(k01 : reg) + ~(k10 : reg) + ~(k11 : reg) + ((c1, c2) : node * node) += + mux2_reg + (mux2_reg k00 k01 c2) + (mux2_reg k10 k11 c2) + c1 + +(* -------------------------------------------------------------------- *) +let mux_reg (cr : (node * reg) array) (r : reg) : reg = + Array.fold_right (fun (c, r) s -> mux2_reg s r c) cr r + +(* -------------------------------------------------------------------- *) +let ite (c : node) (t : reg) (f : reg) : reg = + mux2_reg f t c + +(* -------------------------------------------------------------------- *) +let c_rshift ~(lg2o : int) ~(sign : node) (c : node) (r : reg) = + let len = Array.length r in + let clamp = log2 len in + let s = + if lg2o > clamp then + Array.make len sign + else + let offset = 1 lsl lg2o in + Array.append (Array.sub r (min offset len) (len - (min offset len))) (Array.make (min offset len) sign) + in + Array.map2 (fun r1 s1 -> mux2 r1 s1 c) r s + +(* TODO: change array appends into inits *) + +(* -------------------------------------------------------------------- *) +let arshift ~(offset : int) (r : reg) = + let sign = if Array.length r = 0 then false_ else r.(Array.length r - 1) in + let l = Array.length r in + Array.append (Array.sub r (min offset l) (l - (min offset l))) (Array.make (min offset l) sign) + +(* -------------------------------------------------------------------- *) +let lsr_ (r as r0 : reg) (s : reg) : reg = + let _, r = + Array.fold_left (fun (i, r) c -> + (i+1, c_rshift ~lg2o:i ~sign:false_ c r) + ) (0, r) s + in assert (Array.length r = Array.length r0); r + +(* -------------------------------------------------------------------- *) +let lsl_ (r : reg) (s : reg) : reg = + Array.rev (lsr_ (Array.rev r) s) + +(* -------------------------------------------------------------------- *) +let asl_ (r : reg) (s : reg) : reg = + lsl_ r s + +(* -------------------------------------------------------------------- *) +let asr_ (r : reg) (s : reg) : reg = + let sign = + if Array.length r = 0 then false_ else r.(Array.length r - 1) + in + let _, r = + Array.fold_left (fun (i, r) c -> + (i+1, c_rshift ~lg2o:i ~sign c r) + ) (0, r) s + in r + +(* -------------------------------------------------------------------- *) +let shift ~(side : [`L | `R]) ~(sign : [`L | `A]) = + match side, sign with + | `L, `L -> lsl_ + | `R, `L -> lsr_ + | `L, `A -> asl_ + | `R, `A -> asr_ + + +(* -------------------------------------------------------------------- *) +let halfadder (a : node) (b : node) : node * node = + (and_ a b, xor a b) + +(* -------------------------------------------------------------------- *) +let incr (r : reg) : node * reg = + Array.fold_left_map halfadder true_ r + +(* -------------------------------------------------------------------- *) +let incrc (r : reg) : reg = + let c, r = incr r in Array.append r [|c|] + +(* -------------------------------------------------------------------- *) +let incr_dropc (r : reg) : reg = + snd (Array.fold_left_map halfadder true_ r) + +(* -------------------------------------------------------------------- *) +let opp (r : reg) : reg = + incr_dropc (lnot_ r) + +(* -------------------------------------------------------------------- *) +let fulladder (c : node) (a : node) (b : node) : node * node = + let c1, s = halfadder a b in + let c2, s = halfadder c s in + (or_ c1 c2, s) + +(* -------------------------------------------------------------------- *) +let addsub (m : node) (r1 : reg) (r2 : reg) : node * reg = + assert(Array.length r1 = Array.length r2); + + Array.fold_left_map + (fun carry (a, b) -> fulladder carry a (xor b m)) + m (Array.combine r1 r2) + +(* -------------------------------------------------------------------- *) +let add (r1 : reg) (r2 : reg) : node * reg = + addsub false_ r1 r2 + +(* -------------------------------------------------------------------- *) +let addc (r1 : reg) (r2 : reg) : reg = + let c, r = add r1 r2 in Array.append r [|c|] + +(* -------------------------------------------------------------------- *) +let add_dropc (r1 : reg) (r2 : reg) : reg = + snd (add r1 r2) + +(* -------------------------------------------------------------------- *) +let sub (r1 : reg) (r2 : reg) : node * reg = + addsub true_ r1 r2 + +(* -------------------------------------------------------------------- *) +let sub_dropc (r1 : reg) (r2 : reg) : reg = + snd (sub r1 r2) + +(* -------------------------------------------------------------------- *) +let bmul (n : node) (r : reg) : reg = + Array.map (fun n' -> and_ n n') r + +(* -------------------------------------------------------------------- *) +let umul (r1 : reg) (r2 : reg) : reg = + let n1 = Array.length r1 in + let n2 = Array.length r2 in + + let prods = Array.mapi (fun i n -> lshift ~offset:i (bmul n r2)) r1 in + + let out = Array.fold_left addc (Array.make n2 false_) prods in + let out = Array.sub out 0 (n1 + n2) in + + out + +(* -------------------------------------------------------------------- *) +let umul_ (r1 : reg) (r2 : reg) : reg * reg = + let n = Array.length r2 in + let r = umul r1 r2 in + + split_at_arr n r + +(* -------------------------------------------------------------------- *) +let umull (r1 : reg) (r2 : reg) : reg = + fst (umul_ r1 r2) + +(* -------------------------------------------------------------------- *) +let umulh (r1 : reg) (r2 : reg) : reg = + snd (umul_ r1 r2) + +(* -------------------------------------------------------------------- *) +let smul (r1 : reg) (r2 : reg) : reg = + let nm, (r1, r2) = + let n1 = Array.length r1 in + let n2 = Array.length r2 in + let nm = max n1 n2 in + + let r1 = sextend ~size:nm r1 in + let r2 = sextend ~size:nm r2 in + + (nm, (r1, r2)) in + + let sbmul_r2 (n : node) = + Array.mapi (fun i n' -> + let out = and_ n n' in + if i+1 = nm then neg out else out + ) r2 in + + let prods = Array.mapi (fun i n -> + let out = sbmul_r2 n in + let out = + match () with + | _ when i = 0 -> Array.append out [|true_|] + | _ when i+1 = nm -> Array.append (lnot_ out) [|true_|] + | _ -> Array.append out [|false_|] + in + lshift ~offset:i out + ) r1 in + + let out = Array.fold_left addc (Array.make (nm+1) false_) prods in + + Array.left out (2 * nm) + +(* -------------------------------------------------------------------- *) +let smul_ (r1 : reg) (r2 : reg) : reg * reg = + let nm = max (Array.length r1) (Array.length r2) in + let s = smul r1 r2 in + split_at_arr nm s + +(* -------------------------------------------------------------------- *) +let smull (r1 : reg) (r2 : reg) : reg = + fst (smul_ r1 r2) + +(* -------------------------------------------------------------------- *) +let smulh (r1 : reg) (r2 : reg) : reg = + snd (smul_ r1 r2) + +(* -------------------------------------------------------------------- *) +let ssat ~(size : int) (r : reg) : reg = + assert (0 < size); + assert (size < Array.length r); + + let rl, rh = split_at_arr (size - 1) r in + let rh, msb = Array.sub rh 0 (Array.length rh - 1), rh.(Array.length rh - 1) in + + let rm = Array.append (Array.make (size - 1) false_) [|true_ |] in + let rM = Array.append (Array.make (size - 1) true_ ) [|false_|] in + let ro = Array.append rl [|msb|] in + + let cm = and_ msb (neg (ands rh)) in + let cM = and_ (neg msb) (ors rh) in + + mux_reg [|(cm, rm); (cM, rM)|] ro + +(* -------------------------------------------------------------------- *) +let usat ~(size : int) (r : reg) : reg = + assert (size < Array.length r); + + let rl, rh = split_at_arr size r in + let rh, msb = Array.left rh (Array.length rh - 1), rh.(Array.length rh - 1) in + + let rm = Array.make size false_ in + let rM = Array.make size true_ in + let ro = rl in + + let cm = msb in + let cM = and_ (neg msb) (ors rh) in + + mux_reg [|(cm, rm); (cM, rM)|] ro + +(* -------------------------------------------------------------------- *) +let sat ~(signed : bool) ~(size : int) (r : reg) : reg = + match signed with + | true -> ssat ~size r + | false -> usat ~size r + +(* -------------------------------------------------------------------- *) +let ssadd (r1 : reg) (r2 : reg) : reg = + let n1 = Array.length r1 in + let n2 = Array.length r2 in + let n = max n1 n2 in + + let r1 = sextend ~size:(n+1) r1 in + let r2 = sextend ~size:(n+1) r2 in + + ssat ~size:n (add_dropc r1 r2) + +(* -------------------------------------------------------------------- *) +let usadd (r1 : reg) (r2 : reg) : reg = + let r = addc r1 r2 in + usat ~size:(Array.length r - 1) r + +(* -------------------------------------------------------------------- *) +let usmul (r1 : reg) (r2 : reg) : reg = + let n1 = Array.length r1 in + let n2 = Array.length r2 in + let nm = max n1 n2 in + + let r1 = uextend ~size:(2*nm) r1 in + let r2 = sextend ~size:(2*nm) r2 in + + smull r1 r2 + +(* -------------------------------------------------------------------- *) +let ugte (eq : node) (r1 : reg) (r2 : reg) : node = + let n1 = Array.length r1 in + let n2 = Array.length r2 in + let n = max n1 n2 in + let r1 = uextend ~size:n r1 in + let r2 = uextend ~size:n r2 in + + Array.fold_left (fun ct (c1, c2) -> + mux2_2 (c1, c2) + ~k00:ct + ~k01:Aig.false_ + ~k10:Aig.true_ + ~k11:ct + ) eq (Array.combine r1 r2) + +(* -------------------------------------------------------------------- *) +let sgte (eq : node) (r1 : reg) (r2 : reg) : node = + let msb1, r1 = split_msb r1 in + let msb2, r2 = split_msb r2 in + + mux2_2 (msb1, msb2) + ~k00:(ugte eq r1 r2) + ~k01:Aig.true_ + ~k10:Aig.false_ + ~k11:(ugte eq r1 r2) + +(* -------------------------------------------------------------------- *) +let bvueq (r1 : reg) (r2 : reg) : node = + let n1 = Array.length r1 in + let n2 = Array.length r2 in + let n = max n1 n2 in + let r1 = uextend ~size:n r1 in + let r2 = uextend ~size:n r2 in + + Array.fold_left (fun ct (c1, c2) -> + mux2_2 (c1, c2) + ~k00:ct + ~k01:Aig.false_ + ~k10:Aig.false_ + ~k11:ct + ) Aig.true_ (Array.combine r1 r2) + +(* -------------------------------------------------------------------- *) +let bvseq (r1 : reg) (r2 : reg) : node = + let n1 = Array.length r1 in + let n2 = Array.length r2 in + let n = max n1 n2 in + let r1 = sextend ~size:n r1 in + let r2 = sextend ~size:n r2 in + + Array.fold_left (fun ct (c1, c2) -> + mux2_2 (c1, c2) + ~k00:ct + ~k01:Aig.false_ + ~k10:Aig.false_ + ~k11:ct + ) Aig.true_ (Array.combine r1 r2) + +(* -------------------------------------------------------------------- *) +let ugt (r1 : reg) (r2 : reg) : node = + ugte Aig.false_ r1 r2 + +(* -------------------------------------------------------------------- *) +let uge (r1 : reg) (r2 : reg) : node = + ugte Aig.true_ r1 r2 + +(* -------------------------------------------------------------------- *) +let ult (r1: reg) (r2 : reg) : node = + ugt r2 r1 + +(* -------------------------------------------------------------------- *) +let ule (r1 : reg) (r2 : reg) : node = + uge r2 r1 + +(* -------------------------------------------------------------------- *) +let sgt (r1 : reg) (r2 : reg) : node = + sgte Aig.false_ r1 r2 + +(* -------------------------------------------------------------------- *) +let sge (r1 : reg) (r2 : reg) : node = + sgte Aig.true_ r1 r2 + +(* -------------------------------------------------------------------- *) +let slt (r1 : reg) (r2 : reg) : node = + sgt r2 r1 + +(* -------------------------------------------------------------------- *) +let sle (r1 : reg) (r2 : reg) : node = + sge r2 r1 + +(* -------------------------------------------------------------------- *) +let iszero (r : reg) : node = + bvueq r (Array.map (fun _ -> false_) r) + +(* -------------------------------------------------------------------- *) +let abs (a : reg) : reg = + let msb_a, _ = split_msb a in + ite (msb_a) (opp a) a + +(* -------------------------------------------------------------------- *) +let udiv_ (a : reg) (b : reg) : reg * reg = + assert (Array.length a >= Array.length b); + + let n = Array.length b in + + let pu (a : node) (b : node) (cin : node) : node * (node -> node) = + let cout, s = fulladder cin (neg b) a in + let out (cc : node) = mux2 a s cc in + (cout, out) + in + + let create_line (i : int) (d : node) (a : reg) : node * reg = + let a = Array.append [|d|] (if i = n then a else snd (split_msb a)) in + let b = if i < n then b else Array.append b [|Aig.false_|] in + + let c, pus = + Array.fold_left_map + (fun c (a, b) -> pu a b c) + Aig.true_ (Array.combine a b) + in (c, Array.map (fun pu -> pu c) pus) + in + + Array.fold_lefti (fun (q, a) i d -> + let q', a = create_line i d a in (Array.append [|q'|] q, a) + ) ([||], Array.make n false_) (Array.rev a) + +(* -------------------------------------------------------------------- *) +let udiv (a : reg) (b : reg) : reg = + let m = max (Array.length a) (Array.length b) in + let a = uextend ~size:m a in + let b = uextend ~size:m b in + ite (iszero b) a (fst (udiv_ a b)) + +(* -------------------------------------------------------------------- *) +let sdiv (s : reg) (t : reg) : reg = + let msb_s, _ = split_msb s in + let msb_t, _ = split_msb t in + + mux2_2reg + ~k00:( (udiv ( s) ( t))) + ~k10:(opp (udiv (opp s) ( t))) + ~k01:(opp (udiv ( s) (opp t))) + ~k11:( (udiv (opp s) (opp t))) + (msb_s, msb_t) + +(* -------------------------------------------------------------------- *) +let umod (a : reg) (b : reg) : reg = + let m = max (Array.length a) (Array.length b) in + let a = uextend ~size:m a in + let b = uextend ~size:m b in + + ite + (iszero b) + (Array.map (fun _ -> false_) b) + (uextend ~size:m (snd (udiv_ a b))) + +(* -------------------------------------------------------------------- *) +let srem (s : reg) (t : reg) : reg = + let msb_s, _ = split_msb s in + let msb_t, _ = split_msb t in + + mux2_2reg + ~k00:( (umod ( s) ( t))) + ~k10:(opp (umod (opp s) ( t))) + ~k01:(opp (umod ( s) (opp t))) + ~k11:( (umod (opp s) (opp t))) + (msb_s, msb_t) + +(* -------------------------------------------------------------------- *) +let smod (s : reg) (t : reg) : reg = + ite (iszero t) s @@ + let msb_s, _ = split_msb s in + let msb_t, _ = split_msb t in + + let u = umod (abs s) (abs t) in + + ite (iszero u) + u + (mux2_2reg + ~k00:( u ) + ~k10:(add_dropc (opp u) t) + ~k01:(add_dropc ( u) t) + ~k11:( (opp u) ) + (msb_s, msb_t)) + +(* -------------------------------------------------------------------- *) +let rol (r: reg) (s: reg) : reg = + let size = Array.length r in + let s = umod s (of_int ~size size) in (* so 0 <= s < size *) + let s = Array.left s size |> uextend ~size in (* by above, ln s < size *) + lor_ (shift ~side:`L ~sign:`L r s) (shift ~side:`R ~sign:`L r (sub_dropc (of_int ~size size) s)) + +(* -------------------------------------------------------------------- *) +let ror (r: reg) (s: reg) : reg = + let size = Array.length r in + let s = umod s (of_int ~size size) in (* so 0 <= s < size *) + let s = Array.left s size |> uextend ~size in (* by above, ln s < size *) + lor_ (shift ~side:`R ~sign:`L r s) (shift ~side:`L ~sign:`L r (sub_dropc (of_int ~size size) s)) + +(* -------------------------------------------------------------------- *) +let popcount ~(size : int) (r : reg) : reg = + Array.fold_left (fun aout node -> + ite node (incr_dropc aout) aout + ) (Array.make size Aig.false_) r + +(* -------------------------------------------------------------------- *) +(* FIXME: redo this *) +let of_bigint_all ~(size : int) (v : Z.t) : reg = + let mod_ = Z.(lsl) Z.one (size) in + let v = Z.rem v mod_ in + let v = if Z.sign v < 0 then Z.add mod_ v else v in + of_bigint ~size v + +(* Assumes input is array of 16 bit words *) +let compute ?(input_block_size = 16) ?(output_block_size = 16) (r: reg) (inp: int array) : int array = + assert (input_block_size <= 32); + let m = (1 lsl input_block_size) - 1 in + let inp = Array.map (fun i -> i land m) inp in + let inp = Array.map (of_int ~size:input_block_size) inp |> Array.reduce Array.append in + maps (function + | (0, i) -> Some (inp.(i)) + | _ -> None) r |> bools_of_reg |> explode ~size:output_block_size |> Array.map (uint_of_bools) + diff --git a/libs/lospecs/circuit.mli b/libs/lospecs/circuit.mli new file mode 100644 index 0000000000..6f923966a3 --- /dev/null +++ b/libs/lospecs/circuit.mli @@ -0,0 +1,171 @@ +(* ==================================================================== *) +open Aig + +(* ==================================================================== *) +val log2 : int -> int + +(* ==================================================================== *) +val explode : size:int -> 'a array -> 'a array array + +(* ==================================================================== *) +val sint_of_bools : bool array -> int + +val uint_of_bools : bool array -> int + +val bytes_of_bools : bool array -> bytes + +val ubigint_of_bools : bool array -> Z.t + +val sbigint_of_bools : bool array -> Z.t + +val bools_of_reg : reg -> bool array + +val bool_list_of_reg : reg -> bool list + +(* ==================================================================== *) +val of_int : size:int -> int -> reg + +val of_bigint : size:int -> Z.t -> reg + +val of_int32s : int32 array -> reg + +(* ==================================================================== *) +val w8 : int -> reg + +val w16 : int -> reg + +val w32 : int32 -> reg + +val w64 : int64 -> reg + +val w128 : string -> reg + +val w256 : string -> reg + +(* ==================================================================== *) +val mux2 : node -> node -> node -> node + +val mux2_reg : reg -> reg -> node -> reg + +val mux_reg : (node * reg) array -> reg -> reg + +val ite : node -> reg -> reg -> reg + +(* ==================================================================== *) +val reg : size:int -> name:int -> reg + +(* ==================================================================== *) +val uextend : size:int -> reg -> reg + +val sextend : size:int -> reg -> reg + +(* ==================================================================== *) +val lnot_ : reg -> reg + +val lor_ : reg -> reg -> reg + +val land_ : reg -> reg -> reg + +val lxor_ : reg -> reg -> reg + +val lxnor_ : reg -> reg -> reg + +val ors : node array -> node + +val ands : node array -> node + +(* ==================================================================== *) +val arshift : offset:int -> reg -> reg + +val lsl_ : reg -> reg -> reg + +val lsr_ : reg -> reg -> reg + +val asl_ : reg -> reg -> reg + +val asr_ : reg -> reg -> reg + +val shift : side:[`L | `R] -> sign:[`L | `A] -> reg -> reg -> reg + +val rol : reg -> reg -> reg + +val ror : reg -> reg -> reg + +(* ==================================================================== *) +val incr : reg -> node * reg + +val incr_dropc : reg -> reg + +val incrc : reg -> reg + +(* ==================================================================== *) +val add : reg -> reg -> node * reg + +val addc : reg -> reg -> reg + +val add_dropc : reg -> reg -> reg + +val ssadd : reg -> reg -> reg + +val usadd : reg -> reg -> reg + +(* ==================================================================== *) +val opp : reg -> reg + +val sub : reg -> reg -> node * reg + +val sub_dropc : reg -> reg -> reg + +(* ==================================================================== *) +val umul : reg -> reg -> reg + +val umull : reg -> reg -> reg + +val umulh : reg -> reg -> reg + +val smul : reg -> reg -> reg + +val smull : reg -> reg -> reg + +val smulh : reg -> reg -> reg + +val usmul : reg -> reg -> reg + +(* ==================================================================== *) +val ugte : node -> reg -> reg -> node + +val ugt : reg -> reg -> node + +val uge : reg -> reg -> node + +val sgte : node -> reg -> reg -> node + +val sgt : reg -> reg -> node + +val sge : reg -> reg -> node + +val bvueq : reg -> reg -> node + +val bvseq : reg -> reg -> node + +(* ==================================================================== *) +val sat : signed:bool -> size:int -> reg -> reg + +val udiv_ : reg -> reg -> reg * reg + +val udiv : reg -> reg -> reg + +val umod : reg -> reg -> reg + +val sdiv : reg -> reg -> reg + +val srem : reg -> reg -> reg + +val smod : reg -> reg -> reg + +(* ==================================================================== *) +val popcount : size:int -> reg -> reg + +val of_bigint_all : size:int -> Z.t -> reg + +val compute : ?input_block_size:int -> ?output_block_size:int -> reg -> int array -> int array diff --git a/libs/lospecs/circuit_spec.ml b/libs/lospecs/circuit_spec.ml new file mode 100644 index 0000000000..fca87e49b0 --- /dev/null +++ b/libs/lospecs/circuit_spec.ml @@ -0,0 +1,279 @@ +(* ==================================================================== *) +open Ast +open Aig + +(* ==================================================================== *) +let load_from_file ~(filename : string) = + let specs = File.with_file_in filename (Io.parse filename) in + let specs = Typing.tt_program Typing.Env.empty specs in + specs + +(* FIXME: Duplicated from circuit.ml *) +let split_at_arr (type t) (n: int) (r: t array) : t array * t array = + Array.sub r 0 n, Array.right r (Array.length r - n) + +exception CircuitSpecError of symbol (* FIXME PR: Rename? *) + +(* ==================================================================== *) +module Env : sig + type env + + val empty : env + + module Fun : sig + val get : env -> ident -> aargs * aexpr + + val bind : env -> ident -> aargs * aexpr -> env + end + + module Var : sig + val get : env -> ident -> reg + + val bind : env -> ident -> reg -> env + + val bindall : env -> (ident * reg) list -> env + end +end = struct + type binding = Var of reg | Fun of aargs * aexpr + + type env = (ident, binding) Map.t + + let empty : env = + Map.empty + + module Fun = struct + let get (env : env) (x : ident) = + match Map.find_opt x env with + | Some (Fun (a, f)) -> (a, f) + | _ -> raise Not_found + + let bind (env : env) (x : ident) ((a, f): aargs * aexpr) : env = + Map.add x (Fun (a, f)) env + end + + module Var = struct + let get (env : env) (x : ident) = + match Map.find_opt x env with + | Some (Var r) -> r + | _ -> raise Not_found + + let bind (env : env) (x : ident) (r: reg) : env = + Map.add x (Var r) env + + let bindall (env : env) (xr : (ident * reg) list) : env = + List.fold_left (fun env (x, r) -> bind env x r) env xr + end +end + +type env = Env.env + +(* ==================================================================== *) +let circuit_of_specification (rs : reg list) (p : adef) : reg = + assert (List.length rs = List.length p.arguments); + assert (List.for_all2 (fun r (_, `W n) -> Array.length r = n) rs p.arguments); + + let rec of_expr_ (env : env) (e : aexpr) : reg = + match e.node with + | EIncr (_, e) -> + Circuit.incr_dropc (of_expr env e) + + | EAdd (_, c, (e1, e2)) -> begin + let e1 = of_expr env e1 in + let e2 = of_expr env e2 in + match c with + | `Word -> Circuit.add_dropc e1 e2 + | `Sat `S -> Circuit.ssadd e1 e2 + | `Sat `U -> Circuit.usadd e1 e2 + end + + | ESub (_, (e1, e2)) -> + let e1 = of_expr env e1 in + let e2 = of_expr env e2 in + Circuit.sub_dropc e1 e2 + + | EMul (k, _, (e1, e2)) -> begin + let e1 = of_expr env e1 in + let e2 = of_expr env e2 in + + match k with + | `U `D -> Circuit.umul e1 e2 + | `U `H -> Circuit.umulh e1 e2 + | `U `L -> Circuit.umull e1 e2 + | `S `D -> Circuit.smul e1 e2 + | `S `H -> Circuit.smulh e1 e2 + | `S `L -> Circuit.umull e1 e2 + | `US -> Circuit.usmul e1 e2 + end + + | ECmp (`W _, us, k, (e1, e2)) -> + let e1 = of_expr env e1 in + let e2 = of_expr env e2 in + let c = + match us, k with + | `S, `Gt -> Circuit.sgt e1 e2 + | `S, `Ge -> Circuit.sge e1 e2 + | `U, `Gt -> Circuit.ugt e1 e2 + | `U, `Ge -> Circuit.uge e1 e2 + in [|c|] + + | ENot (_, e) -> + Circuit.lnot_ (of_expr env e) + + | EOr (_, (e1, e2)) -> + let e1 = of_expr env e1 in + let e2 = of_expr env e2 in + Circuit.lor_ e1 e2 + + | EXor (_, (e1, e2)) -> + let e1 = of_expr env e1 in + let e2 = of_expr env e2 in + Circuit.lxor_ e1 e2 + + | EAnd (_, (e1, e2)) -> + let e1 = of_expr env e1 in + let e2 = of_expr env e2 in + Circuit.land_ e1 e2 + + | EShift (lr, la, (e1, e2)) -> + let e1 = of_expr env e1 in + let e2 = of_expr env e2 in + Circuit.shift ~side:lr ~sign:la e1 e2 + + | ESat (us, `W size, e) -> begin + let e = of_expr env e in + match us with + | `U -> Circuit.sat ~signed:false ~size e + | `S -> Circuit.sat ~signed:true ~size e + end + + | EExtend (us, `W size, e) -> begin + let e = of_expr env e in + match us with + | `U -> Circuit.uextend ~size e + | `S -> Circuit.sextend ~size e + end + + | EPopCount (size, e) -> + Circuit.popcount ~size:(get_size size) (of_expr env e) + + | ESlice (e, ({ node = EInt offset }, size, scale)) -> + let e = of_expr env e in + let offset = offset * scale in + let size = size * scale in + Array.sub e offset size + + | ESlice (e, (offset, size, scale)) -> + let lgscale = Circuit.log2 scale in + assert (1 lsl lgscale = scale); + + let e = of_expr env e in + let offset = of_expr env offset in + + let offset = Array.append (Array.make lgscale Aig.false_) offset in + let size = size * scale in + + Array.left (Circuit.lsr_ e offset) size + + | EAssign (e, ({ node = EInt offset }, size, scale), v) -> + let e = of_expr env e in + let v = of_expr env v in + let offset = offset * scale in + let size = size * scale in + let pre, e = split_at_arr offset e in + let e, post = split_at_arr size e in + Array.append pre (Array.append v post) + + | EAssign (e, (offset, size, scale), v) -> + let esz = atype_as_aword e.type_ in + + let lgscale = Circuit.log2 scale in + assert (1 lsl lgscale = scale); + + let e = of_expr env e in + let offset = of_expr env offset in + let v = of_expr env v in + + let offset = Array.append (Array.make lgscale Aig.false_) offset in + let size = size * scale in + + let m = Array.make size Aig.true_ in + let m = Circuit.uextend ~size:esz m in + let m = Circuit.lnot_ (Circuit.lsl_ m offset) in + + let v = Circuit.uextend ~size:esz v in + let v = Circuit.lsl_ v offset in + + Circuit.lor_ (Circuit.land_ e m) v + + | EConcat (_, es) -> + Array.reduce Array.append (List.map (of_expr env) es |> Array.of_list) + + | ERepeat (_, (e, n)) -> + Array.reduce Array.append (Array.make n (of_expr env e)) + + | EMap ((`W n, _), (a, f), es) -> + let anames = List.map fst a in + let es = List.map (of_expr env) es in + let es = List.map (Circuit.explode ~size:n %> Array.to_list) es in + let es = List.transpose es |> Array.of_list in + + let es = es |> Array.map (fun es -> + let env = Env.Var.bindall env (List.combine anames es) in + of_expr env f + ) + + in Array.reduce Array.append es + + | EApp (f, args) -> + let a, f = Env.Fun.get env f in + let anames = List.map fst a in + let args = List.map (of_expr env) args in + let env = Env.Var.bindall env (List.combine anames args) in + of_expr env f + + | ELet ((x, None, v), e) -> + let v = of_expr env v in + of_expr (Env.Var.bind env x v) e + + | ELet ((x, Some a, v), e) -> + let env = Env.Fun.bind env x (a, v) in + of_expr env e + + | ECond (c, (e1, e2)) -> + let c = of_expr env c in + let e1 = of_expr env e1 in + let e2 = of_expr env e2 in + + Circuit.mux2_reg e2 e1 (Circuit.ors c) + + | EVar x -> + Env.Var.get env x + + | EInt i -> begin + match e.type_ with + | `W n -> Circuit.of_int ~size:n i + | _ -> raise (CircuitSpecError (Format.asprintf "Expected int got %a" pp_atype e.type_)) + end + + and of_expr (env : env) (e : aexpr) : reg = + let r = of_expr_ env e in + + begin + match e.type_ with + | `W n -> + if Array.length r <> n then begin + Format.eprintf "%d %d@." (Array.length r) n; + Format.eprintf "%a@." + (Yojson.Safe.pretty_print ~std:true) + (Ast.aexpr_to_yojson e); + raise (CircuitSpecError (Format.asprintf "Bitstring length mismatch (expected %d, got %d)" n (Array.length r))) + end + | _ -> () + end; r + in + + let env = + let bindings = List.combine (List.map fst p.arguments) rs in + Env.Var.bindall Env.empty bindings in + + of_expr env p.body diff --git a/libs/lospecs/circuit_spec.mli b/libs/lospecs/circuit_spec.mli new file mode 100644 index 0000000000..89a558c677 --- /dev/null +++ b/libs/lospecs/circuit_spec.mli @@ -0,0 +1,3 @@ +(* ==================================================================== *) +val circuit_of_specification : Aig.reg list -> Ast.adef -> Aig.reg +val load_from_file : filename:string -> (string * Ast.adef) list diff --git a/libs/lospecs/deps.ml b/libs/lospecs/deps.ml new file mode 100644 index 0000000000..e9a77fe708 --- /dev/null +++ b/libs/lospecs/deps.ml @@ -0,0 +1,196 @@ +(* -------------------------------------------------------------------- *) +open Ast + +(* -------------------------------------------------------------------- *) +type symbol = string + +(* -------------------------------------------------------------------- *) +type dep1 = Set.Int.t IdentMap.t +type deps = dep1 Map.Int.t + +(* -------------------------------------------------------------------- *) +let eq_dep1 (d1 : dep1) (d2 : dep1) : bool = + IdentMap.equal Set.Int.equal d1 d2 + +(* -------------------------------------------------------------------- *) +let eq_deps (d1 : deps) (d2 : deps) : bool = Map.Int.equal eq_dep1 d1 d2 + +(* -------------------------------------------------------------------- *) +let empty ~(size : int) : deps = + 0 --^ size |> Enum.map (fun i -> (i, IdentMap.empty)) |> Map.Int.of_enum + +(* -------------------------------------------------------------------- *) +let enlarge ~(min : int) ~(max : int) (d : deps) : deps = + let change = function None -> Some IdentMap.empty | Some _ as v -> v in + + min --^ max |> Enum.fold (fun d i -> Map.Int.modify_opt i change d) d + +(* -------------------------------------------------------------------- *) +let clearout ~(min : int) ~(max : int) (d : deps) : deps = + Map.Int.filter_map + (fun i d1 -> Some (if min <= i && i < max then d1 else IdentMap.empty)) + d + +(* -------------------------------------------------------------------- *) +let restrict ~(min : int) ~(max : int) (d : deps) : deps = + Map.Int.filter (fun i _ -> min <= i && i < max) d + +(* -------------------------------------------------------------------- *) +let recast ~(min : int) ~(max : int) (d : deps) : deps = + d |> restrict ~min ~max |> enlarge ~min ~max + +(* -------------------------------------------------------------------- *) +let merge1 (d1 : dep1) (d2 : dep1) : dep1 = + IdentMap.merge + (fun _ i1 i2 -> + Some (Set.Int.union (i1 |? Set.Int.empty) (i2 |? Set.Int.empty))) + d1 d2 + +(* -------------------------------------------------------------------- *) +let merge (d1 : deps) (d2 : deps) : deps = + Map.Int.merge + (fun _ m1 m2 -> + Some (merge1 (m1 |? IdentMap.empty) (m2 |? IdentMap.empty))) + d1 d2 + +(* -------------------------------------------------------------------- *) +let merge1_all (ds : dep1 Enum.t) : dep1 = Enum.reduce merge1 ds + +(* -------------------------------------------------------------------- *) +let merge_all (ds : deps Enum.t) : deps = Enum.reduce merge ds + +(* -------------------------------------------------------------------- *) +let copy ~(offset : int) ~(size : int) (x : ident) : deps = + 0 --^ size + |> Enum.map (fun i -> + let di = IdentMap.singleton x (Set.Int.singleton (i + offset)) in + (i, di)) + |> Map.Int.of_enum + +(* -------------------------------------------------------------------- *) +let chunk ~(csize : int) ~(count : int) (d : deps) : deps = + 0 --^ count + |> Enum.map (fun ci -> + let d1 = + 0 --^ csize + |> Enum.map (fun i -> i + (ci * csize)) + |> Enum.map (fun i -> Map.Int.find_opt i d |> Option.default IdentMap.empty) + |> merge1_all + in + 0 --^ csize |> Enum.map (fun i -> (i + (ci * csize), d1))) + |> Enum.flatten |> Map.Int.of_enum + +(* -------------------------------------------------------------------- *) +let perm ~(csize : int) ~(perm : int list) (d : deps) : deps = + List.enum perm + |> Enum.mapi (fun ci x -> + Enum.map + (fun i -> (i + (ci * csize), Map.Int.find_opt (i + (x * csize)) d |> Option.default IdentMap.empty)) + (0 --^ csize)) + |> Enum.flatten |> Map.Int.of_enum + +(* -------------------------------------------------------------------- *) +let collapse ~(csize : int) ~(count : int) (d : deps) : deps = + 0 --^ count + |> Enum.map (fun ci -> + let d1 = + 0 --^ csize + |> Enum.map (fun i -> i + (ci * csize)) + |> Enum.map (fun i -> Map.Int.find_opt i d |> Option.default IdentMap.empty) + |> merge1_all + in + (ci, d1)) + |> Map.Int.of_enum + +(* -------------------------------------------------------------------- *) +let merge_all_deps (d : deps) : dep1 = + Map.Int.enum d |> Enum.map snd |> merge1_all + +(* -------------------------------------------------------------------- *) +let constant ~(size : int) (d : dep1) : deps = + 0 --^ size |> Enum.map (fun i -> (i, d)) |> Map.Int.of_enum + +(* -------------------------------------------------------------------- *) +let offset ~(offset : int) (d : deps) : deps = + Map.Int.enum d |> Enum.map (fun (i, x) -> (i + offset, x)) |> Map.Int.of_enum + +(* -------------------------------------------------------------------- *) +let split ~(csize : int) ~(count : int) (d : deps) : deps Enum.t = + 0 --^ count + |> Enum.map (fun i -> + Map.Int.filter (fun x _ -> csize * i <= x && x < csize * (i + 1)) d + |> offset ~offset:(-i * csize)) + +(* -------------------------------------------------------------------- *) +let aggregate ~(csize : int) (ds : deps Enum.t) = + Enum.foldi + (fun i d1 d -> merge (offset ~offset:(i * csize) d1) d) + (empty ~size:0) ds + +(* ==================================================================== *) +type 'a pp = Format.formatter -> 'a -> unit + +(* -------------------------------------------------------------------- *) +let pp_bitset (fmt : Format.formatter) (d : Set.Int.t) = + Format.fprintf fmt "{%a}" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") + Format.pp_print_int) + (Set.Int.elements d) + +(* -------------------------------------------------------------------- *) +let pp_bitintv (fmt : Format.formatter) (d : (int * int) list) = + Format.fprintf fmt "%a" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") + (fun fmt (i, j) -> Format.fprintf fmt "[%d..%d](%d)" i j (j - i + 1))) + d + +(* -------------------------------------------------------------------- *) +let bitintv_of_bitset (d : Set.Int.t) = + let aout = ref [] in + let current = ref None in + + d + |> Set.Int.iter (fun i -> + match !current with + | None -> current := Some (i, i) + | Some (v1, v2) -> + if i = v2 + 1 then current := Some (v1, i) + else ( + aout := (v1, v2) :: !aout; + current := Some (i, i))); + + Option.may (fun (v1, v2) -> aout := (v1, v2) :: !aout) !current; + + List.rev !aout + +(* -------------------------------------------------------------------- *) +let pp_dep1 (fmt : Format.formatter) (d : dep1) = + IdentMap.iter + (fun x bits -> + Format.fprintf fmt "%s.%d -> %a@\n" (Ident.name x) (Ident.id x) pp_bitintv (bitintv_of_bitset bits)) + d + +(* -------------------------------------------------------------------- *) +let pp_deps (fmt : Format.formatter) (d : deps) = + let display (v1, v2, d) = + Format.fprintf fmt "[%d..%d](%d) -> @[@\n%a@]@\n" v1 v2 + (v2 - v1 + 1) + pp_dep1 d + in + + let current = ref None in + + Map.Int.iter + (fun i d -> + match !current with + | None -> current := Some (i, i, d) + | Some (v1, v2, d') -> + if i = v2 + 1 && eq_dep1 d d' then current := Some (v1, i, d') + else ( + display (v1, v2, d'); + current := Some (i, i, d))) + d; + + Option.may display !current diff --git a/libs/lospecs/deps.mli b/libs/lospecs/deps.mli new file mode 100644 index 0000000000..7bdad64d48 --- /dev/null +++ b/libs/lospecs/deps.mli @@ -0,0 +1,35 @@ +open Ast + +(* -------------------------------------------------------------------- *) +type symbol = string +type dep1 = Set.Int.t IdentMap.t +type deps = dep1 Map.Int.t + +(* -------------------------------------------------------------------- *) +val empty : size:int -> deps +val enlarge : min:int -> max:int -> deps -> deps +val clearout : min:int -> max:int -> deps -> deps +val restrict : min:int -> max:int -> deps -> deps +val recast : min:int -> max:int -> deps -> deps +val merge1 : dep1 -> dep1 -> dep1 +val merge : deps -> deps -> deps +val merge1_all : dep1 Enum.t -> dep1 +val merge_all : deps Enum.t -> deps +val copy : offset:int -> size:int -> ident -> deps +val chunk : csize:int -> count:int -> deps -> deps +val perm : csize:int -> perm:int list -> deps -> deps +val collapse : csize:int -> count:int -> deps -> deps +val merge_all_deps : deps -> dep1 +val constant : size:int -> dep1 -> deps +val offset : offset:int -> deps -> deps +val split : csize:int -> count:int -> deps -> deps Enum.t +val aggregate : csize:int -> deps Enum.t -> deps + +(* -------------------------------------------------------------------- *) +type 'a pp = Format.formatter -> 'a -> unit + +val bitintv_of_bitset : Set.Int.t -> (int * int) list +val pp_bitset : Set.Int.t pp +val pp_bitintv : (int * int) list pp +val pp_dep1 : dep1 pp +val pp_deps : deps pp diff --git a/libs/lospecs/dune b/libs/lospecs/dune new file mode 100644 index 0000000000..a723995e61 --- /dev/null +++ b/libs/lospecs/dune @@ -0,0 +1,15 @@ +(library + (name lospecs) + (public_name easycrypt.lospecs) + (flags + (:standard -open Batteries)) + (preprocess + (pps ppx_deriving_yojson)) + (libraries batteries bitwuzla menhirLib zarith)) + +(ocamllex lexer) + +(menhir + (modules parser) + (explain true) + (flags --table)) diff --git a/libs/lospecs/hlaig.ml b/libs/lospecs/hlaig.ml new file mode 100644 index 0000000000..bdec654433 --- /dev/null +++ b/libs/lospecs/hlaig.ml @@ -0,0 +1,510 @@ +type node = Aig.node +type reg = Aig.reg + +module Hashtbl = Batteries.Hashtbl + +module type SMTInstance = sig + type bvterm + + exception SMTError + + (* Expected params: sort, value *) + val bvterm_of_int : int -> int -> bvterm + + (* Expected params: sort, name *) + val bvterm_of_name : int -> string -> bvterm + + (* argument must be of size 1, assert it true *) + (* Should affect internal state of SMT *) + val assert' : bvterm -> unit + + (* Check satisfiability of current asserts *) + val check_sat : unit -> bool + + (* equality over bitvectors, res is a size 1 bitvector *) + val bvterm_equal : bvterm -> bvterm -> bvterm + + (* bvterm concat, res sort is sum of sorts *) + val bvterm_concat : bvterm -> bvterm -> bvterm + + (* bvand *) + val lognot : bvterm -> bvterm + + (* bvnot *) + val logand : bvterm -> bvterm -> bvterm + + val get_value : bvterm -> bvterm + + val pp_term : Format.formatter -> bvterm -> unit +end + +module type SMTInterface = sig + val circ_equiv : ?inps:(int * int) list -> reg -> reg -> node -> bool + + val circ_sat : ?inps:(int * int) list -> node -> bool + + val circ_taut : ?inps:(int * int) list -> node -> bool +end + +(* TODO Add model printing for circ_sat and circ_taut *) +(* Assumes circuit inputs have already been appropriately renamed *) +module MakeSMTInterface(SMT: SMTInstance) : SMTInterface = struct + let circ_equiv ?(inps: (int * int) list option) (r1 : Aig.reg) (r2 : Aig.reg) (pcond : Aig.node) : bool = + if not ((Array.length r1 > 0) && (Array.length r2 > 0)) then + (Format.eprintf "Sizes differ in circ_equiv"; false) + else + let bvvars : SMT.bvterm Map.String.t ref = ref Map.String.empty in + + let rec bvterm_of_node : Aig.node -> SMT.bvterm = + let cache = Hashtbl.create 0 in + + let rec doit (n : Aig.node) = + let mn = + match Hashtbl.find_option cache (Int.abs n.id) with + | None -> + let mn = doit_r n.gate in + Hashtbl.add cache (Int.abs n.id) mn; + mn + | Some mn -> + mn + in + if 0 < n.id then mn else SMT.lognot mn + + and doit_r (n : Aig.node_r) = + match n with + | False -> SMT.bvterm_of_int 1 0 + | Input v -> let name = ("BV_" ^ (fst v |> string_of_int) ^ "_" ^ (Printf.sprintf "%X" (snd v))) in + begin + match Map.String.find_opt name !bvvars with + | None -> + bvvars := Map.String.add name (SMT.bvterm_of_name 1 name) !bvvars; + Map.String.find name !bvvars + | Some t -> t + end + | And (n1, n2) -> SMT.logand (doit n1) (doit n2) + + in fun n -> doit n + in + + let bvterm_of_reg (r: Aig.reg) : _ = + Array.map bvterm_of_node r |> Array.reduce (fun acc b -> SMT.bvterm_concat b acc) + in + + let bvinpt1 = (bvterm_of_reg r1) in + let bvinpt2 = (bvterm_of_reg r2) in + let formula = SMT.bvterm_equal bvinpt1 bvinpt2 in + let pcond = (bvterm_of_node pcond) in + let inps = Option.bind inps (fun l -> + if List.is_empty l then None + else Some l + ) in + + let inps = Option.map (fun inps -> + List.map (fun (id,sz) -> + List.init sz (fun i -> ("BV_" ^ (id |> string_of_int) ^ "_" ^ (Printf.sprintf "%X" (i))))) inps + ) inps in + let inps = Option.map (fun inps -> + List.map (List.map (fun name -> match Map.String.find_opt name !bvvars with + | Some bv -> bv + | None -> SMT.bvterm_of_name 1 name)) inps) inps + in + let bvinp = Option.map (fun inps -> + List.map (fun i -> List.reduce (SMT.bvterm_concat) i) inps) inps + in + + begin + SMT.assert' @@ SMT.logand pcond (SMT.lognot formula); + if SMT.check_sat () = false then true + else begin + Format.eprintf "bvout1: %a@." SMT.pp_term (SMT.get_value bvinpt1); + Format.eprintf "bvout2: %a@." SMT.pp_term (SMT.get_value bvinpt2); + Format.eprintf "Terms in formula: "; + List.iter (Format.eprintf "%s ") (List.of_enum @@ Map.String.keys !bvvars); + Format.eprintf "@\n"; + Option.may (fun bvinp -> + List.iteri (fun i bv -> + Format.eprintf "input[%d]: %a@." i SMT.pp_term (SMT.get_value bv) + ) bvinp) bvinp; + false + end + end + + + (* TODO: better encoding of smt terms ? *) + let circ_sat ?(inps: (int * int) list option) (n : Aig.node) : bool = + let bvvars : SMT.bvterm Map.String.t ref = ref Map.String.empty in + + begin match inps with + | None -> () + | Some inps -> List.iter (fun (id, sz) -> + List.iter (fun i -> + let name = ("BV_" ^ (string_of_int id) ^ "_" ^ (Printf.sprintf "%05X" i)) in + bvvars := Map.String.add name (SMT.bvterm_of_name 1 name) !bvvars) + (List.init sz identity)) inps + end; + + let rec bvterm_of_node : Aig.node -> SMT.bvterm = + let cache = Hashtbl.create 0 in + + let rec doit (n : Aig.node) = + let mn = + match Hashtbl.find_option cache (Int.abs n.id) with + | None -> + let mn = doit_r n.gate in + Hashtbl.add cache (Int.abs n.id) mn; + mn + | Some mn -> + mn + in + if 0 < n.id then mn else SMT.lognot mn + + and doit_r (n : Aig.node_r) = + match n with + | False -> SMT.bvterm_of_int 1 0 + | Input v -> let name = ("BV_" ^ (fst v |> string_of_int) ^ "_" ^ (Printf.sprintf "%05X" (snd v))) in + begin + match Map.String.find_opt name !bvvars with + | None -> + bvvars := Map.String.add name (SMT.bvterm_of_name 1 name) !bvvars; + Map.String.find name !bvvars + | Some t -> t + end + | And (n1, n2) -> SMT.logand (doit n1) (doit n2) + + in fun n -> doit n + in + + let form = bvterm_of_node n in + + let inps = Option.bind inps (fun l -> + if List.is_empty l then None + else Some l + ) in + + let inps = Option.map (fun inps -> + List.map (fun (id,sz) -> + List.init sz (fun i -> ("BV_" ^ (id |> string_of_int) ^ "_" ^ (Printf.sprintf "%05X" (i))))) inps + ) inps in + let inps = Option.map (fun inps -> + List.map (List.map (fun name -> match Map.String.find_opt name !bvvars with + | Some bv -> bv + | None -> SMT.bvterm_of_name 1 name)) inps) inps + in + let bvinp = Option.map (fun inps -> + List.map (fun i -> List.reduce (SMT.bvterm_concat) i) inps) inps + in + + begin + SMT.assert' @@ form; + if SMT.check_sat () = true then + begin + Format.eprintf "Input BVVars: "; + let () = Enum.iter (Format.eprintf "%s, ") (Map.String.keys !bvvars) in + Format.eprintf "@."; + Option.may (fun bvinp -> List.iteri (fun i bv -> + Format.eprintf "input[%d]: %a@." i SMT.pp_term (SMT.get_value bv) + ) bvinp) bvinp; + true + end + else false + end + + let circ_taut ?inps (n: Aig.node) : bool = + not @@ circ_sat ?inps (Aig.neg n) + +end + + +let makeBWZinstance () : (module SMTInstance) = + let module B = Bitwuzla.Once () in + let open B in + + (module struct + type bvterm = Term.Bv.t + + exception SMTError + + let bvterm_of_int (sort: int) (v: int) : bvterm = + Term.Bv.of_int (Sort.bv sort) v + + + let bvterm_of_name (sort: int) (name: string) : bvterm = + Term.const (Sort.bv sort) name + + + let assert' (f: bvterm) : unit = + assert' f + + + let check_sat () : bool = + match check_sat () with + | Sat -> true + | Unsat -> false + | Unknown -> raise SMTError + + + let bvterm_equal (bv1: bvterm) (bv2: bvterm) : bvterm = + Term.equal bv1 bv2 + + + let bvterm_concat (bv1: bvterm) (bv2: bvterm) : bvterm = + Term.Bv.concat [|bv1; bv2|] + + + let lognot (bv: bvterm) : bvterm = + Term.Bv.lognot bv + + + let logand (bv1: bvterm) (bv2: bvterm) : bvterm = + Term.Bv.logand bv1 bv2 + + + let get_value (bv: bvterm) : bvterm = + (get_value bv :> bvterm) + + + let pp_term (fmt: Format.formatter) (bv: bvterm) : unit = + Term.pp fmt bv + + end : SMTInstance) + + +let makeBWZinterface () : (module SMTInterface) = + (module MakeSMTInterface ((val makeBWZinstance () : SMTInstance))) + + +let of_int (i:int) : reg = + (* Number of bits the integer occupies *) + let rec log2up (i: int) : int = + match i with + | 0 | 1 -> 1 + | _ -> 1 + log2up (i/2) + in + Circuit.of_int ~size:(log2up i) i + +(* ------------------------------------------------------------------------------- *) +(* FIXME: CHECK THIS *) +let rec inputs_of_node : _ -> Aig.var Set.t = + let cache : (int, Aig.var Set.t) Hashtbl.t = Hashtbl.create 0 in + + let rec doit (n : Aig.node) : Aig.var Set.t = + match Hashtbl.find_option cache (Int.abs n.id) with + | None -> + let mn = doit_r n.gate in + Hashtbl.add cache (Int.abs n.id) mn; + mn + | Some mn -> + mn + + and doit_r (n : Aig.node_r) = + match n with + | False -> Set.empty + | Input v -> Set.singleton v + | And (n1, n2) -> Set.union (doit n1) (doit n2) + + in fun n -> doit n + +(* ------------------------------------------------------------------------------- *) +let inputs_of_reg (r : Aig.reg) : Aig.var Set.t = + Array.fold_left (fun acc x -> Set.union acc (inputs_of_node x)) Set.empty r + +module Deps = struct + (* tdeps : int -> int set ; dependency for a single output bit + i |-> {j | output depends on bit j of var i }*) + type tdeps = (int, int Set.t) Map.t + (* tdblock (n, d) = merged dependencies as above for n bits + aka, the tdep represents dependencies for n bits rather than 1 + *) + type tdblock = (int * tdeps) + + + let cache : (int, tdeps) Hashtbl.t = Hashtbl.create 5003 + + let reset_state : unit -> unit = fun () -> Hashtbl.reset cache + + (* ==================================================================== *) + let rec dep : _ -> tdeps = + let cache : (int, tdeps) Hashtbl.t = Hashtbl.create 0 in + + let rec doit (n: Aig.node) : tdeps = + match Hashtbl.find_option cache (Int.abs n.id) with + | None -> let mn = doit_r n.gate in + Hashtbl.add cache (Int.abs n.id) mn; + mn + | Some mn -> + mn + + and doit_r (n: Aig.node_r) = + match n with + | False -> Map.empty + | Input (v, i) -> Map.add v (Set.add i (Set.empty)) Map.empty + | And (n1, n2) -> Map.union_stdlib (fun k s1 s2 -> Some (Set.union s1 s2)) (doit n1) (doit n2) + + in (fun n -> + let res = doit n in + Hashtbl.clear cache; + res) + + let deps (n: reg) : tdeps array = + Array.map dep n + + let block_deps (d: tdeps array) : tdblock list = + let drop_while_count (f: 'a -> bool) (l: 'a list) : int * ('a list) = + let rec doit (n: int) (l: 'a list) = + match l with + | [] -> (n, []) + | a::l' -> if f a then doit (n+1) l' else (n, l) + in + let n, tl = doit 0 l in + (n, tl) + in + let rec decompose (l: tdeps list) : tdblock list = + match l with + | [] -> [] + | h::_ -> let n, l' = + (drop_while_count (fun a -> Map.equal (Set.equal) h a) l) in + (n, h)::(decompose l') + in + decompose (Array.to_list d) + + let blocks_indep ((_,b):tdblock) ((_,d):tdblock) : bool = + let keys = Set.intersect (Set.of_enum @@ Map.keys b) (Set.of_enum @@ Map.keys d) in + let intersects = Set.map (fun k -> + let b1 = Map.find k b in + let d1 = Map.find k d in + (Set.cardinal @@ Set.intersect b1 d1) = 0 + ) keys in + Set.fold (&&) intersects true + + let block_list_indep (bs: tdblock list) : bool = + let rec doit (bs: tdblock list) (acc: tdblock list) : bool = + match bs with + | [] -> true + | b::bs -> List.for_all (blocks_indep b) acc && doit bs (b::acc) + in + doit bs [] + + let merge_deps (d1: tdeps) (d2: tdeps) : tdeps = + Map.union_stdlib (fun _ a b -> Option.some (Set.union a b)) d1 d2 + + let split_deps (n: int) (d: tdeps array) : tdblock list = + assert (Array.length d mod n = 0); + let combine (d: tdeps list) : tdeps = + List.reduce merge_deps d + in + let rec aggregate (acc: tdblock list) (d: tdeps array) : tdblock list = + match d with + | [| |] -> acc + | _ -> (aggregate ((n, combine (Array.head d n |> Array.to_list))::acc) (Array.tail d n)) + in + List.rev @@ aggregate [] d + + let check_dep_width ?(eq=false) (n: int) (d: tdeps) : bool = + Map.fold (fun s acc -> let m = (Set.cardinal s) in + if eq then + acc && (n = m) + else + acc && (m <= n) + ) d true + + (* maybe optimize this? *) + let tdblock_of_tdeps (d: tdeps list) : tdblock = + (List.length d, List.reduce merge_deps d) + + (* + Take a list of blocks and drop all but the first block if the + sizes are the same and the dependecy amounts are the same + *) + let compare_dep_size (a: tdeps) (b: tdeps) : bool = + (Map.fold (fun s acc -> acc + (Set.cardinal s)) a 0) = + (Map.fold (fun s acc -> acc + (Set.cardinal s)) b 0) + + let compare_tdblocks ((na, da): tdblock) ((nb, db): tdblock) : bool = + (na = nb) && compare_dep_size da db + + let collapse_blocks (d: tdblock list) : tdblock option = + match d with + | [] -> None + | h::t -> + List.fold_left + (fun a b -> + match a with + | None -> None + | Some a -> if compare_tdblocks a b + then Some a else None) + (Some h) t + + (* -------------------------------------------------------------------- *) + (* Uses dependency analysis to realign inputs to start at 0 *) + (* Corresponds to taking the relevant subcircuit to this output *) + (* Assumes that inputs are contiguous FIXME *) + let realign_inputs ?(renamings: (int -> int option) option) (n: node) : node * (int, int * int) Map.t = + let d = dep n in + let shifts = Map.map (fun s -> + Set.min_elt_opt s |> Option.default 0, + Set.max_elt_opt s |> Option.default 0 + ) d in + let map_ = + match renamings with + | Some renamings -> begin fun (v, i) -> + let v' = renamings v |> Option.default v in + match Map.find_opt v shifts with + | None -> None + | Some (k, _) -> Some (Aig.input (v', i-k)) + end + | None -> begin fun (v, i) -> + match Map.find_opt v shifts with + | None -> None + | Some (k, _) -> Some (Aig.input (v, i-k)) + end + in + let shifts = match renamings with + | None -> shifts + | Some renamings -> + Map.to_seq shifts |> Seq.map (fun (k, v) -> + Option.default k (renamings k), v) |> Map.of_seq + in + Aig.map map_ n, shifts + + + let pp_dep ?(namer = string_of_int) (fmt : Format.formatter) (d: tdeps) : unit = + let print_set fmt s = Set.iter (Format.fprintf fmt "%d ") s in + Map.iter (fun id ints -> Format.fprintf fmt "%s: %a@." (namer id) print_set ints) d + + let pp_deps ?(namer = string_of_int) (fmt: Format.formatter) (ds: tdeps list) : unit = + List.iteri (fun i d -> Format.fprintf fmt "Output #%d:@.%a@." i (pp_dep ~namer) d) ds + + let pp_bdep ?(start_index = 0) ?(oname="") ?(namer=string_of_int) (fmt: Format.formatter) ((n, d): tdblock) : unit = + Format.fprintf fmt "[%d-%d]%s:@." start_index (start_index+n-1) oname; + pp_dep ~namer fmt d + + let pp_bdeps ?(oname="") ?(namer=string_of_int) (fmt: Format.formatter) (bs: tdblock list) : unit = + List.fold_left (fun acc (n,d) -> (pp_bdep ~start_index:acc ~oname ~namer fmt (n,d)); acc + n) 0 bs |> ignore +end + +(* -------------------------------------------------------------------- *) +let rec pp_node ?(namer=string_of_int) (fmt : Format.formatter) (n : node) = + let pp_node = pp_node ~namer in + match n with + | { gate = False; id } when 0 < id -> + Format.fprintf fmt "⊥" + + | { gate = False; } -> + Format.fprintf fmt "⊤" + + | { gate = Input (v, i); id; } -> + let s = namer v in + Format.fprintf fmt "%s%s#%0.4x" + (if 0 < id then "" else "¬") s i + + | { gate = And (n1, n2); id; } when 0 < id -> + Format.fprintf fmt "(%a) ∧ (%a)" pp_node n1 pp_node n2 + + | { gate = And (n1, n2); } -> + Format.fprintf fmt "¬((%a) ∧ (%a))" pp_node n1 pp_node n2 + +(* -------------------------------------------------------------------- *) +let zpad (n: int) (r: Aig.reg) = + if Array.length r < n then + Array.append r (Array.init (n - (Array.length r)) (fun _ -> Aig.false_)) + else r diff --git a/libs/lospecs/io.ml b/libs/lospecs/io.ml new file mode 100644 index 0000000000..dff9a0193c --- /dev/null +++ b/libs/lospecs/io.ml @@ -0,0 +1,38 @@ +(* -------------------------------------------------------------------- *) +open Ptree + +(* -------------------------------------------------------------------- *) +let parse (name : string) (input : IO.input) : Ptree.pprogram = + let lexbuf = Lexing.from_channel input in + Lexing.set_filename lexbuf name; + Parser.program Lexer.main lexbuf + +(* -------------------------------------------------------------------- *) +let print_source_for_range (fmt : Format.formatter) (range : range) (name : string) = + let lines = File.lines_of name in + let nlines = Enum.count lines in + + let begin_ = fst range.rg_begin - 1 in + let end_ = fst range.rg_end in + + let ctxt = 2 in + let ctxt_s = max 0 (begin_ - ctxt) in + let ctxt_e = min nlines (end_ + ctxt) in + + let lines = Enum.skip ctxt_s lines in + let lines = Enum.take (ctxt_e - ctxt_s) lines in + + let sz = int_of_float (ceil (log10 (float_of_int end_ +. 1.))) in + + begin + let doline (i : int) = Format.sprintf "%d---------" i in + Format.fprintf fmt "%*s | %s@." + sz "" + (String.concat "" (List.map doline (List.init 7 identity))); + end; + Enum.iteri + (fun i line -> + let lineno = ctxt_s + i in + let mark = if begin_ <= lineno && lineno < end_ then ">" else " " in + Format.fprintf fmt "%*d %s| %s@." sz (lineno + 1) mark line) + lines diff --git a/libs/lospecs/lexer.mll b/libs/lospecs/lexer.mll new file mode 100644 index 0000000000..21346fa2ad --- /dev/null +++ b/libs/lospecs/lexer.mll @@ -0,0 +1,75 @@ +{ + open Parser + + let keywords = [ + ("fun" , FUN ); + ("let" , LET ); + ("in" , IN ); + ] + + let keywords = + let table = Hashtbl.create 0 in + List.iter (fun (x, k) -> Hashtbl.add table x k) keywords; + table +} + +let lower = ['a'-'z'] +let upper = ['A'-'Z'] +let alpha = lower | upper +let digit = ['0'-'9'] +let hexdigit = digit | ['a'-'f'] | ['A'-'F'] +let alnum = alpha | digit + +let ident = (alpha | '_') (alnum | '_')* + +let decnum = digit+ +let hexnum = "0x" hexdigit+ + +let whitespace = [' ' '\t' '\r'] + +rule main = parse + | '<' { LT } + | '>' { GT } + | '(' { LPAREN } + | ')' { RPAREN } + | '[' { LBRACKET } + | ']' { RBRACKET } + | '@' { AT } + | "<-" { LARROW } + | "->" { RARROW } + | ',' { COMMA } + | '=' { EQUAL } + | ':' { COLON } + | '.' { DOT } + | '|' { PIPE } + | '?' { QMARK } + + | ident as x + { Hashtbl.find_default keywords x (IDENT x) } + + | decnum as d + { NUMBER (int_of_string d) } + + | hexnum as d + { NUMBER (int_of_string d) } + + | whitespace+ + { main lexbuf } + + | '\n' + { Lexing.new_line lexbuf; main lexbuf } + + | '#' [^'\n']* + { main lexbuf } + +(* DEBUG FEATURE: for binary searching for syntax errors + to be switched for better error output *) + | '^' _* + { main lexbuf } + + | eof + { EOF } + + | _ { + raise (Ptree.ParseError (Ptree.Lc.of_lexbuf lexbuf)) + } diff --git a/libs/lospecs/parser.mly b/libs/lospecs/parser.mly new file mode 100644 index 0000000000..dad00ee271 --- /dev/null +++ b/libs/lospecs/parser.mly @@ -0,0 +1,148 @@ +%{ + open Ptree + + let string_of_position ((p1, p2) : Lexing.position * Lexing.position) = + Format.sprintf "%d.%d:%d.%d" + p1.pos_lnum (p1.pos_cnum - p1.pos_bol + 1) + p2.pos_lnum (p2.pos_cnum - p2.pos_bol + 1) +%} + +%token AT +%token COLON +%token COMMA +%token DOT +%token EOF +%token EQUAL +%token FUN +%token GT +%token LARROW +%token LBRACKET +%token LET +%token LPAREN +%token LT +%token IN +%token PIPE +%token QMARK +%token RARROW +%token RBRACKET +%token RPAREN + +%token IDENT +%token NUMBER + +%type program + +%start program + +%nonassoc below_TERNARY +%left QMARK +%left COLON + +%% + +%inline vname: +| x=loc(IDENT) + { x } + +%inline wname: +| x=vname t=wtype + { (x, t) } + +%inline wtype_: +| AT x=NUMBER + { `W x } + +%inline wtype: +| w=loc(wtype_) { w } + +fname_: +| f=loc(IDENT) + { (f, None) } + +| f=loc(IDENT) p=angled(list0(loc(NUMBER), COMMA)) + { (f, Some (List.map (Lc.map (fun x -> `W x)) p)) } + +%inline fname: +| f=loc(fname_) { f } + +sexpr_: +| f=fname + { PEFName f } + +| f=fname args=parens(list0(loc(earg), COMMA)) + { PEApp (f, args) } + +| e=parens(expr) + { PEParens e } + +| i=NUMBER + { PEInt (i, None) } + +| i=NUMBER w=wtype + { PEInt (i, Some w) } + +%inline sexpr: +| e=loc(sexpr_) { e } + +expr_: +| e=sexpr_ + { e } + +| FUN args=wname* DOT body=expr %prec below_TERNARY + { PEFun (args, body) } + +| LET x=loc(IDENT) args=parens(list0(wname, COMMA))? EQUAL e1=expr IN e2=expr %prec below_TERNARY + { PELet ((x, args, e1), e2) } + +| e=sexpr LBRACKET + s=ioption(AT s=expr PIPE { s }) i=expr j=prefix(COLON, expr)? + RBRACKET + { PESlice (e, (i, j, s)) } + +| e=sexpr LBRACKET + s=ioption(AT s=expr PIPE { s }) i=expr j=prefix(COLON, expr)? + LARROW r=expr + RBRACKET + { PEAssign (e, (i, j, s), r) } + +| c=expr QMARK e1=expr COLON e2=expr + { PECond (c, (e1, e2)) } + +%inline expr: +| e=loc(expr_) { e } + +earg: +| DOT + { None } + +| e=expr + { Some e } + +def: +| name=IDENT args=parens(list0(wname, COMMA)) RARROW rty=wtype EQUAL body=expr + { { name; args; rty; body; } } + +program: +| defs=def* EOF + { defs } + +| error + { raise (ParseError (Lc.of_positions (fst $loc) (snd $loc))) } + +%inline parens(X): +| LPAREN x=X RPAREN { x } + +%inline angled(X): +| LT x=X GT { x } + +%inline list0(X, S): +| x=separated_list(S, X) { x } + +%inline prefix(S, X): +| S x=X { x } + +%inline loc(X): +| data=X { + let range = Lc.of_positions $startpos $endpos in + { range; data; } + } diff --git a/libs/lospecs/ptree.ml b/libs/lospecs/ptree.ml new file mode 100644 index 0000000000..6c1da94ce7 --- /dev/null +++ b/libs/lospecs/ptree.ml @@ -0,0 +1,107 @@ +(* -------------------------------------------------------------------- *) +open Lexing + +(* -------------------------------------------------------------------- *) +type range = { + rg_fname : string; + rg_begin : int * int; + rg_end : int * int; +} [@@deriving yojson] + +type 'a loced = { range : range; data : 'a; } [@@deriving yojson] + +(* -------------------------------------------------------------------- *) +module Lc = struct + let of_positions (p1 : position) (p2 : position) : range = + assert (p1.pos_fname = p2.pos_fname); + + let mk_range (p : position) = + (p.pos_lnum, p.pos_cnum - p.pos_bol) in + + { rg_fname = p1.pos_fname; rg_begin = mk_range p1; rg_end = mk_range p2; } + + let of_lexbuf (lx : Lexing.lexbuf) : range = + let p1 = Lexing.lexeme_start_p lx in + let p2 = Lexing.lexeme_end_p lx in + of_positions p1 p2 + + let merge (p1 : range) (p2 : range) = + assert (p1.rg_fname = p2.rg_fname); + { rg_fname = p1.rg_fname; + rg_begin = min p1.rg_begin p2.rg_begin; + rg_end = max p1.rg_end p2.rg_end; } + + (* Dead code? FIXME PR *) + let mergeall (p : range list) = + match p with + | [] -> assert false + | t :: ts -> List.fold_left merge t ts + + let unloc (x : 'a loced) : 'a = + x.data + + let range (x : 'a loced) : range = + x.range + + let mk (range : range) (data : 'a) : 'a loced = + { range; data; } + + let map (f : 'a -> 'b) (x : 'a loced) : 'b loced = + { x with data = f x.data } + + let string_of_range (range : range) = + let spos = + if range.rg_begin = range.rg_end then + Printf.sprintf "line %d (%d)" + (fst range.rg_begin) (snd range.rg_begin + 1) + else if fst range.rg_begin = fst range.rg_end then + Printf.sprintf "line %d (%d-%d)" + (fst range.rg_begin) (snd range.rg_begin + 1) (snd range.rg_end + 1) + else + Printf.sprintf "line %d (%d) to line %d (%d)" + (fst range.rg_begin) (snd range.rg_begin + 1) + (fst range.rg_end ) (snd range.rg_end + 1) + in + Printf.sprintf "%s: %s" range.rg_fname spos + + let pp_range (fmt : Format.formatter) (range : range) = + Format.fprintf fmt "%s" (string_of_range range) +end + +(* -------------------------------------------------------------------- *) +exception ParseError of range + +(* -------------------------------------------------------------------- *) +type symbol = string [@@deriving yojson] +type word = [ `W of int ] [@@deriving yojson] +type type_ = [ `Unsigned | `Signed | word ] [@@deriving yojson] + +(* -------------------------------------------------------------------- *) +type psymbol = symbol loced [@@deriving yojson] +type pword = word loced [@@deriving yojson] +type ptype = type_ loced [@@deriving yojson] +type parg = psymbol * pword [@@deriving yojson] +type pargs = parg list [@@deriving yojson] +type pfname = (psymbol * pword list option) loced [@@deriving yojson] + +(* -------------------------------------------------------------------- *) +type pexpr_ = + | PEParens of pexpr + | PEFName of pfname + | PEInt of int * pword option + | PECond of pexpr * (pexpr * pexpr) + | PEFun of pargs * pexpr + | PELet of (psymbol * pargs option * pexpr) * pexpr + | PESlice of pexpr * pslice + | PEAssign of pexpr * pslice * pexpr + | PEApp of pfname * pexpr option loced list +[@@deriving yojson] + +and pexpr = pexpr_ loced [@@deriving yojson] + +and pslice = (pexpr * pexpr option * pexpr option) [@@deriving yojson] + +type pdef = { name : symbol; args : pargs; rty : pword; body : pexpr } +[@@deriving yojson] + +type pprogram = pdef list [@@deriving yojson] diff --git a/libs/lospecs/tests/avx2.ml b/libs/lospecs/tests/avx2.ml new file mode 100644 index 0000000000..d17d7f26e4 --- /dev/null +++ b/libs/lospecs/tests/avx2.ml @@ -0,0 +1,259 @@ +(* -------------------------------------------------------------- *) +type 'a pair = 'a * 'a +type 'a quad = 'a * 'a * 'a * 'a + +(* -------------------------------------------------------------- *) +type m64x2 = int64 pair +type m64x4 = int64 quad +type m32x4 = int32 pair pair +type m32x8 = int32 pair quad +type m16x8 = int pair pair pair +type m16x16 = int pair pair quad +type m8x16 = char pair pair pair pair +type m8x32 = char pair pair pair quad + +(* -------------------------------------------------------------- *) +type m128 = m64x2 +type m256 = m64x4 + +(* -------------------------------------------------------------- *) +type endianess = [`Little | `Big] + +(* -------------------------------------------------------------- *) +type size = [`U8 | `U16 | `U32 | `U64] + +let width_of_size (s : size) : int = + match s with + | `U8 -> 8 + | `U16 -> 16 + | `U32 -> 32 + | `U64 -> 64 + +(* -------------------------------------------------------------- *) +let pp_bytes + ~(size : size) + (fmt : Format.formatter) + (v : bytes) += + let w = width_of_size size / 8 in + + v |> Bytes.iteri (fun i b -> + if i <> 0 && i mod w = 0 then + Format.fprintf fmt "_"; + Format.fprintf fmt "%02x" (Char.code b) + ) + +(* -------------------------------------------------------------- *) +let map_quad (type a) (type b) + (f : a -> b) + ((x0, x1, x2, x3) : a quad) += + (f x0, f x1, f x2, f x3) + +(* -------------------------------------------------------------- *) +let map_pair (type a) (type b) (f : a -> b) ((x, y) : a pair) = + (f x, f y) + +(* -------------------------------------------------------------- *) +external m64_to_32x2 : int64 -> int32 pair = "m64_to_32x2" +external m32_to_16x2 : int32 -> int pair = "m32_to_16x2" +external m16_to_8x2 : int -> char pair = "m16_to_8x2" + +(* -------------------------------------------------------------- *) +external m64_of_32x2 : int32 pair -> int64 = "m64_of_32x2" +external m32_of_16x2 : int pair -> int32 = "m32_of_16x2" +external m16_of_8x2 : char pair -> int = "m16_of_8x2" + +(* -------------------------------------------------------------- *) +module M256 = struct + (* ------------------------------------------------------------ *) + external oftuple_64 : m64x4 -> m256 = "%identity" + external totuple_64 : m256 -> m64x4 = "%identity" + + (* ------------------------------------------------------------ *) + let to_bytes ~(endianess : endianess) (v : m256) : bytes = + let w0, w1, w2, w3 = totuple_64 v in + let b = Buffer.create 32 in + + let () = + match endianess with + | `Little -> + Buffer.add_int64_le b w0; + Buffer.add_int64_le b w1; + Buffer.add_int64_le b w2; + Buffer.add_int64_le b w3; + + | `Big -> + Buffer.add_int64_be b w3; + Buffer.add_int64_be b w2; + Buffer.add_int64_be b w1; + Buffer.add_int64_be b w0 + + in Buffer.to_bytes b + + (* ------------------------------------------------------------ *) + let of_bytes ~(endianess : endianess) (v : bytes) : m256 = + assert (Bytes.length v = 32); + + let w0, w1, w2, w3 = + match endianess with + | `Big -> ( + Bytes.get_int64_be v 24, + Bytes.get_int64_be v 16, + Bytes.get_int64_be v 8, + Bytes.get_int64_be v 0 + ) + | `Little -> ( + Bytes.get_int64_le v 0, + Bytes.get_int64_le v 8, + Bytes.get_int64_le v 16, + Bytes.get_int64_le v 24 + ) + + in oftuple_64 (w0, w1, w2, w3) + + (* ------------------------------------------------------------ *) + let pp + ~(size : size) + ~(endianess : endianess) + (fmt : Format.formatter) + (v : m256) + = + Format.fprintf fmt "%a" (pp_bytes ~size) (to_bytes ~endianess v) + + (* ------------------------------------------------------------ *) + let oftuple_32 (v : m32x8) : m256 = + oftuple_64 (map_quad m64_of_32x2 v) + + let totuple_32 (v : m256) : m32x8 = + map_quad m64_to_32x2 (totuple_64 v) + + (* ------------------------------------------------------------ *) + let oftuple_16 (v : m16x16) : m256 = + oftuple_32 (map_quad (map_pair m32_of_16x2) v) + + let totuple_16 (v : m256) : m16x16 = + map_quad (map_pair m32_to_16x2) (totuple_32 v) + + (* ------------------------------------------------------------ *) + let oftuple_8 (v : m8x32) : m256 = + oftuple_16 (map_quad (map_pair (map_pair m16_of_8x2)) v) + + let totuple_8 (v : m256) : m8x32 = + map_quad (map_pair (map_pair m16_to_8x2)) (totuple_16 v) + + (* ------------------------------------------------------------ *) + let random () : m256 = + let w0 = Random.bits64() in + let w1 = Random.bits64() in + let w2 = Random.bits64() in + let w3 = Random.bits64() in + oftuple_64 (w0, w1, w2, w3) +end + +(* -------------------------------------------------------------- *) +module M128 = struct + (* ------------------------------------------------------------ *) + external oftuple_64 : m64x2 -> m128 = "%identity" + external totuple_64 : m128 -> m64x2 = "%identity" + + (* ------------------------------------------------------------ *) + let to_bytes ~(endianess : endianess) (v : m128) : bytes = + let w0, w1 = totuple_64 v in + let b = Buffer.create 32 in + + let () = + match endianess with + | `Little -> + Buffer.add_int64_le b w0; + Buffer.add_int64_le b w1 + + | `Big -> + Buffer.add_int64_be b w1; + Buffer.add_int64_be b w0 + + in Buffer.to_bytes b + + (* ------------------------------------------------------------ *) + let of_bytes ~(endianess : endianess) (v : bytes) : m128 = + assert (Bytes.length v = 16); + + let w0, w1 = + match endianess with + | `Big -> ( + Bytes.get_int64_be v 8, + Bytes.get_int64_be v 0 + ) + | `Little -> ( + Bytes.get_int64_le v 0, + Bytes.get_int64_le v 8 + ) + + in oftuple_64 (w0, w1) + + (* ------------------------------------------------------------ *) + let pp + ~(size : size) + ~(endianess : endianess) + (fmt : Format.formatter) + (v : m128) + = + Format.fprintf fmt "%a" (pp_bytes ~size) (to_bytes ~endianess v) + + (* ------------------------------------------------------------ *) + let oftuple_32 (v : m32x4) : m128 = + oftuple_64 (map_pair m64_of_32x2 v) + + let totuple_32 (v : m128) : m32x4 = + map_pair m64_to_32x2 (totuple_64 v) + + (* ------------------------------------------------------------ *) + let oftuple_16 (v : m16x8) : m128 = + oftuple_32 (map_pair (map_pair m32_of_16x2) v) + + let totuple_16 (v : m128) : m16x8 = + map_pair (map_pair m32_to_16x2) (totuple_32 v) + + (* ------------------------------------------------------------ *) + let oftuple_8 (v : m8x16) : m128 = + oftuple_16 (map_pair (map_pair (map_pair m16_of_8x2)) v) + + let totuple_8 (v : m128) : m8x16 = + map_pair (map_pair (map_pair m16_to_8x2)) (totuple_16 v) + + (* ------------------------------------------------------------ *) + let random () : m128 = + let w0 = Random.bits64() in + let w1 = Random.bits64() in + oftuple_64 (w0, w1) +end + +(* -------------------------------------------------------------- *) +external mm256_and_si256 : m256 -> m256 -> m256 = "caml_simde_mm256_and_si256" +external mm256_andnot_si256 : m256 -> m256 -> m256 = "caml_simde_mm256_andnot_si256" +external mm256_add_epi8 : m256 -> m256 -> m256 = "caml_simde_mm256_add_epi8" +external mm256_add_epi16 : m256 -> m256 -> m256 = "caml_simde_mm256_add_epi16" +external mm256_sub_epi8 : m256 -> m256 -> m256 = "caml_simde_mm256_sub_epi8" +external mm256_sub_epi16 : m256 -> m256 -> m256 = "caml_simde_mm256_sub_epi16" +external mm256_mulhi_epi16 : m256 -> m256 -> m256 = "caml_simde_mm256_mulhi_epi16" +external mm256_mulhi_epu16 : m256 -> m256 -> m256 = "caml_simde_mm256_mulhi_epu16" +external mm256_mulhrs_epi16 : m256 -> m256 -> m256 = "caml_simde_mm256_mulhrs_epi16" +external mm256_packus_epi16 : m256 -> m256 -> m256 = "caml_simde_mm256_packus_epi16" +external mm256_packs_epi16 : m256 -> m256 -> m256 = "caml_simde_mm256_packs_epi16" +external mm256_maddubs_epi16 : m256 -> m256 -> m256 = "caml_simde_mm256_maddubs_epi16" +external mm256_permutevar8x32_epi32 : m256 -> m256 -> m256 = "caml_simde_mm256_permutevar8x32_epi32" +external mm256_permute4x64_epi64 : m256 -> int -> m256 = "caml_simde_mm256_permute4x64_epi64_dyn" +external mm256_permute2x128_si256 : m256 -> m256 -> int -> m256 = "caml_simde_mm256_permute2x128_si256_dyn" +external mm256_shuffle_epi8 : m256 -> m256 -> m256 = "caml_simde_mm256_shuffle_epi8" +external mm256_srai_epi16 : m256 -> int -> m256 = "caml_simde_mm256_srai_epi16" +external mm256_srli_epi16 : m256 -> int -> m256 = "caml_simde_mm256_srli_epi16" +external mm256_cmpgt_epi16 : m256 -> m256 -> m256 = "caml_simde_mm256_cmpgt_epi16" +external mm256_movemask_epi8 : m256 -> int32 = "caml_simde_mm256_movemask_epi8" +external mm256_unpacklo_epi8 : m256 -> m256 -> m256 = "caml_simde_mm256_unpacklo_epi8" +external mm256_unpacklo_epi64 : m256 -> m256 -> m256 = "caml_simde_mm256_unpacklo_epi64" +external mm256_unpackhi_epi64 : m256 -> m256 -> m256 = "caml_simde_mm256_unpackhi_epi64" +external mm256_blend_epi16 : m256 -> m256 -> int -> m256 = "caml_simde_mm256_blend_epi16_dyn" +external mm256_blend_epi32 : m256 -> m256 -> int -> m256 = "caml_simde_mm256_blend_epi32_dyn" +external mm256_moveldup_ps : m256 -> m256 = "caml_simde_mm256_moveldup_ps_dyn" +external mm256_inserti128_si256 : m256 -> m128 -> int -> m256 = "caml_simde_mm256_inserti128_si256_dyn" +external mm256_extracti128_si256 : m256 -> int -> m128 = "caml_simde_mm256_extracti128_si256_dyn" diff --git a/libs/lospecs/tests/avx2_runtime.cpp b/libs/lospecs/tests/avx2_runtime.cpp new file mode 100644 index 0000000000..0cbd3c7979 --- /dev/null +++ b/libs/lospecs/tests/avx2_runtime.cpp @@ -0,0 +1,534 @@ +/* ==================================================================== */ +#include +#include "avx2_runtime.h" + +/* -------------------------------------------------------------------- */ +#include + +/* -------------------------------------------------------------------- */ +#include +#include +#include +#include +#include + +/* ==================================================================== */ +extern "C" CAMLprim value m64_of_32x2(value lohi) { + CAMLparam1(lohi); + + const uint32_t lo = (uint32_t) Int32_val(Field(lohi, 0)); + const uint32_t hi = (uint32_t) Int32_val(Field(lohi, 1)); + + const uint64_t out = ((uint64_t) lo) | (((uint64_t) hi) << 32); + + CAMLreturn(caml_copy_int64((int64_t) out)); +} + +/* -------------------------------------------------------------------- */ +extern "C" CAMLprim value m64_to_32x2(value lohi) { + CAMLparam1(lohi); + CAMLlocal1(out); + + const uint64_t v = (uint64_t) Int64_val(lohi); + + const uint32_t lo = (v >> 0) & 0xffffffff; + const uint32_t hi = (v >> 32) & 0xffffffff; + + out = caml_alloc_tuple(2); + Field(out, 0) = caml_copy_int32(lo); + Field(out, 1) = caml_copy_int32(hi); + + CAMLreturn(out); +} + +/* -------------------------------------------------------------------- */ +extern "C" CAMLprim value m32_of_16x2(value lohi) { + CAMLparam1(lohi); + + const uint16_t lo = (uint16_t) Int_val(Field(lohi, 0)); + const uint16_t hi = (uint16_t) Int_val(Field(lohi, 1)); + + const uint32_t out = ((uint32_t) lo) | (((uint32_t) hi) << 16); + + CAMLreturn(caml_copy_int32((int32_t) out)); +} + +/* -------------------------------------------------------------------- */ +extern "C" CAMLprim value m32_to_16x2(value lohi) { + CAMLparam1(lohi); + CAMLlocal1(out); + + const uint32_t v = (uint32_t) Int32_val(lohi); + + const uint16_t lo = (v >> 0) & 0xffff; + const uint16_t hi = (v >> 16) & 0xffff; + + out = caml_alloc_tuple(2); + Field(out, 0) = Val_int(lo); + Field(out, 1) = Val_int(hi); + + CAMLreturn(out); +} + +/* -------------------------------------------------------------------- */ +extern "C" CAMLprim value m16_of_8x2(value lohi) { + CAMLparam1(lohi); + + const uint8_t lo = (uint8_t) Int_val(Field(lohi, 0)); + const uint8_t hi = (uint8_t) Int_val(Field(lohi, 1)); + + const uint16_t out = ((uint16_t) lo) | (((uint16_t) hi) << 8); + + CAMLreturn(Val_int(out)); +} + +/* -------------------------------------------------------------------- */ + extern "C" CAMLprim value m16_to_8x2(value lohi) { + CAMLparam1(lohi); + CAMLlocal1(out); + + const uint16_t v = (uint16_t) Int_val(lohi); + + const uint8_t lo = (v >> 0) & 0xff; + const uint8_t hi = (v >> 8) & 0xff; + + out = caml_alloc_tuple(2); + Field(out, 0) = Val_int(lo); + Field(out, 1) = Val_int(hi); + + CAMLreturn(out); +} + +/* ==================================================================== */ +#if defined(HAS_AVX) +/* -------------------------------------------------------------------- */ +value value_of_w256(simde__m256i x) { + CAMLparam0(); + CAMLlocal1(out); + + out = caml_alloc_tuple(4); + Store_field(out, 0, caml_copy_int64(simde_mm256_extract_epi64(x, 0))); + Store_field(out, 1, caml_copy_int64(simde_mm256_extract_epi64(x, 1))); + Store_field(out, 2, caml_copy_int64(simde_mm256_extract_epi64(x, 2))); + Store_field(out, 3, caml_copy_int64(simde_mm256_extract_epi64(x, 3))); + + CAMLreturn(out); +} + +/* -------------------------------------------------------------------- */ +simde__m256i w256_of_value(value x) { + CAMLparam1(x); + + simde__m256i out = simde_mm256_set_epi64x( + Int64_val(Field(x, 3)), + Int64_val(Field(x, 2)), + Int64_val(Field(x, 1)), + Int64_val(Field(x, 0)) + ); + + CAMLreturnT(simde__m256i, out); +} + +/* -------------------------------------------------------------------- */ +value value_of_w128(simde__m128i x) { + CAMLparam0(); + CAMLlocal1(out); + + out = caml_alloc_tuple(2); + Store_field(out, 0, caml_copy_int64(simde_mm_extract_epi64(x, 0))); + Store_field(out, 1, caml_copy_int64(simde_mm_extract_epi64(x, 1))); + + CAMLreturn(out); +} + +/* -------------------------------------------------------------------- */ +simde__m128i w128_of_value(value x) { + CAMLparam1(x); + + simde__m128i out = simde_mm_set_epi64x( + Int64_val(Field(x, 1)), + Int64_val(Field(x, 0)) + ); + + CAMLreturnT(simde__m128i, out); +} + +/* -------------------------------------------------------------------- */ +simde__m256i simde_mm256_inserti128_si256_dyn(simde__m256i a, simde__m128i b, const int imm8) { + switch (imm8 & 0x01) { + case 0: + return simde_mm256_inserti128_si256(a, b, 0); + case 1: + return simde_mm256_inserti128_si256(a, b, 1); + } + abort(); +} + +/* -------------------------------------------------------------------- */ +simde__m128i simde_mm256_extracti128_si256_dyn(simde__m256i a, const int imm8) { + switch (imm8 & 0x01) { + case 0: + return simde_mm256_extracti128_si256(a, 0); + case 1: + return simde_mm256_extracti128_si256(a, 1); + } + abort(); +} + +/* -------------------------------------------------------------------- */ +simde__m256i simde_mm256_blend_epi16_dyn(simde__m256i a, simde__m256i b, const int imm8) { +#define CASE(I) case I: return simde_mm256_blend_epi16(a, b, I) + + /* + * for i in range(0, 256, 4): + * print('; '.join(f'CASE(0x{j:02x})' for j in range(i, i+4)) + ';') + */ + switch (imm8 & 0xff) { + CASE(0x00); CASE(0x01); CASE(0x02); CASE(0x03); + CASE(0x04); CASE(0x05); CASE(0x06); CASE(0x07); + CASE(0x08); CASE(0x09); CASE(0x0a); CASE(0x0b); + CASE(0x0c); CASE(0x0d); CASE(0x0e); CASE(0x0f); + CASE(0x10); CASE(0x11); CASE(0x12); CASE(0x13); + CASE(0x14); CASE(0x15); CASE(0x16); CASE(0x17); + CASE(0x18); CASE(0x19); CASE(0x1a); CASE(0x1b); + CASE(0x1c); CASE(0x1d); CASE(0x1e); CASE(0x1f); + CASE(0x20); CASE(0x21); CASE(0x22); CASE(0x23); + CASE(0x24); CASE(0x25); CASE(0x26); CASE(0x27); + CASE(0x28); CASE(0x29); CASE(0x2a); CASE(0x2b); + CASE(0x2c); CASE(0x2d); CASE(0x2e); CASE(0x2f); + CASE(0x30); CASE(0x31); CASE(0x32); CASE(0x33); + CASE(0x34); CASE(0x35); CASE(0x36); CASE(0x37); + CASE(0x38); CASE(0x39); CASE(0x3a); CASE(0x3b); + CASE(0x3c); CASE(0x3d); CASE(0x3e); CASE(0x3f); + CASE(0x40); CASE(0x41); CASE(0x42); CASE(0x43); + CASE(0x44); CASE(0x45); CASE(0x46); CASE(0x47); + CASE(0x48); CASE(0x49); CASE(0x4a); CASE(0x4b); + CASE(0x4c); CASE(0x4d); CASE(0x4e); CASE(0x4f); + CASE(0x50); CASE(0x51); CASE(0x52); CASE(0x53); + CASE(0x54); CASE(0x55); CASE(0x56); CASE(0x57); + CASE(0x58); CASE(0x59); CASE(0x5a); CASE(0x5b); + CASE(0x5c); CASE(0x5d); CASE(0x5e); CASE(0x5f); + CASE(0x60); CASE(0x61); CASE(0x62); CASE(0x63); + CASE(0x64); CASE(0x65); CASE(0x66); CASE(0x67); + CASE(0x68); CASE(0x69); CASE(0x6a); CASE(0x6b); + CASE(0x6c); CASE(0x6d); CASE(0x6e); CASE(0x6f); + CASE(0x70); CASE(0x71); CASE(0x72); CASE(0x73); + CASE(0x74); CASE(0x75); CASE(0x76); CASE(0x77); + CASE(0x78); CASE(0x79); CASE(0x7a); CASE(0x7b); + CASE(0x7c); CASE(0x7d); CASE(0x7e); CASE(0x7f); + CASE(0x80); CASE(0x81); CASE(0x82); CASE(0x83); + CASE(0x84); CASE(0x85); CASE(0x86); CASE(0x87); + CASE(0x88); CASE(0x89); CASE(0x8a); CASE(0x8b); + CASE(0x8c); CASE(0x8d); CASE(0x8e); CASE(0x8f); + CASE(0x90); CASE(0x91); CASE(0x92); CASE(0x93); + CASE(0x94); CASE(0x95); CASE(0x96); CASE(0x97); + CASE(0x98); CASE(0x99); CASE(0x9a); CASE(0x9b); + CASE(0x9c); CASE(0x9d); CASE(0x9e); CASE(0x9f); + CASE(0xa0); CASE(0xa1); CASE(0xa2); CASE(0xa3); + CASE(0xa4); CASE(0xa5); CASE(0xa6); CASE(0xa7); + CASE(0xa8); CASE(0xa9); CASE(0xaa); CASE(0xab); + CASE(0xac); CASE(0xad); CASE(0xae); CASE(0xaf); + CASE(0xb0); CASE(0xb1); CASE(0xb2); CASE(0xb3); + CASE(0xb4); CASE(0xb5); CASE(0xb6); CASE(0xb7); + CASE(0xb8); CASE(0xb9); CASE(0xba); CASE(0xbb); + CASE(0xbc); CASE(0xbd); CASE(0xbe); CASE(0xbf); + CASE(0xc0); CASE(0xc1); CASE(0xc2); CASE(0xc3); + CASE(0xc4); CASE(0xc5); CASE(0xc6); CASE(0xc7); + CASE(0xc8); CASE(0xc9); CASE(0xca); CASE(0xcb); + CASE(0xcc); CASE(0xcd); CASE(0xce); CASE(0xcf); + CASE(0xd0); CASE(0xd1); CASE(0xd2); CASE(0xd3); + CASE(0xd4); CASE(0xd5); CASE(0xd6); CASE(0xd7); + CASE(0xd8); CASE(0xd9); CASE(0xda); CASE(0xdb); + CASE(0xdc); CASE(0xdd); CASE(0xde); CASE(0xdf); + CASE(0xe0); CASE(0xe1); CASE(0xe2); CASE(0xe3); + CASE(0xe4); CASE(0xe5); CASE(0xe6); CASE(0xe7); + CASE(0xe8); CASE(0xe9); CASE(0xea); CASE(0xeb); + CASE(0xec); CASE(0xed); CASE(0xee); CASE(0xef); + CASE(0xf0); CASE(0xf1); CASE(0xf2); CASE(0xf3); + CASE(0xf4); CASE(0xf5); CASE(0xf6); CASE(0xf7); + CASE(0xf8); CASE(0xf9); CASE(0xfa); CASE(0xfb); + CASE(0xfc); CASE(0xfd); CASE(0xfe); CASE(0xff); + } + abort(); +#undef CASE +} + +/* -------------------------------------------------------------------- */ +simde__m256i simde_mm256_blend_epi32_dyn(simde__m256i a, simde__m256i b, const int imm8) { +#define CASE(I) case I: return simde_mm256_blend_epi32(a, b, I) + + /* + * for i in range(0, 256, 4): + * print('; '.join(f'CASE(0x{j:02x})' for j in range(i, i+4)) + ';') + */ + switch (imm8 & 0xff) { + CASE(0x00); CASE(0x01); CASE(0x02); CASE(0x03); + CASE(0x04); CASE(0x05); CASE(0x06); CASE(0x07); + CASE(0x08); CASE(0x09); CASE(0x0a); CASE(0x0b); + CASE(0x0c); CASE(0x0d); CASE(0x0e); CASE(0x0f); + CASE(0x10); CASE(0x11); CASE(0x12); CASE(0x13); + CASE(0x14); CASE(0x15); CASE(0x16); CASE(0x17); + CASE(0x18); CASE(0x19); CASE(0x1a); CASE(0x1b); + CASE(0x1c); CASE(0x1d); CASE(0x1e); CASE(0x1f); + CASE(0x20); CASE(0x21); CASE(0x22); CASE(0x23); + CASE(0x24); CASE(0x25); CASE(0x26); CASE(0x27); + CASE(0x28); CASE(0x29); CASE(0x2a); CASE(0x2b); + CASE(0x2c); CASE(0x2d); CASE(0x2e); CASE(0x2f); + CASE(0x30); CASE(0x31); CASE(0x32); CASE(0x33); + CASE(0x34); CASE(0x35); CASE(0x36); CASE(0x37); + CASE(0x38); CASE(0x39); CASE(0x3a); CASE(0x3b); + CASE(0x3c); CASE(0x3d); CASE(0x3e); CASE(0x3f); + CASE(0x40); CASE(0x41); CASE(0x42); CASE(0x43); + CASE(0x44); CASE(0x45); CASE(0x46); CASE(0x47); + CASE(0x48); CASE(0x49); CASE(0x4a); CASE(0x4b); + CASE(0x4c); CASE(0x4d); CASE(0x4e); CASE(0x4f); + CASE(0x50); CASE(0x51); CASE(0x52); CASE(0x53); + CASE(0x54); CASE(0x55); CASE(0x56); CASE(0x57); + CASE(0x58); CASE(0x59); CASE(0x5a); CASE(0x5b); + CASE(0x5c); CASE(0x5d); CASE(0x5e); CASE(0x5f); + CASE(0x60); CASE(0x61); CASE(0x62); CASE(0x63); + CASE(0x64); CASE(0x65); CASE(0x66); CASE(0x67); + CASE(0x68); CASE(0x69); CASE(0x6a); CASE(0x6b); + CASE(0x6c); CASE(0x6d); CASE(0x6e); CASE(0x6f); + CASE(0x70); CASE(0x71); CASE(0x72); CASE(0x73); + CASE(0x74); CASE(0x75); CASE(0x76); CASE(0x77); + CASE(0x78); CASE(0x79); CASE(0x7a); CASE(0x7b); + CASE(0x7c); CASE(0x7d); CASE(0x7e); CASE(0x7f); + CASE(0x80); CASE(0x81); CASE(0x82); CASE(0x83); + CASE(0x84); CASE(0x85); CASE(0x86); CASE(0x87); + CASE(0x88); CASE(0x89); CASE(0x8a); CASE(0x8b); + CASE(0x8c); CASE(0x8d); CASE(0x8e); CASE(0x8f); + CASE(0x90); CASE(0x91); CASE(0x92); CASE(0x93); + CASE(0x94); CASE(0x95); CASE(0x96); CASE(0x97); + CASE(0x98); CASE(0x99); CASE(0x9a); CASE(0x9b); + CASE(0x9c); CASE(0x9d); CASE(0x9e); CASE(0x9f); + CASE(0xa0); CASE(0xa1); CASE(0xa2); CASE(0xa3); + CASE(0xa4); CASE(0xa5); CASE(0xa6); CASE(0xa7); + CASE(0xa8); CASE(0xa9); CASE(0xaa); CASE(0xab); + CASE(0xac); CASE(0xad); CASE(0xae); CASE(0xaf); + CASE(0xb0); CASE(0xb1); CASE(0xb2); CASE(0xb3); + CASE(0xb4); CASE(0xb5); CASE(0xb6); CASE(0xb7); + CASE(0xb8); CASE(0xb9); CASE(0xba); CASE(0xbb); + CASE(0xbc); CASE(0xbd); CASE(0xbe); CASE(0xbf); + CASE(0xc0); CASE(0xc1); CASE(0xc2); CASE(0xc3); + CASE(0xc4); CASE(0xc5); CASE(0xc6); CASE(0xc7); + CASE(0xc8); CASE(0xc9); CASE(0xca); CASE(0xcb); + CASE(0xcc); CASE(0xcd); CASE(0xce); CASE(0xcf); + CASE(0xd0); CASE(0xd1); CASE(0xd2); CASE(0xd3); + CASE(0xd4); CASE(0xd5); CASE(0xd6); CASE(0xd7); + CASE(0xd8); CASE(0xd9); CASE(0xda); CASE(0xdb); + CASE(0xdc); CASE(0xdd); CASE(0xde); CASE(0xdf); + CASE(0xe0); CASE(0xe1); CASE(0xe2); CASE(0xe3); + CASE(0xe4); CASE(0xe5); CASE(0xe6); CASE(0xe7); + CASE(0xe8); CASE(0xe9); CASE(0xea); CASE(0xeb); + CASE(0xec); CASE(0xed); CASE(0xee); CASE(0xef); + CASE(0xf0); CASE(0xf1); CASE(0xf2); CASE(0xf3); + CASE(0xf4); CASE(0xf5); CASE(0xf6); CASE(0xf7); + CASE(0xf8); CASE(0xf9); CASE(0xfa); CASE(0xfb); + CASE(0xfc); CASE(0xfd); CASE(0xfe); CASE(0xff); + } + abort(); +#undef CASE +} + +/* -------------------------------------------------------------------- */ +simde__m256i simde_mm256_permute4x64_epi64_dyn(simde__m256i a, const int imm8) { +#define CASE(I) case I: return simde_mm256_permute4x64_epi64(a, I) + + /* + * for i in range(0, 256, 4): + * print('; '.join(f'CASE(0x{j:02x})' for j in range(i, i+4)) + ';') + */ + switch (imm8 & 0xff) { + CASE(0x00); CASE(0x01); CASE(0x02); CASE(0x03); + CASE(0x04); CASE(0x05); CASE(0x06); CASE(0x07); + CASE(0x08); CASE(0x09); CASE(0x0a); CASE(0x0b); + CASE(0x0c); CASE(0x0d); CASE(0x0e); CASE(0x0f); + CASE(0x10); CASE(0x11); CASE(0x12); CASE(0x13); + CASE(0x14); CASE(0x15); CASE(0x16); CASE(0x17); + CASE(0x18); CASE(0x19); CASE(0x1a); CASE(0x1b); + CASE(0x1c); CASE(0x1d); CASE(0x1e); CASE(0x1f); + CASE(0x20); CASE(0x21); CASE(0x22); CASE(0x23); + CASE(0x24); CASE(0x25); CASE(0x26); CASE(0x27); + CASE(0x28); CASE(0x29); CASE(0x2a); CASE(0x2b); + CASE(0x2c); CASE(0x2d); CASE(0x2e); CASE(0x2f); + CASE(0x30); CASE(0x31); CASE(0x32); CASE(0x33); + CASE(0x34); CASE(0x35); CASE(0x36); CASE(0x37); + CASE(0x38); CASE(0x39); CASE(0x3a); CASE(0x3b); + CASE(0x3c); CASE(0x3d); CASE(0x3e); CASE(0x3f); + CASE(0x40); CASE(0x41); CASE(0x42); CASE(0x43); + CASE(0x44); CASE(0x45); CASE(0x46); CASE(0x47); + CASE(0x48); CASE(0x49); CASE(0x4a); CASE(0x4b); + CASE(0x4c); CASE(0x4d); CASE(0x4e); CASE(0x4f); + CASE(0x50); CASE(0x51); CASE(0x52); CASE(0x53); + CASE(0x54); CASE(0x55); CASE(0x56); CASE(0x57); + CASE(0x58); CASE(0x59); CASE(0x5a); CASE(0x5b); + CASE(0x5c); CASE(0x5d); CASE(0x5e); CASE(0x5f); + CASE(0x60); CASE(0x61); CASE(0x62); CASE(0x63); + CASE(0x64); CASE(0x65); CASE(0x66); CASE(0x67); + CASE(0x68); CASE(0x69); CASE(0x6a); CASE(0x6b); + CASE(0x6c); CASE(0x6d); CASE(0x6e); CASE(0x6f); + CASE(0x70); CASE(0x71); CASE(0x72); CASE(0x73); + CASE(0x74); CASE(0x75); CASE(0x76); CASE(0x77); + CASE(0x78); CASE(0x79); CASE(0x7a); CASE(0x7b); + CASE(0x7c); CASE(0x7d); CASE(0x7e); CASE(0x7f); + CASE(0x80); CASE(0x81); CASE(0x82); CASE(0x83); + CASE(0x84); CASE(0x85); CASE(0x86); CASE(0x87); + CASE(0x88); CASE(0x89); CASE(0x8a); CASE(0x8b); + CASE(0x8c); CASE(0x8d); CASE(0x8e); CASE(0x8f); + CASE(0x90); CASE(0x91); CASE(0x92); CASE(0x93); + CASE(0x94); CASE(0x95); CASE(0x96); CASE(0x97); + CASE(0x98); CASE(0x99); CASE(0x9a); CASE(0x9b); + CASE(0x9c); CASE(0x9d); CASE(0x9e); CASE(0x9f); + CASE(0xa0); CASE(0xa1); CASE(0xa2); CASE(0xa3); + CASE(0xa4); CASE(0xa5); CASE(0xa6); CASE(0xa7); + CASE(0xa8); CASE(0xa9); CASE(0xaa); CASE(0xab); + CASE(0xac); CASE(0xad); CASE(0xae); CASE(0xaf); + CASE(0xb0); CASE(0xb1); CASE(0xb2); CASE(0xb3); + CASE(0xb4); CASE(0xb5); CASE(0xb6); CASE(0xb7); + CASE(0xb8); CASE(0xb9); CASE(0xba); CASE(0xbb); + CASE(0xbc); CASE(0xbd); CASE(0xbe); CASE(0xbf); + CASE(0xc0); CASE(0xc1); CASE(0xc2); CASE(0xc3); + CASE(0xc4); CASE(0xc5); CASE(0xc6); CASE(0xc7); + CASE(0xc8); CASE(0xc9); CASE(0xca); CASE(0xcb); + CASE(0xcc); CASE(0xcd); CASE(0xce); CASE(0xcf); + CASE(0xd0); CASE(0xd1); CASE(0xd2); CASE(0xd3); + CASE(0xd4); CASE(0xd5); CASE(0xd6); CASE(0xd7); + CASE(0xd8); CASE(0xd9); CASE(0xda); CASE(0xdb); + CASE(0xdc); CASE(0xdd); CASE(0xde); CASE(0xdf); + CASE(0xe0); CASE(0xe1); CASE(0xe2); CASE(0xe3); + CASE(0xe4); CASE(0xe5); CASE(0xe6); CASE(0xe7); + CASE(0xe8); CASE(0xe9); CASE(0xea); CASE(0xeb); + CASE(0xec); CASE(0xed); CASE(0xee); CASE(0xef); + CASE(0xf0); CASE(0xf1); CASE(0xf2); CASE(0xf3); + CASE(0xf4); CASE(0xf5); CASE(0xf6); CASE(0xf7); + CASE(0xf8); CASE(0xf9); CASE(0xfa); CASE(0xfb); + CASE(0xfc); CASE(0xfd); CASE(0xfe); CASE(0xff); + } + abort(); +#undef CASE +} + +/* -------------------------------------------------------------------- */ +simde__m256i simde_mm256_permute2x128_si256_dyn(simde__m256i a, simde__m256i b, const int imm8) { +#define CASE(I) case I: return simde_mm256_permute2x128_si256(a, b, I) + + /* + * for i in range(0, 256, 4): + * print('; '.join(f'CASE(0x{j:02x})' for j in range(i, i+4)) + ';') + */ + switch (imm8 & 0xff) { + CASE(0x00); CASE(0x01); CASE(0x02); CASE(0x03); + CASE(0x04); CASE(0x05); CASE(0x06); CASE(0x07); + CASE(0x08); CASE(0x09); CASE(0x0a); CASE(0x0b); + CASE(0x0c); CASE(0x0d); CASE(0x0e); CASE(0x0f); + CASE(0x10); CASE(0x11); CASE(0x12); CASE(0x13); + CASE(0x14); CASE(0x15); CASE(0x16); CASE(0x17); + CASE(0x18); CASE(0x19); CASE(0x1a); CASE(0x1b); + CASE(0x1c); CASE(0x1d); CASE(0x1e); CASE(0x1f); + CASE(0x20); CASE(0x21); CASE(0x22); CASE(0x23); + CASE(0x24); CASE(0x25); CASE(0x26); CASE(0x27); + CASE(0x28); CASE(0x29); CASE(0x2a); CASE(0x2b); + CASE(0x2c); CASE(0x2d); CASE(0x2e); CASE(0x2f); + CASE(0x30); CASE(0x31); CASE(0x32); CASE(0x33); + CASE(0x34); CASE(0x35); CASE(0x36); CASE(0x37); + CASE(0x38); CASE(0x39); CASE(0x3a); CASE(0x3b); + CASE(0x3c); CASE(0x3d); CASE(0x3e); CASE(0x3f); + CASE(0x40); CASE(0x41); CASE(0x42); CASE(0x43); + CASE(0x44); CASE(0x45); CASE(0x46); CASE(0x47); + CASE(0x48); CASE(0x49); CASE(0x4a); CASE(0x4b); + CASE(0x4c); CASE(0x4d); CASE(0x4e); CASE(0x4f); + CASE(0x50); CASE(0x51); CASE(0x52); CASE(0x53); + CASE(0x54); CASE(0x55); CASE(0x56); CASE(0x57); + CASE(0x58); CASE(0x59); CASE(0x5a); CASE(0x5b); + CASE(0x5c); CASE(0x5d); CASE(0x5e); CASE(0x5f); + CASE(0x60); CASE(0x61); CASE(0x62); CASE(0x63); + CASE(0x64); CASE(0x65); CASE(0x66); CASE(0x67); + CASE(0x68); CASE(0x69); CASE(0x6a); CASE(0x6b); + CASE(0x6c); CASE(0x6d); CASE(0x6e); CASE(0x6f); + CASE(0x70); CASE(0x71); CASE(0x72); CASE(0x73); + CASE(0x74); CASE(0x75); CASE(0x76); CASE(0x77); + CASE(0x78); CASE(0x79); CASE(0x7a); CASE(0x7b); + CASE(0x7c); CASE(0x7d); CASE(0x7e); CASE(0x7f); + CASE(0x80); CASE(0x81); CASE(0x82); CASE(0x83); + CASE(0x84); CASE(0x85); CASE(0x86); CASE(0x87); + CASE(0x88); CASE(0x89); CASE(0x8a); CASE(0x8b); + CASE(0x8c); CASE(0x8d); CASE(0x8e); CASE(0x8f); + CASE(0x90); CASE(0x91); CASE(0x92); CASE(0x93); + CASE(0x94); CASE(0x95); CASE(0x96); CASE(0x97); + CASE(0x98); CASE(0x99); CASE(0x9a); CASE(0x9b); + CASE(0x9c); CASE(0x9d); CASE(0x9e); CASE(0x9f); + CASE(0xa0); CASE(0xa1); CASE(0xa2); CASE(0xa3); + CASE(0xa4); CASE(0xa5); CASE(0xa6); CASE(0xa7); + CASE(0xa8); CASE(0xa9); CASE(0xaa); CASE(0xab); + CASE(0xac); CASE(0xad); CASE(0xae); CASE(0xaf); + CASE(0xb0); CASE(0xb1); CASE(0xb2); CASE(0xb3); + CASE(0xb4); CASE(0xb5); CASE(0xb6); CASE(0xb7); + CASE(0xb8); CASE(0xb9); CASE(0xba); CASE(0xbb); + CASE(0xbc); CASE(0xbd); CASE(0xbe); CASE(0xbf); + CASE(0xc0); CASE(0xc1); CASE(0xc2); CASE(0xc3); + CASE(0xc4); CASE(0xc5); CASE(0xc6); CASE(0xc7); + CASE(0xc8); CASE(0xc9); CASE(0xca); CASE(0xcb); + CASE(0xcc); CASE(0xcd); CASE(0xce); CASE(0xcf); + CASE(0xd0); CASE(0xd1); CASE(0xd2); CASE(0xd3); + CASE(0xd4); CASE(0xd5); CASE(0xd6); CASE(0xd7); + CASE(0xd8); CASE(0xd9); CASE(0xda); CASE(0xdb); + CASE(0xdc); CASE(0xdd); CASE(0xde); CASE(0xdf); + CASE(0xe0); CASE(0xe1); CASE(0xe2); CASE(0xe3); + CASE(0xe4); CASE(0xe5); CASE(0xe6); CASE(0xe7); + CASE(0xe8); CASE(0xe9); CASE(0xea); CASE(0xeb); + CASE(0xec); CASE(0xed); CASE(0xee); CASE(0xef); + CASE(0xf0); CASE(0xf1); CASE(0xf2); CASE(0xf3); + CASE(0xf4); CASE(0xf5); CASE(0xf6); CASE(0xf7); + CASE(0xf8); CASE(0xf9); CASE(0xfa); CASE(0xfb); + CASE(0xfc); CASE(0xfd); CASE(0xfe); CASE(0xff); + } + abort(); +#undef CASE +} + +/* -------------------------------------------------------------------- */ +simde__m256i simde_mm256_moveldup_ps_dyn(simde__m256i a) { + return (simde__m256i)simde_mm256_moveldup_ps((simde__m256)a); +} + + +#endif + +extern "C" { +BIND_256x2_256(simde_mm256_permutevar8x32_epi32); +BIND2(simde_mm256_permute4x64_epi64_dyn, M256i, M256i, Long); +BIND3(simde_mm256_permute2x128_si256_dyn, M256i, M256i, M256i, Long); + +BIND_256x2_256(simde_mm256_and_si256); +BIND_256x2_256(simde_mm256_andnot_si256); +BIND_256x2_256(simde_mm256_add_epi8); +BIND_256x2_256(simde_mm256_add_epi16); +BIND_256x2_256(simde_mm256_sub_epi8); +BIND_256x2_256(simde_mm256_sub_epi16); +BIND_256x2_256(simde_mm256_maddubs_epi16); +BIND_256x2_256(simde_mm256_packus_epi16); +BIND_256x2_256(simde_mm256_packs_epi16); +BIND_256x2_256(simde_mm256_mulhi_epi16); +BIND_256x2_256(simde_mm256_mulhi_epu16); +BIND_256x2_256(simde_mm256_mulhrs_epi16); + +BIND_256x2_256(simde_mm256_shuffle_epi8); +BIND_256x2_256(simde_mm256_cmpgt_epi16); +BIND_256x2_256(simde_mm256_unpacklo_epi8); +BIND_256x2_256(simde_mm256_unpacklo_epi64); +BIND_256x2_256(simde_mm256_unpackhi_epi64); + +BIND2(simde_mm256_srai_epi16, M256i, M256i, Long); +BIND2(simde_mm256_srli_epi16, M256i, M256i, Long); + +BIND1(simde_mm256_movemask_epi8, Int32, M256i); +BIND1(simde_mm256_moveldup_ps_dyn, M256i, M256i); + +BIND3(simde_mm256_blend_epi16_dyn, M256i, M256i, M256i, Long); +BIND3(simde_mm256_blend_epi32_dyn, M256i, M256i, M256i, Long); + + +BIND3(simde_mm256_inserti128_si256_dyn, M256i, M256i, M128i, Long); +BIND2(simde_mm256_extracti128_si256_dyn, M128i, M256i, Long); +} diff --git a/libs/lospecs/tests/avx2_runtime.h b/libs/lospecs/tests/avx2_runtime.h new file mode 100644 index 0000000000..e5dd028584 --- /dev/null +++ b/libs/lospecs/tests/avx2_runtime.h @@ -0,0 +1,210 @@ +/* ==================================================================== */ +#if !defined(AVX2_RUNTIME__) +# define AVX2_RUNTIME__ 1 + +#if defined(__x86_64__) || defined(_M_X64) +# define HAS_AVX 1 +# include +#endif + +#define HAS_AVX + +/* -------------------------------------------------------------------- */ +#include +#include +#include +#include +#include + +/* -------------------------------------------------------------------- */ +extern "C" { +CAMLprim value caml_simde_mm256_permutevar8x32_epi32(value, value); +CAMLprim value caml_simde_mm256_permute4x64_epi64_dyn(value, value); +CAMLprim value caml_simde_mm256_permute2x128_si256_dyn(value, value, value); +CAMLprim value m64_of_32x2(value); +CAMLprim value m64_to_32x2(value); +CAMLprim value m32_of_16x2(value lohi); +CAMLprim value m32_to_16x2(value lohi); +CAMLprim value m16_of_8x2(value lohi); +CAMLprim value m16_to_8x2(value lohi); + +CAMLprim value caml_simde_mm256_and_si256(value, value); +CAMLprim value caml_simde_mm256_andnot_si256(value, value); +CAMLprim value caml_simde_mm256_add_epi8(value, value); +CAMLprim value caml_simde_mm256_add_epi16(value, value); +CAMLprim value caml_simde_mm256_sub_epi8(value, value); +CAMLprim value caml_simde_mm256_sub_epi16(value, value); +CAMLprim value caml_simde_mm256_maddubs_epi16(value, value); +CAMLprim value caml_simde_mm256_packus_epi16(value, value); +CAMLprim value caml_simde_mm256_packs_epi16(value, value); +CAMLprim value caml_simde_mm256_mulhi_epu16(value, value); +CAMLprim value caml_simde_mm256_mulhrs_epi16(value, value); +CAMLprim value caml_simde_mm256_shuffle_epi8(value, value); +CAMLprim value caml_simde_mm256_srai_epi16(value, value); +CAMLprim value caml_simde_mm256_srli_epi16(value, value); +CAMLprim value caml_simde_mm256_cmpgt_epi16(value, value); +CAMLprim value caml_simde_mm256_movemask_epi8(value); +CAMLprim value caml_simde_mm256_unpacklo_epi8(value, value); +CAMLprim value caml_simde_mm256_unpacklo_epi64(value, value); +CAMLprim value caml_simde_mm256_unpackhi_epi64(value, value); +CAMLprim value caml_simde_mm256_inserti128_si256_dyn(value, value, value); +CAMLprim value caml_simde_mm256_extracti128_si256_dyn(value, value); +CAMLprim value caml_simde_mm256_blend_epi16_dyn(value, value, value); +CAMLprim value caml_simde_mm256_blend_epi32_dyn(value, value, value); +CAMLprim value caml_simde_mm256_moveldup_ps(value); +} + +/* ==================================================================== */ +#if defined(HAS_AVX) + +/* -------------------------------------------------------------------- */ +extern value value_of_w256(simde__m256i x); +extern simde__m256i w256_of_value(value x); + +/* -------------------------------------------------------------------- */ +extern value value_of_w128(simde__m128i x); +extern simde__m128i w128_of_value(value x); + +/* -------------------------------------------------------------------- */ +struct M256i { + typedef simde__m256i type; + + static inline type ofocaml(value v) { + return w256_of_value(v); + } + + static inline value toocaml(type v) { + return value_of_w256(v); + } +}; + +/* -------------------------------------------------------------------- */ +struct M128i { + typedef simde__m128i type; + + static inline type ofocaml(value v) { + return w128_of_value(v); + } + + static inline value toocaml(type v) { + return value_of_w128(v); + } +}; + +/* -------------------------------------------------------------------- */ +struct Long { + typedef long type; + + static inline type ofocaml(value v) { + return Long_val(v); + } + + static inline value toocaml(type v) { + return Val_long(v); + } +}; + +/* -------------------------------------------------------------------- */ +struct Int32 { + typedef long type; + + static inline type ofocaml(value v) { + return Int32_val(v); + } + + static inline value toocaml(type v) { + return caml_copy_int32(v); + } +}; + +/* -------------------------------------------------------------------- */ +struct Int64 { + typedef long type; + + static inline type ofocaml(value v) { + return Int64_val(v); + } + + static inline value toocaml(type v) { + return caml_copy_int64(v); + } +}; + +/* -------------------------------------------------------------------- */ +template +static value bind(value arg) { + CAMLparam1(arg); + typename T::type varg = T::ofocaml(arg); + CAMLreturn(U::toocaml(F(varg))); +} + +/* -------------------------------------------------------------------- */ +template +static value bind(value arg1, value arg2) { + CAMLparam2(arg1, arg2); + typename T1::type varg1 = T1::ofocaml(arg1); + typename T2::type varg2 = T2::ofocaml(arg2); + CAMLreturn(U::toocaml(F(varg1, varg2))); +} + +/* -------------------------------------------------------------------- */ +template +static value bind(value arg1, value arg2, value arg3) { + CAMLparam3(arg1, arg2, arg3); + typename T1::type varg1 = T1::ofocaml(arg1); + typename T2::type varg2 = T2::ofocaml(arg2); + typename T3::type varg3 = T3::ofocaml(arg3); + CAMLreturn(U::toocaml(F(varg1, varg2, varg3))); +} + +/* -------------------------------------------------------------------- */ +# define BIND1(F, U, T) \ +CAMLprim value caml_##F(value a) { \ + return bind(a); \ +} + +/* -------------------------------------------------------------------- */ +# define BIND2(F, U, T1, T2) \ +CAMLprim value caml_##F(value a, value b) { \ + return bind(a, b); \ +} + +/* -------------------------------------------------------------------- */ +# define BIND3(F, U, T1, T2, T3) \ +CAMLprim value caml_##F(value a, value b, value c) { \ + return bind(a, b, c); \ +} + +/* ==================================================================== */ +#else + +/* -------------------------------------------------------------------- */ +# define BIND1(F, U, T) \ +CAMLprim value caml_##F(value a) { \ + CAMLparam1(a); \ + caml_failwith("not implemented: " #F); \ + CAMLreturn(Val_unit); \ +} + +/* -------------------------------------------------------------------- */ +# define BIND2(F, U, T1, T2) \ +CAMLprim value caml_##F(value a, value b) { \ + CAMLparam2(a, b); \ + caml_failwith("not implemented: " #F); \ + CAMLreturn(Val_unit); \ +} + +/* -------------------------------------------------------------------- */ +# define BIND3(F, U, T1, T2, T3) \ +CAMLprim value caml_##F(value a, value b, value c) { \ + CAMLparam3(a, b, c); \ + caml_failwith("not implemented: " #F); \ + CAMLreturn(Val_unit); \ +} + +#endif /* defined(HAS_AVX) */ + +#define BIND_256x2_256(F) BIND2(F, M256i, M256i, M256i) +#define BIND_256x3_256(F) BIND3(F, M256i, M256i, M256i, M256i) + +#endif /* !AVX2_RUNTIME__ */ diff --git a/libs/lospecs/tests/circuit_avx2.ml b/libs/lospecs/tests/circuit_avx2.ml new file mode 100644 index 0000000000..8792be5e0b --- /dev/null +++ b/libs/lospecs/tests/circuit_avx2.ml @@ -0,0 +1,265 @@ +(* ==================================================================== *) +open Lospecs +open Aig + +type symbol = string + +(* ==================================================================== *) +module type S = sig + val vpermd : reg -> reg -> reg + val vpermq : reg -> int -> reg + val vperm2i128 : reg -> reg -> int -> reg + val vpbroadcast_16u16 : reg -> reg + val vpadd_16u16 : reg -> reg -> reg + val vpadd_32u8 : reg -> reg -> reg + val vpsub_16u16 : reg -> reg -> reg + val vpsub_32u8 : reg -> reg -> reg + val vpand_256 : reg -> reg -> reg + val vpmaddubsw_256 : reg -> reg -> reg + val vpmulh_16u16 : reg -> reg -> reg + val vpmulhu_16u16 : reg -> reg -> reg + val vpmulhrs_16u16 : reg -> reg -> reg + val vpsra_16u16 : reg -> int -> reg + val vpsrl_16u16 : reg -> int -> reg + val vpsrl_4u64 : reg -> int -> reg + val vpsll_4u64 : reg -> int -> reg + val vpackus_16u16 : reg -> reg -> reg + val vpackss_16u16 : reg -> reg -> reg + val vpshufb_256 : reg -> reg -> reg + val vpcmpgt_16u16 : reg -> reg -> reg + val vpmovmskb_u256u64 : reg -> reg + val vpunpckl_32u8 : reg -> reg -> reg + val vpunpckl_4u64 : reg -> reg -> reg + val vpunpckh_4u64 : reg -> reg -> reg + val vpextracti128 : reg -> int -> reg + val vpinserti128 : reg -> reg -> int -> reg + val vpblend_16u16 : reg -> reg -> int -> reg + val vpblend_8u32 : reg -> reg -> int -> reg + val vpslldq_256 : reg -> int -> reg + val vpsrldq_256 : reg -> int -> reg + val vpslldq_128 : reg -> int -> reg + val vpsrldq_128 : reg -> int -> reg + val vmovsldup_256 : reg -> reg +end + +(* ==================================================================== *) +module FromSpec () : S = struct + (* ------------------------------------------------------------------ *) + let specs = + let specs = match Sys.getenv_opt "EC_AVX2_SPEC_FILE_PATH" with + | Some s -> s + | None -> Format.eprintf "Path to avx2 spec file not set, please set env var EC_AVX2_SPEC_FILE_PATH with the correct path to the file@."; + exit 1 + in + let specs = Circuit_spec.load_from_file ~filename:specs in + let specs = BatMap.of_seq (List.to_seq specs) in + specs + + let get_specification (name : symbol) : Ast.adef option = + BatMap.find_opt name specs + + (* ------------------------------------------------------------------ *) + let vpermd = Option.get (get_specification "VPERMD") + + let vpermd (r1 : reg) (r2 : reg) : reg = + Circuit_spec.circuit_of_specification [r1; r2] vpermd + + (* ------------------------------------------------------------------ *) + let vpermq = Option.get (get_specification "VPERMQ") + + let vpermq (r : reg) (i : int) : reg = + Circuit_spec.circuit_of_specification [r; Circuit.w8 i] vpermq + + (* ------------------------------------------------------------------ *) + let vperm2i128 = Option.get (get_specification "VPERM2I128") + + let vperm2i128 (r1 : reg) (r2 : reg) (i : int) : reg = + Circuit_spec.circuit_of_specification [r1; r2; Circuit.w8 i] vperm2i128 + + (* ------------------------------------------------------------------ *) + let vpbroadcast_16u16 = Option.get (get_specification "VPBROADCAST_16u16") + + let vpbroadcast_16u16 (r : reg) : reg = + Circuit_spec.circuit_of_specification [r] vpbroadcast_16u16 + + (* ------------------------------------------------------------------ *) + let vpadd_16u16 = Option.get (get_specification "VPADD_16u16") + + let vpadd_16u16 (r1 : reg) (r2 : reg) : reg = + Circuit_spec.circuit_of_specification [r1; r2] vpadd_16u16 + + (* ------------------------------------------------------------------ *) + let vpadd_32u8 = Option.get (get_specification "VPADD_32u8") + + let vpadd_32u8 (r1 : reg) (r2 : reg) : reg = + Circuit_spec.circuit_of_specification [r1; r2] vpadd_32u8 + + (* ----------------------------------------------------------------- *) + let vpsub_16u16 = Option.get (get_specification "VPSUB_16u16") + + let vpsub_16u16 (r1 : reg) (r2 : reg) : reg = + Circuit_spec.circuit_of_specification [r1; r2] vpsub_16u16 + + (* ------------------------------------------------------------------ *) + let vpsub_32u8 = Option.get (get_specification "VPSUB_32u8") + + let vpsub_32u8 (r1 : reg) (r2 : reg) : reg = + Circuit_spec.circuit_of_specification [r1; r2] vpsub_32u8 + + (* ------------------------------------------------------------------ *) + let vpand_256 = Option.get (get_specification "VPAND_256") + + let vpand_256 (r1 : reg) (r2 : reg) : reg = + Circuit_spec.circuit_of_specification [r1; r2] vpand_256 + + (* ------------------------------------------------------------------ *) + let vpmaddubsw_256 = Option.get (get_specification "VPMADDUBSW_256") + + let vpmaddubsw_256 (r1 : reg) (r2 : reg) : reg = + Circuit_spec.circuit_of_specification [r1; r2] vpmaddubsw_256 + + (* ------------------------------------------------------------------ *) + let vpmulh_16u16 = Option.get (get_specification "VPMULH_16u16") + + let vpmulh_16u16 (r1 : reg) (r2 : reg) : reg = + Circuit_spec.circuit_of_specification [r1; r2] vpmulh_16u16 + + (* ------------------------------------------------------------------ *) + let vpmulhu_16u16 = Option.get (get_specification "VPMULHU_16u16") + + let vpmulhu_16u16 (r1 : reg) (r2 : reg) : reg = + Circuit_spec.circuit_of_specification [r1; r2] vpmulhu_16u16 + + (* ------------------------------------------------------------------ *) + let vpmulhrs_16u16 = Option.get (get_specification "VPMULHRS_16u16") + + let vpmulhrs_16u16 (r1 : reg) (r2 : reg) : reg = + Circuit_spec.circuit_of_specification [r1; r2] vpmulhrs_16u16 + + (* ------------------------------------------------------------------ *) + let vpsra_16u16 = Option.get (get_specification "VPSRA_16u16") + + let vpsra_16u16 (r : reg) (n : int) : reg = + Circuit_spec.circuit_of_specification [r; Circuit.w128 (string_of_int n)] vpsra_16u16 + + (* ------------------------------------------------------------------ *) + let vpsrl_16u16 = Option.get (get_specification "VPSRL_16u16") + + let vpsrl_16u16 (r : reg) (n : int) : reg = + Circuit_spec.circuit_of_specification [r; Circuit.w128 (string_of_int n)] vpsrl_16u16 + + (* ------------------------------------------------------------------ *) + let vpsrl_4u64 = Option.get (get_specification "VPSRL_4u64") + + let vpsrl_4u64 (r : reg) (n : int) : reg = + Circuit_spec.circuit_of_specification [r; Circuit.w128 (string_of_int n)] vpsrl_4u64 + + (* ------------------------------------------------------------------ *) + let vpsll_4u64 = Option.get (get_specification "VPSLL_4u64") + + let vpsll_4u64 (r : reg) (n : int) : reg = + Circuit_spec.circuit_of_specification [r; Circuit.w128 (string_of_int n)] vpsll_4u64 + + (* ------------------------------------------------------------------ *) + let vpslldq_256 = Option.get (get_specification "VPSLLDQ_256") + + let vpslldq_256 (r : reg) (n : int) : reg = + Circuit_spec.circuit_of_specification [r; Circuit.w8 (8 * n)] vpslldq_256 + + (* ------------------------------------------------------------------ *) + let vpsrldq_256 = Option.get (get_specification "VPSRLDQ_256") + + let vpsrldq_256 (r : reg) (n : int) : reg = + Circuit_spec.circuit_of_specification [r; Circuit.w8 (8 * n)] vpsrldq_256 + + (* ------------------------------------------------------------------ *) + let vpslldq_128 = Option.get (get_specification "VPSLLDQ_128") + + let vpslldq_128 (r : reg) (n : int) : reg = + Circuit_spec.circuit_of_specification [r; Circuit.w8 (8 * n)] vpslldq_128 + + (* ------------------------------------------------------------------ *) + let vpsrldq_128 = Option.get (get_specification "VPSRLDQ_128") + + let vpsrldq_128 (r : reg) (n : int) : reg = + Circuit_spec.circuit_of_specification [r; Circuit.w8 (8 * n)] vpsrldq_128 + + (* ------------------------------------------------------------------ *) + let vpackus_16u16 = Option.get (get_specification "VPACKUS_16u16") + + let vpackus_16u16 (r1 : reg) (r2 : reg) : reg = + Circuit_spec.circuit_of_specification [r1; r2] vpackus_16u16 + + (* ------------------------------------------------------------------ *) + let vpackss_16u16 = Option.get (get_specification "VPACKSS_16u16") + + let vpackss_16u16 (r1 : reg) (r2 : reg) : reg = + Circuit_spec.circuit_of_specification [r1; r2] vpackss_16u16 + + (* ------------------------------------------------------------------ *) + let vpshufb_256 = Option.get (get_specification "VPSHUFB_256") + + let vpshufb_256 (r1 : reg) (r2 : reg) : reg = + Circuit_spec.circuit_of_specification [r1; r2] vpshufb_256 + + (* ------------------------------------------------------------------ *) + let vpcmpgt_16u16 = Option.get (get_specification "VPCMPGT_16u16") + + let vpcmpgt_16u16 (r1 : reg) (r2 : reg) : reg = + Circuit_spec.circuit_of_specification [r1; r2] vpcmpgt_16u16 + + (* ------------------------------------------------------------------ *) + let vpmovmskb_u256u64 = Option.get (get_specification "VPMOVMSKB_u256u64") + + let vpmovmskb_u256u64 (r : reg) : reg = + Circuit_spec.circuit_of_specification [r] vpmovmskb_u256u64 + + (* ------------------------------------------------------------------ *) + let vpunpckl_32u8 = Option.get (get_specification "VPUNPCKL_32u8") + + let vpunpckl_32u8 (r1 : reg) (r2 : reg): reg = + Circuit_spec.circuit_of_specification [r1; r2] vpunpckl_32u8 + + (* ------------------------------------------------------------------ *) + let vpunpckl_4u64 = Option.get (get_specification "VPUNPCKL_4u64") + + let vpunpckl_4u64 (r1 : reg) (r2 : reg): reg = + Circuit_spec.circuit_of_specification [r1; r2] vpunpckl_4u64 + + (* ------------------------------------------------------------------ *) + let vpunpckh_4u64 = Option.get (get_specification "VPUNPCKH_4u64") + + let vpunpckh_4u64 (r1 : reg) (r2 : reg): reg = + Circuit_spec.circuit_of_specification [r1; r2] vpunpckh_4u64 + + (* ------------------------------------------------------------------ *) + let vpextracti128 = Option.get (get_specification "VPEXTRACTI128") + + let vpextracti128 (r : reg) (i : int): reg = + Circuit_spec.circuit_of_specification [r; Circuit.w8 i] vpextracti128 + + (* ------------------------------------------------------------------ *) + let vpinserti128 = Option.get (get_specification "VPINSERTI128") + + let vpinserti128 (r1 : reg) (r2 : reg) (i : int): reg = + Circuit_spec.circuit_of_specification [r1; r2; Circuit.w8 i] vpinserti128 + + (* ------------------------------------------------------------------ *) + let vpblend_16u16 = Option.get (get_specification "VPBLEND_16u16") + + let vpblend_16u16 (r1 : reg) (r2 : reg) (i : int): reg = + Circuit_spec.circuit_of_specification [r1; r2; Circuit.w8 i] vpblend_16u16 + + (* ------------------------------------------------------------------ *) + let vpblend_8u32 = Option.get (get_specification "VPBLEND_8u32") + + let vpblend_8u32 (r1 : reg) (r2 : reg) (i : int): reg = + Circuit_spec.circuit_of_specification [r1; r2; Circuit.w8 i] vpblend_8u32 + + (* ------------------------------------------------------------------ *) + let vmovsldup_256 = Option.get (get_specification "VMOVSLDUP_256") + + let vmovsldup_256 (r : reg) : reg = + Circuit_spec.circuit_of_specification [r] vmovsldup_256 + +end diff --git a/libs/lospecs/tests/circuit_test.ml b/libs/lospecs/tests/circuit_test.ml new file mode 100644 index 0000000000..a9b205d9e3 --- /dev/null +++ b/libs/lospecs/tests/circuit_test.ml @@ -0,0 +1,1109 @@ +(* -------------------------------------------------------------------- *) +open Lospecs + +(* -------------------------------------------------------------------- *) +module C = struct + include Lospecs.Aig + include Lospecs.Circuit + include Circuit_avx2.FromSpec () +end + +(* -------------------------------------------------------------------- *) +let sign (i : int) = + match i with + | _ when i < 0 -> -1 + | _ when i > 0 -> 1 + | _ -> 0 + +(* -------------------------------------------------------------------- *) +let as_seq1 (type t) (xs : t list) = + match xs with [x] -> x | _ -> assert false + +(* -------------------------------------------------------------------- *) +let as_seq2 (type t) (xs : t list) = + match xs with [x; y] -> (x, y) | _ -> assert false + +(* -------------------------------------------------------------------- *) +let pp_bytes (fmt : Format.formatter) (b : bytes) = + Bytes.iter + (fun b -> Format.fprintf fmt "%02x" (Char.code b)) + b + +(* -------------------------------------------------------------------- *) +let srange_ (i : int) = + assert (0 < i && i <= Sys.int_size); + let v = (1 lsl (i - 1)) in + (-v, v-1) + +(* -------------------------------------------------------------------- *) +let srange (i : int) = + let vm, vM = srange_ i in Iter.(--) vm vM + +(* -------------------------------------------------------------------- *) +let urange_ (i : int) = + assert (0 < i && i <= Sys.int_size - 1); + (0, (1 lsl i) - 1) + +(* -------------------------------------------------------------------- *) +let urange (i : int) = + let vm, vM = urange_ i in Iter.(--) vm vM + +(* -------------------------------------------------------------------- *) +let product (type t) (s : t Iter.t list) = + let rec doit (s : t Iter.t list) : t list Iter.t = + match s with + | [] -> + Iter.of_list [[]] + | s1 :: s -> + Iter.map (fun (x, xs) -> x :: xs) (Iter.product s1 (doit s)) + in doit s + +(* -------------------------------------------------------------------- *) +type op = { + name : string; + args : (int * [`U | `S]) list; + out : [`U | `S]; + mk : C.reg list -> C.reg; + reff : int list -> int; +} + +(* -------------------------------------------------------------------- *) +let bar (name : string) (total : int) = + let open Progress.Line in + list [ + spinner ~color:(Progress.Color.ansi `green) () + ; rpad (max 20 (String.length name)) (const name) + ; bar total + ; lpad (2 * 7 + 1) (count_to total) + ] + +(* -------------------------------------------------------------------- *) +let test (op : op) = + let rs, vs = + let reg_of_arg (name : int) ((sz, s) : int * [`U | `S]) = + let r = C.reg ~size:sz ~name in + let v = match s with `U -> urange sz | `S -> srange sz in + (r, v) + in List.split (List.mapi reg_of_arg op.args) + in + + let sz = List.sum (List.map fst op.args) in + + assert (sz <= Sys.int_size - 1); + + let total = 1 lsl sz in + let bar = bar op.name total in + + let circuit = op.mk rs in + + let test (vs : int list) = + let vsa = Array.of_list vs in + let env ((n, k) : C.var) = (vsa.(n) lsr k) land 0b1 <> 0 in + let out = Array.map (C.eval env) circuit in + let out = + match op.out with + | `S -> C.sint_of_bools out + | `U -> C.uint_of_bools out in + let exp = op.reff vs in + + if out <> exp then begin + Progress.interject_with (fun () -> + Format.eprintf "%s(%a) = out: %d / exp: %d@." + op.name + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") + Format.pp_print_int) + vs + out + exp + ); + assert false + end + in + + Progress.with_reporter bar (fun f -> + Iter.iter + (fun vs -> test vs; f 1) + (product vs) + ) + +(* -------------------------------------------------------------------- *) +let test_uextend () = + let op (isize : int) (osize : int) : op = + { name = (Format.sprintf "uextend<%d,%d>" isize osize) + ; args = [(isize, `U)] + ; out = `U + ; mk = (fun rs -> C.uextend ~size:osize (as_seq1 rs)) + ; reff = (fun vs -> as_seq1 vs) + } + + in test (op 8 16) + +(* -------------------------------------------------------------------- *) +let test_ite () = + let op () : op = + { name = (Format.sprintf "ite") + ; args = [(1, `U)] + ; out = `U + ; mk = (fun rs -> C.ite ((as_seq1 rs).(0)) [|C.true_|] [|C.false_|]) + ; reff = (fun vs -> as_seq1 vs) + } + + in test (op ()) + +(* -------------------------------------------------------------------- *) +let test_sextend () = + let op (isize : int) (osize : int) : op = + { name = (Format.sprintf "sextend<%d,%d>" isize osize) + ; args = [(isize, `S)] + ; out = `S + ; mk = (fun rs -> C.sextend ~size:osize (as_seq1 rs)) + ; reff = (fun vs -> as_seq1 vs) + } + + in test (op 8 16) + +(* -------------------------------------------------------------------- *) +let test_shift ~(side : [`L | `R]) ~(sign : [`U | `S]) = + let str_side = match side with `L -> "left" | `R -> "right" in + let str_sign = match sign with `U -> "u" | `S -> "s" in + + let op (size : int) : op = + let module M = (val Word.word ~sign ~size) in + + let sim (v : int) (i : int) = + M.to_int (match side with + | `L -> M.shiftl (M.of_int v) i + | `R -> M.shiftr (M.of_int v) i + ) + in + + let asign = match sign with `U -> `L | `S -> `A in + + { name = (Format.sprintf "shift<%s,%s,%d>" str_side str_sign size) + ; args = [(size, sign); (4, `U)] + ; out = sign + ; mk = (fun rs -> let x, y = as_seq2 rs in C.shift ~side ~sign:asign x y) + ; reff = (fun vs -> let x, y = as_seq2 vs in sim x y) + } + + in + + for i = 1 to 14 do + test (op i) + done + +(* -------------------------------------------------------------------- *) +let test_rot ~(side : [`L | `R]) = + let str_side = match side with `L -> "left" | `R -> "right" in + + let op (size : int) : op = + let module M = (val Word.word ~sign:`U ~size) in + + let sim (v : int) (i : int) = + let i = i mod size in + let m = (1 lsl size) - 1 in + let v = v land m in + match side with + | `L -> ((v lsl i) lor (v lsr (size - i))) land m + | `R -> ((v lsr i) lor (v lsl (size - i))) land m + in + + { name = (Format.sprintf "rot<%s,%d>" str_side size) + ; args = [(size, `U); (4, `U)] + ; out = `U + ; mk = (fun rs -> let x, y = as_seq2 rs in match side with + | `L -> C.rol x y + | `R -> C.ror x y + ) + ; reff = (fun vs -> let x, y = as_seq2 vs in sim x y) + } + + in + + for i = 1 to 14 do + test (op i) + done + +(* -------------------------------------------------------------------- *) +let test_opp () = + let op (size : int) : op = + let module M = (val Word.sword ~size) in + + let sim (x : int) : int = + M.to_int (M.neg (M.of_int x)) in + + { name = (Format.sprintf "opp<%d>" size) + ; args = [(size, `S)] + ; out = `S + ; mk = (fun rs -> C.opp (as_seq1 rs)) + ; reff = (fun vs -> sim (as_seq1 vs)) + } + + in test (op 13) + +(* -------------------------------------------------------------------- *) +let test_add () = + let op (size : int) : op = + let module M = (val Word.sword ~size) in + + let sim (x : int) (y : int) : int = + M.to_int (M.add (M.of_int x) (M.of_int y)) in + + { name = (Format.sprintf "add<%d>" size) + ; args = List.make 2 (size, `S) + ; out = `S + ; mk = (fun rs -> let x, y = as_seq2 rs in C.add_dropc x y) + ; reff = (fun vs -> let x, y = as_seq2 vs in sim x y) + } + + in test (op 9) + +(* -------------------------------------------------------------------- *) +let test_incr () = + let op (size : int) : op = + let module M = (val Word.uword ~size) in + + let sim (x : int) : int = + M.to_int (M.add (M.of_int x) M.one) in + + { name = (Format.sprintf "incr<%d>" size) + ; args = [(size, `U)] + ; out = `U + ; mk = (fun rs -> C.incr_dropc (as_seq1 rs)) + ; reff = (fun vs -> sim (as_seq1 vs)); + } + + in test (op 11) + +(* -------------------------------------------------------------------- *) +let test_sub () = + let op (size : int) : op = + let module M = (val Word.sword ~size) in + + let sim (x : int) (y : int) : int = + M.to_int (M.sub (M.of_int x) (M.of_int y)) in + + { name = (Format.sprintf "sub<%d>" size) + ; args = List.make 2 (size, `S) + ; out = `S + ; mk = (fun rs -> let x, y = as_seq2 rs in C.sub_dropc x y) + ; reff = (fun vs -> let x, y = as_seq2 vs in sim x y) + } + + in test (op 9) + +(* -------------------------------------------------------------------- *) +let test_umul () = + let op (sz1 : int) (sz2 : int) : op = { + name = (Format.sprintf "umul<%d,%d>" sz1 sz2); + args = [(sz1, `U); (sz2, `U)]; + out = `U; + mk = (fun rs -> let x, y = as_seq2 rs in C.umul x y); + reff = (fun vs -> let x, y = as_seq2 vs in (x * y)); + } in + + test (op 10 8) + +(* -------------------------------------------------------------------- *) +let test_smul () = + let op (sz1 : int) (sz2 : int) : op = { + name = (Format.sprintf "smul<%d,%d>" sz1 sz2); + args = [(sz1, `S); (sz2, `S)]; + out = `S; + mk = (fun rs -> let x, y = as_seq2 rs in C.smul x y); + reff = (fun vs -> let x, y = as_seq2 vs in (x * y)); + } in + + test (op 10 8) + +(* -------------------------------------------------------------------- *) +let test_smul_u8_s8 () = + let op () : op = { + name = "smul_u8_s8"; + args = [(8, `U); (8, `S)]; + out = `S; + mk = (fun rs -> + let x, y = as_seq2 rs in + C.smul + (C.uextend ~size:16 x) + (C.sextend ~size:16 y)); + reff = (fun vs -> let x, y = as_seq2 vs in (x * y)); + } in + + test (op ()) + +(* -------------------------------------------------------------------- *) +let test_udiv () = + let op (size : int) : op = + let sim (x : int) (y : int) : int = + if y = 0 then x else x / y + in + + { name = (Format.sprintf "udiv<%d>" size) + ; args = List.make 2 (size, `U) + ; out = `U + ; mk = (fun rs -> let x, y = as_seq2 rs in C.udiv x y) + ; reff = (fun vs -> let x, y = as_seq2 vs in sim x y) + } + + in + test (op 4); + test (op 9) + +(* -------------------------------------------------------------------- *) +let test_umod () = + let op (size : int) : op = + let sim (x : int) (y : int) : int = + if y = 0 then 0 else x mod y + in + + { name = (Format.sprintf "umod<%d>" size) + ; args = List.make 2 (size, `U) + ; out = `U + ; mk = (fun rs -> let x, y = as_seq2 rs in C.umod x y) + ; reff = (fun vs -> let x, y = as_seq2 vs in sim x y) + } + + in + test (op 4); + test (op 9) + +(* -------------------------------------------------------------------- *) +let test_sdiv () = + let op (size : int) : op = + let module M = (val Word.sword ~size) in + + let sim (x : int) (y : int) : int = + if y = 0 then x else M.to_int (M.div (M.of_int x) (M.of_int y)) + in + + { name = (Format.sprintf "sdiv<%d>" size) + ; args = List.make 2 (size, `S) + ; out = `S + ; mk = (fun rs -> let x, y = as_seq2 rs in C.sdiv x y) + ; reff = (fun vs -> let x, y = as_seq2 vs in sim x y) + } + + in + test (op 4); + test (op 9) + +(* -------------------------------------------------------------------- *) +let test_ssat () = + let op (isize : int) (osize: int) : op = + let saturate = + let vm, vM = srange_ osize in + fun (i : int) -> min vM (max vm i) + in + + { name = (Format.sprintf "ssat<%d,%d>" isize osize); + args = [(isize, `S)]; + out = `S; + mk = (fun rs -> C.sat ~signed:true ~size:osize (as_seq1 rs)); + reff = (fun vs -> saturate (as_seq1 vs)); } in + + test (op 10 4); + test (op 15 7); + test (op 17 16) + +(* -------------------------------------------------------------------- *) +let test_usat () = + let op (isize : int) (osize: int) : op = + let saturate = + let vm, vM = urange_ osize in + fun (i : int) -> min vM (max vm i) + in + + { name = (Format.sprintf "usat<%d,%d>" isize osize); + args = [(isize, `S)]; + out = `U; + mk = (fun rs -> C.sat ~signed:false ~size:osize (as_seq1 rs)); + reff = (fun vs -> saturate (as_seq1 vs)); } in + + test (op 10 4); + test (op 15 7) + +(* -------------------------------------------------------------------- *) +let test_sgt () = + let op (size : int) = + { name = Format.sprintf "sgt<%d>" size; + args = [(size, `S); (size, `S)]; + out = `U; + mk = (fun rs -> let x, y = as_seq2 rs in [|C.sgt x y|]); + reff = (fun vs -> let x, y = as_seq2 vs in if x > y then 1 else 0); } + + in + test (op 10) + +(* -------------------------------------------------------------------- *) +let test_sge () = + let op (size : int) = + { name = Format.sprintf "sge<%d>" size; + args = [(size, `S); (size, `S)]; + out = `U; + mk = (fun rs -> let x, y = as_seq2 rs in [|C.sge x y|]); + reff = (fun vs -> let x, y = as_seq2 vs in if x >= y then 1 else 0); } + + in + test (op 10) + +(* -------------------------------------------------------------------- *) +let test_ugt () = + let op (size : int) = + { name = Format.sprintf "ugt<%d>" size; + args = [(size, `U); (size, `U)]; + out = `U; + mk = (fun rs -> let x, y = as_seq2 rs in [|C.ugt x y|]); + reff = (fun vs -> let x, y = as_seq2 vs in if x > y then 1 else 0); } + + in + test (op 10) + +(* -------------------------------------------------------------------- *) +let test_uge () = + let op (size : int) = + { name = Format.sprintf "uge<%d>" size; + args = [(size, `U); (size, `U)]; + out = `U; + mk = (fun rs -> let x, y = as_seq2 rs in [|C.uge x y|]); + reff = (fun vs -> let x, y = as_seq2 vs in if x >= y then 1 else 0); } + + in + test (op 10) + +(* -------------------------------------------------------------------- *) +let test_popcount () = + let op (size : int) = + { name = Format.sprintf "popcount<%d>" size; + args = [(size, `U)]; + out = `U; + mk = (fun rs -> let x = as_seq1 rs in C.popcount ~size x); + reff = (fun vs -> let x = as_seq1 vs in Z.popcount (Z.of_int x)); } + + in + test (op 16) + +(* -------------------------------------------------------------------- *) +type mvalue = M256 of Avx2.m256 | M128 of Avx2.m128 + +module MValue : sig + type kind = [`M256 | `M128] + + val random : kind -> mvalue + + val to_bytes : endianess:Avx2.endianess -> mvalue -> bytes + + val of_bytes : endianess:Avx2.endianess -> bytes -> mvalue + + val pp : + endianess:Avx2.endianess -> + size:Avx2.size -> + Format.formatter -> + mvalue -> + unit +end = struct + type kind = [`M256 | `M128] + + let random (k : kind) = + match k with + | `M256 -> M256 (Avx2.M256.random ()) + | `M128 -> M128 (Avx2.M128.random ()) + + let to_bytes ~(endianess : Avx2.endianess) (m : mvalue) = + match m with + | M256 v -> Avx2.M256.to_bytes ~endianess:`Little v + | M128 v -> Avx2.M128.to_bytes ~endianess:`Little v + + let of_bytes ~(endianess : Avx2.endianess) (m : bytes) = + match Bytes.length m with + | 32 -> M256 (Avx2.M256.of_bytes ~endianess m) + | 16 -> M128 (Avx2.M128.of_bytes ~endianess m) + | _ -> assert false + + let pp + ~(endianess : Avx2.endianess) + ~(size : Avx2.size) + (fmt : Format.formatter) + (m : mvalue) + = + match m with + | M256 v -> Avx2.M256.pp ~endianess ~size fmt v + | M128 v -> Avx2.M128.pp ~endianess ~size fmt v +end + +(* -------------------------------------------------------------------- *) +type vpop = { + name : string; + args : MValue.kind list; + mk : C.reg list -> C.reg; + reff : mvalue list -> mvalue; +} + +(* -------------------------------------------------------------------- *) +let call_m256_m256 + (f : Avx2.m256 -> Avx2.m256) + (vs : mvalue list) + : mvalue += + match vs with + | [M256 v] -> M256 (f v) + | _ -> assert false + +(* -------------------------------------------------------------------- *) +let call_m256_m128 + (f : Avx2.m256 -> Avx2.m128) + (vs : mvalue list) + : mvalue += + match vs with + | [M256 v] -> M128 (f v) + | _ -> assert false + +(* -------------------------------------------------------------------- *) +let call_m256_m128_m256 + (f : Avx2.m256 -> Avx2.m128 -> Avx2.m256) + (vs : mvalue list) + : mvalue += + match vs with + | [M256 v1; M128 v2] -> M256 (f v1 v2) + | _ -> assert false + +(* -------------------------------------------------------------------- *) +let call_m256x2_m256 + (f : Avx2.m256 -> Avx2.m256 -> Avx2.m256) + (vs : mvalue list) + : mvalue += + match vs with + | [M256 v1; M256 v2] -> M256 (f v1 v2) + | _ -> assert false + +(* -------------------------------------------------------------------- *) +let test_vp (total : int) (op : vpop) = + let rs = op.args |> List.mapi (fun i arg -> + match arg with + | `M256 -> C.reg ~size:256 ~name:i + | `M128 -> C.reg ~size:128 ~name:i + ) in + + let circuit = op.mk rs in + + let test () = + let vs = List.map MValue.random op.args in + let avs = Array.of_list vs in + let avs = Array.map (MValue.to_bytes ~endianess:`Little) avs in + + let env ((n, i) : C.var) = C.get_bit avs.(n) i in + + let o = + match op.reff vs with + | M256 v -> Avx2.M256.to_bytes ~endianess:`Little v + | M128 v -> Avx2.M128.to_bytes ~endianess:`Little v + in + + let o' = Array.map (C.eval env) circuit in + let o' = C.bytes_of_bools o' in + + if o <> o' then begin + Progress.interject_with (fun () -> + vs |> List.iter (fun v -> + Format.eprintf "%a@." + (MValue.pp ~endianess:`Big ~size:`U32) + v + ); + Format.eprintf "%a@." + (MValue.pp ~endianess:`Big ~size:`U32) + (MValue.of_bytes ~endianess:`Little o); + Format.eprintf "%a@." + (MValue.pp ~endianess:`Big ~size:`U32) + (MValue.of_bytes ~endianess:`Little o') + ); + assert false + end + in + + Progress.with_reporter (bar op.name total) (fun f -> + Iter.iter + (fun _ -> test (); f 1) + (Iter.(--) 0 (total-1)) + ) + +(* -------------------------------------------------------------------- *) +let test_vpadd_16u16 () = + let op = { + name = "vpadd_16u16"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpadd_16u16 x y); + reff = call_m256x2_m256 Avx2.mm256_add_epi16; + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vpadd_32u8 () = + let op = { + name = "vpadd_32u8"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpadd_32u8 x y); + reff = call_m256x2_m256 Avx2.mm256_add_epi8; + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vpsub_16u16 () = + let op = { + name = "vpsub_16u16"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpsub_16u16 x y); + reff = call_m256x2_m256 Avx2.mm256_sub_epi16; + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vpsub_32u8 () = + let op = { + name = "vpsub_32u8"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpsub_32u8 x y); + reff = call_m256x2_m256 Avx2.mm256_sub_epi8; + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vpsra_16u16 () = + let op (offset : int) = { + name = Format.sprintf "vpsra_16u16<%d>" offset; + args = [`M256]; + mk = (fun rs -> C.vpsra_16u16 (as_seq1 rs) offset); + reff = call_m256_m256 (fun x -> Avx2.mm256_srai_epi16 x offset); + } in + + Iter.iter (fun i -> test_vp 10000 (op i)) (Iter.(--) 0x00 0x10) + +(* -------------------------------------------------------------------- *) +let test_vpsrl_16u16 () = + let op (offset : int) = { + name = Format.sprintf "vpsrl_16u16<%d>" offset; + args = [`M256]; + mk = (fun rs -> C.vpsrl_16u16 (as_seq1 rs) offset); + reff = call_m256_m256 (fun x -> Avx2.mm256_srli_epi16 x offset); + } in + + Iter.iter (fun i -> test_vp 10000 (op i)) (Iter.(--) 0x00 0x10) + +(* -------------------------------------------------------------------- *) +let test_vpand_256 () = + let op = { + name = "vpand_256"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpand_256 x y); + reff = call_m256x2_m256 Avx2.mm256_and_si256; + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vpmulh_16u16 () = + let op = { + name = "vpmulh_16u16"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpmulh_16u16 x y); + reff = call_m256x2_m256 Avx2.mm256_mulhi_epi16; + } in + + test_vp 200 op + +(* -------------------------------------------------------------------- *) +let test_vpmulhu_16u16 () = + let op = { + name = "vpmulhu_16u16"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpmulhu_16u16 x y); + reff = call_m256x2_m256 Avx2.mm256_mulhi_epu16; + } in + + test_vp 200 op + +(* -------------------------------------------------------------------- *) +let test_vpmulhrs_16u16 () = + let op = { + name = "vpmulhrs_16u16"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpmulhrs_16u16 x y); + reff = call_m256x2_m256 Avx2.mm256_mulhrs_epi16; + } in + + test_vp 200 op + +(* -------------------------------------------------------------------- *) +let test_vpackus_16u16 () = + let op = { + name = "vpackus_16u16"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpackus_16u16 x y); + reff = call_m256x2_m256 Avx2.mm256_packus_epi16; + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vpackss_16u16 () = + let op = { + name = "vpackss_16u16"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpackss_16u16 x y); + reff = call_m256x2_m256 Avx2.mm256_packs_epi16; + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vpmaddubsw_256 () = + let op = { + name = "vpmaddubsw_256"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpmaddubsw_256 x y); + reff = call_m256x2_m256 Avx2.mm256_maddubs_epi16; + } in + + test_vp 200 op + +(* -------------------------------------------------------------------- *) +let test_vpermd () = + let op = { + name = "vpermd"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpermd x y); + reff = call_m256x2_m256 (fun x y -> Avx2.mm256_permutevar8x32_epi32 y x); + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vpermq () = + let op (imm8 : int) = { + name = Format.sprintf "vpermq<%d>" imm8; + args = [`M256]; + mk = (fun rs -> C.vpermq (as_seq1 rs) imm8); + reff = call_m256_m256 (fun x -> Avx2.mm256_permute4x64_epi64 x imm8); + } in + + test_vp 10000 (op 0x23); + test_vp 10000 (op 0xf7) + +(* -------------------------------------------------------------------- *) +let test_vbshufb_256 () = + let op = { + name = "vbshufb_256"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpshufb_256 x y); + reff = call_m256x2_m256 Avx2.mm256_shuffle_epi8; + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vpcmpgt_16u16 () = + let op = { + name = "vpcmpgt_16u16"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpcmpgt_16u16 x y); + reff = call_m256x2_m256 Avx2.mm256_cmpgt_epi16; + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vpmovmskb_u256u64 () = + let op = { + name = "vpmovmskb_u256u64"; + args = [`M256]; + mk = (fun rs -> C.uextend ~size:256 (C.vpmovmskb_u256u64 (as_seq1 rs))); + reff = (fun vs -> + match vs with + | [M256 v] -> + let out = Avx2.mm256_movemask_epi8 v in + let out = Int64.logand (Int64.of_int32 out) 0xffffffffL in + M256 (Avx2.M256.oftuple_64 (out, 0L, 0L, 0L)) + | _ -> + assert false + ) + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vpunpckl_32u8 () = + let op = { + name = "test_vpunpckl_32u8"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpunpckl_32u8 x y); + reff = call_m256x2_m256 Avx2.mm256_unpacklo_epi8; + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vpunpckl_4u64 () = + let op = { + name = "test_vpunpckl_4u64"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpunpckl_4u64 x y); + reff = call_m256x2_m256 Avx2.mm256_unpacklo_epi64; + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vpunpckh_4u64 () = + let op = { + name = "test_vpunpckh_4u64"; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpunpckh_4u64 x y); + reff = call_m256x2_m256 Avx2.mm256_unpackhi_epi64; + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vmovsldup_256 () = + let op = { + name = "test_vmovsldup_256"; + args = List.make 1 `M256; + mk = (fun rs -> let x = as_seq1 rs in C.vmovsldup_256 x); + reff = call_m256_m256 Avx2.mm256_moveldup_ps; + } in + + test_vp 10000 op + +(* -------------------------------------------------------------------- *) +let test_vpblend_16u16 () = + let op (imm8 : int) = { + name = Format.sprintf "test_vpblend_16u16<%d>" imm8; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpblend_16u16 x y imm8); + reff = call_m256x2_m256 (fun x y -> Avx2.mm256_blend_epi16 x y imm8); + } in + + test_vp 10000 (op 0x00); + test_vp 10000 (op 0x3f); + test_vp 10000 (op 0xaa) + +(* -------------------------------------------------------------------- *) +let test_vpblend_8u32 () = + let op (imm8 : int) = { + name = Format.sprintf "test_vpblend_8u32<%d>" imm8; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpblend_8u32 x y imm8); + reff = call_m256x2_m256 (fun x y -> Avx2.mm256_blend_epi32 x y imm8); + } in + + test_vp 10000 (op 0xaa) + + (* -------------------------------------------------------------------- *) +let test_vperm2i128 () = + let op (imm8 : int) = { + name = Format.sprintf "test_vperm2i128<%d>" imm8; + args = List.make 2 `M256; + mk = (fun rs -> let x, y = as_seq2 rs in C.vperm2i128 x y imm8); + reff = call_m256x2_m256 (fun x y -> Avx2.mm256_permute2x128_si256 x y imm8); + } in + + test_vp 10000 (op 32); + test_vp 10000 (op 49) + +(* -------------------------------------------------------------------- *) +let test_extracti128 () = + let op (i : int) = { + name = Format.sprintf "test_extracti128<%d>" i; + args = [`M256]; + mk = (fun rs -> C.vpextracti128 (as_seq1 rs) i); + reff = call_m256_m128 (fun x -> Avx2.mm256_extracti128_si256 x i); + } in + + test_vp 10000 (op 0); + test_vp 10000 (op 1) + +(* -------------------------------------------------------------------- *) +let test_inserti128 () = + let op (i : int) = { + name = Format.sprintf "test_inserti128<%d>" i; + args = [`M256; `M128]; + mk = (fun rs -> let x, y = as_seq2 rs in C.vpinserti128 x y i); + reff = call_m256_m128_m256 (fun x y -> Avx2.mm256_inserti128_si256 x y i); + } in + + test_vp 10000 (op 0); + test_vp 10000 (op 1) + +(* -------------------------------------------------------------------- *) +let test_bvueq () = + let op (size : int) : op = + let module M = (val Word.sword ~size) in + + let sim (x : int) (y : int) : int = + if x = y then 1 else 0 + in + + { name = (Format.sprintf "bvueq<%d>" size) + ; args = List.make 2 (size, `U) + ; out = `U + ; mk = (fun rs -> let x, y = as_seq2 rs in [|C.bvueq x y|]) + ; reff = (fun vs -> let x, y = as_seq2 vs in sim x y) + } + + in test (op 9) + +(* -------------------------------------------------------------------- *) +let test_bvseq () = + let op (size : int) : op = + let module M = (val Word.sword ~size) in + + let sim (x : int) (y : int) : int = + if x = y then 1 else 0 + in + + { name = (Format.sprintf "bvseq<%d>" size) + ; args = List.make 2 (size, `S) + ; out = `U + ; mk = (fun rs -> let x, y = as_seq2 rs in [|C.bvseq x y|]) + ; reff = (fun vs -> let x, y = as_seq2 vs in sim x y) + } + + in test (op 9) + +(* -------------------------------------------------------------------- *) +let test_mod () = + let op (size : int) : op = + let module M = (val Word.uword ~size) in + + let sim (x : int) (y : int) : int = + M.to_int @@ M.mod_ (M.of_int x) (M.of_int y) + in + + { name = (Format.sprintf "mod<%d>" size) + ; args = List.make 2 (size, `U) + ; out = `U + ; mk = (fun rs -> let x, y = as_seq2 rs in C.umod x y) + ; reff = (fun vs -> let x, y = as_seq2 vs in sim x y) + } + + in test (op 9) + +(* -------------------------------------------------------------------- *) +let test_smod () = + let op (size : int) : op = + let module M = (val Word.sword ~size) in + + let sim (x : int) (y : int) : int = + M.to_int @@ M.mod_ (M.of_int x) (M.of_int y) + in + + { name = (Format.sprintf "smod<%d>" size) + ; args = List.make 2 (size, `S) + ; out = `S + ; mk = (fun rs -> let x, y = as_seq2 rs in C.smod x y) + ; reff = (fun vs -> let x, y = as_seq2 vs in sim x y) + } + + in + for i = 1 to 9 do + test (op i) + done + +(* -------------------------------------------------------------------- *) +let tests = [ +(* + ("opp" , test_opp ); + ("incr", test_incr); + ("add" , test_add ); + ("sub" , test_sub ); + ("umul", test_umul); + ("smul", test_smul); + ("ssat", test_ssat); + ("usat", test_usat); + + ("sgt", test_sgt); + ("sge", test_sge); + + ("ugt", test_ugt); + ("uge", test_uge); + + ("lsl", (fun () -> test_shift ~side:`L ~sign:`U)); + ("lsr", (fun () -> test_shift ~side:`R ~sign:`U)); + ("rol", (fun () -> test_rot ~side:`L)); + ("ror", (fun () -> test_rot ~side:`R)); + + ("asl", (fun () -> test_shift ~side:`L ~sign:`S)); + ("asr", (fun () -> test_shift ~side:`R ~sign:`S)); + + ("smul_u8_s8", test_smul_u8_s8); + + ("uextend", test_uextend); + ("sextend", test_sextend); + + ("ite", test_ite); + + ("udiv", test_udiv); + ("sdiv", test_sdiv); + + ("umod", test_umod); + ("smod", test_smod); + + ("bvueq", test_bvueq); + ("bvseq", test_bvseq); + + ("popcount", test_popcount); +*) + ("vpadd_16u16" , test_vpadd_16u16 ); + ("vpadd_32u8" , test_vpadd_32u8 ); + ("vpsub_16u16" , test_vpsub_16u16 ); + ("vpsub_32u8" , test_vpsub_32u8 ); + ("vmovsldup_256" , test_vmovsldup_256 ); + ("vpblend_8u32" , test_vpblend_8u32 ); + ("vpunpckh_4u64" , test_vpunpckh_4u64 ); + ("vpunpckl_4u64" , test_vpunpckl_4u64 ); + ("vperm2i128" , test_vperm2i128 ); + ("vpsra_16u16" , test_vpsra_16u16 ); + ("vpsrl_16u16" , test_vpsrl_16u16 ); + ("vpand_256" , test_vpand_256 ); + ("vpmulh_16u16" , test_vpmulh_16u16 ); + ("vpmulhu_16u16" , test_vpmulhu_16u16 ); + ("vpmulhrs_16u16" , test_vpmulhrs_16u16 ); + ("vpackus_16u16" , test_vpackus_16u16 ); + ("vpackss_16u16" , test_vpackss_16u16 ); + ("vpmaddubsw_256" , test_vpmaddubsw_256 ); + ("vpermd" , test_vpermd ); + ("vpermq" , test_vpermq ); + ("vbshufb_256" , test_vbshufb_256 ); + ("vpcmpgt_16u16" , test_vpcmpgt_16u16 ); + ("vpmovmskb_u256u64", test_vpmovmskb_u256u64); + ("vpunpckl_32u8" , test_vpunpckl_32u8 ); + ("vpblend_16u16" , test_vpblend_16u16 ); + ("vpextracti128" , test_extracti128 ); + ("vpinserti128" , test_inserti128 ); +] + +(* -------------------------------------------------------------------- *) +let main () = + let tests = + let n = Array.length Sys.argv in + if n <= 1 then + List.map snd tests + else + let names = Array.sub Sys.argv 1 (n - 1) in + let names = Set.of_array names in + let tests = List.filter (fun (x, _) -> Set.mem x names) tests in + List.map snd tests in + + Random.self_init (); + + List.iter (fun f -> f ()) tests + +(* -------------------------------------------------------------------- *) +let () = main () diff --git a/libs/lospecs/tests/simde b/libs/lospecs/tests/simde new file mode 160000 index 0000000000..0efee69e5c --- /dev/null +++ b/libs/lospecs/tests/simde @@ -0,0 +1 @@ +Subproject commit 0efee69e5c16185cad512aefe503b812167e15fe diff --git a/libs/lospecs/typing.ml b/libs/lospecs/typing.ml new file mode 100644 index 0000000000..9601b67f2a --- /dev/null +++ b/libs/lospecs/typing.ml @@ -0,0 +1,646 @@ +(* -------------------------------------------------------------------- *) +open Ptree +open Ast + +exception DestrError of string + +(* -------------------------------------------------------------------- *) +let as_seq1 (type t) (xs : t list) : t = + match xs with [ x ] -> x | _ -> raise (DestrError "as_seq1") + +(* -------------------------------------------------------------------- *) +let as_seq2 (type t) (xs : t list) : t * t = + match xs with [ x; y ] -> (x, y) | _ -> raise (DestrError "as_seq2") + +(* FIXME: check where used and catch error if needed *) + +(* -------------------------------------------------------------------- *) +module Env : sig + type env + + type sig_ = aword list option * atype + + val empty : env + val lookup : env -> symbol -> (ident * sig_) option + val push : env -> symbol -> sig_ -> env * ident + val export : env -> (symbol, ident * sig_) Map.t +end = struct + type sig_ = aword list option * atype + + type env = { vars : (symbol, ident * sig_) Map.t } + + let empty : env = { vars = Map.empty } + + let lookup (env : env) (x : symbol) = Map.find_opt x env.vars + + let push (env : env) (x : symbol) (sig_ : sig_) = + let idx = Ident.create x in + let env = { vars = Map.add x (idx, sig_) env.vars } in + (env, idx) + + let export (env : env) : (symbol, ident * sig_) Map.t = env.vars +end + +(* -------------------------------------------------------------------- *) +type env = Env.env + +(* -------------------------------------------------------------------- *) +let tt_pword (_ : env) ({ data = `W ty } : pword) : aword = `W ty + +(* -------------------------------------------------------------------- *) +exception TypingError of range * string + +(* -------------------------------------------------------------------- *) +let mk_tyerror_r (rg : range) (f : exn -> 'a) msg = + let buf = Buffer.create 0 in + let fbuf = Format.formatter_of_buffer buf in + Format.kfprintf + (fun _ -> + Format.pp_print_flush fbuf (); + f (TypingError (rg, Buffer.contents buf))) + fbuf msg + +(* -------------------------------------------------------------------- *) +let mk_tyerror (range : range) msg = + mk_tyerror_r range identity msg + +(* -------------------------------------------------------------------- *) +let tyerror (range : range) msg = + mk_tyerror_r range (fun e -> raise e) msg + +(* -------------------------------------------------------------------- *) +let tt_type (_ : env) (t : ptype) : atype = + (t.data :> atype) + +(* -------------------------------------------------------------------- *) +let tt_type_parameters + (env : env) + (range : range) + (who : symbol) + ~(expected : int) + (tp : pword list option) += + match tp with + | None -> tyerror range "missing type parameters annotation" + | Some tp -> + let tplen = List.length tp in + if expected <> tplen then begin + tyerror range + "invalid number of type parameters for `%s': expected %d, got %d" + who expected tplen + end; + (List.map (tt_pword env) tp) + +(* -------------------------------------------------------------------- *) +let check_arguments_count (range : range) ~(expected : int) (args : pexpr list) = + if List.length args <> expected then + tyerror range "invalid number of arguments"; + args + +(* -------------------------------------------------------------------- *) +let check_plain_arg (_ : env) (arg : pexpr option loced) = + match arg.data with + | None -> begin + tyerror + arg.range + "this argument cannot be generalized (not in a higher-order context)" + end + | Some arg -> + arg + +(* -------------------------------------------------------------------- *) +let as_int_constant (e : pexpr) : int = + match e.data with + | PEInt (i, None) -> i + | _ -> tyerror e.range "integer constant expected" + +(* -------------------------------------------------------------------- *) +type sig_ = { + s_name : string; + s_ntyparams : int; + s_argsty : aword list -> aword list; + s_retty : aword list -> aword; + s_mk : aword list -> aexpr list -> aexpr_; +} + +(* -------------------------------------------------------------------- *) +module Sigs : sig + val sla : sig_ + val sra : sig_ + val sll : sig_ + val srl : sig_ + val usat : sig_ + val ssat : sig_ + val uextend : sig_ + val sextend : sig_ + val not : sig_ + val incr : sig_ + val add : sig_ + val ssadd : sig_ + val usadd : sig_ + val sub : sig_ + val and_ : sig_ + val or_ : sig_ + val xor_ : sig_ + val umul : sig_ + val umullo : sig_ + val umulhi : sig_ + val smul : sig_ + val smullo : sig_ + val smulhi : sig_ + val usmul : sig_ + val sgt : sig_ + val sge : sig_ + val ugt : sig_ + val uge : sig_ + val popcount : sig_ +end = struct + let mk1 (f : aexpr -> aexpr_) (a : aexpr list) = + f (as_seq1 a) + + let mk2 (f : aexpr -> aexpr -> aexpr_) (a : aexpr list) = + let x, y = as_seq2 a in f x y + + let uniop ?(ret = fun x -> x) ~(name : string) mk = { + s_name = name; + s_ntyparams = 1; + s_argsty = (fun ws -> [as_seq1 ws]); + s_retty = (fun ws -> `W (ret (get_size (as_seq1 ws)))); + s_mk = fun ws -> mk1 (mk ws); + } + + let binop ?(ret = fun x -> x) ~(name : string) mk = { + s_name = name; + s_ntyparams = 1; + s_argsty = (fun ws -> List.make 2 (as_seq1 ws)); + s_retty = (fun ws -> `W (ret (get_size (as_seq1 ws)))); + s_mk = fun ws -> mk2 (mk ws); + } + + let satop ~(name : string) (k : us) = { + s_name = name; + s_ntyparams = 2; + s_argsty = (fun ws -> [fst (as_seq2 ws)]); + s_retty = (fun ws -> snd (as_seq2 ws)); + s_mk = (fun ws -> mk1 (fun x -> ESat (k, snd (as_seq2 ws), x))); + } + + let extendop ~(name : string) (k : us) = { + s_name = name; + s_ntyparams = 2; + s_argsty = (fun ws -> [fst (as_seq2 ws)]); + s_retty = (fun ws -> snd (as_seq2 ws)); + s_mk = (fun ws -> mk1 (fun x -> EExtend (k, snd (as_seq2 ws), x))); + } + + let shiftop ~(name : string) (d : lr) (k : la) = { + s_name = name; + s_ntyparams = 1; + s_argsty = (fun ws -> [as_seq1 ws; `W 8]); + s_retty = (fun ws -> as_seq1 ws); + s_mk = (fun _ -> mk2 (fun x y -> EShift (d, k, (x, y)))); + } + + let mulop ?ret ~(name : string) (k : mulk) = + let mk = fun ws x y -> + let w = as_seq1 ws in + EMul (k, w, (x, y)) + in + binop ?ret ~name mk + + let sla : sig_ = + shiftop ~name:"sla" `L `A + + let sra : sig_ = + shiftop ~name:"sra" `R `A + + let sll : sig_ = + shiftop ~name:"sll" `L `L + + let srl : sig_ = + shiftop ~name:"srl" `R `L + + let usat : sig_ = + satop ~name:"usat" `U + + let ssat : sig_ = + satop ~name:"ssat" `S + + let uextend : sig_ = + extendop ~name:"uextend" `U + + let sextend : sig_ = + extendop ~name:"sextend" `S + + let not : sig_ = + let mk = fun ws x -> ENot (as_seq1 ws, x) in + uniop ~name:"not" mk + + let incr : sig_ = + let mk = fun ws x -> EIncr (as_seq1 ws, x) in + uniop ~name:"incr" mk + + let add : sig_ = + let mk = fun ws x y -> EAdd (as_seq1 ws, `Word, (x, y)) in + binop ~name:"add" mk + + let ssadd : sig_ = + let mk = fun ws x y -> EAdd (as_seq1 ws, `Sat `S, (x, y)) in + binop ~name:"ssadd" mk + + let usadd : sig_ = + let mk = fun ws x y -> EAdd (as_seq1 ws, `Sat `U, (x, y)) in + binop ~name:"usadd" mk + + let sub : sig_ = + let mk = fun ws x y -> ESub (as_seq1 ws, (x, y)) in + binop ~name:"sub" mk + + let and_ : sig_ = + let mk = fun ws x y -> EAnd (as_seq1 ws, (x, y)) in + binop ~name:"and" mk + + let or_ : sig_ = + let mk = fun ws x y -> EOr (as_seq1 ws, (x, y)) in + binop ~name:"or" mk + + let umul : sig_ = + mulop ~ret:(fun n -> 2 * n) ~name:"umul" (`U `D) + + let umulhi : sig_ = + mulop ~name:"umulhi" (`U `H) + + let umullo : sig_ = + mulop ~name:"umullo" (`U `L) + + let smul : sig_ = + mulop ~ret:(fun n -> 2 * n) ~name:"smul" (`S `D) + + let smulhi : sig_ = + mulop ~name:"smulhi" (`S `H) + + let smullo : sig_ = + mulop ~name:"smullo" (`S `L) + + let usmul : sig_ = + mulop ~ret:(fun n -> 2 * n) ~name:"usmul" `US + + let sgt : sig_ = + let mk = fun ws x y -> ECmp (as_seq1 ws, `S, `Gt, (x, y)) in + binop ~ret:(fun _ -> 1) ~name:"sgt" mk + + let sge : sig_ = + let mk = fun ws x y -> ECmp (as_seq1 ws, `S, `Ge, (x, y)) in + binop ~ret:(fun _ -> 1) ~name:"sge" mk + + let ugt : sig_ = + let mk = fun ws x y -> ECmp (as_seq1 ws, `U, `Gt, (x, y)) in + binop ~ret:(fun _ -> 1) ~name:"ugt" mk + + let uge : sig_ = + let mk = fun ws x y -> ECmp (as_seq1 ws, `U, `Ge, (x, y)) in + binop ~ret:(fun _ -> 1) ~name:"uge" mk + + let xor_ : sig_ = + let mk = fun ws x y -> EXor (as_seq1 ws, (x, y)) in + binop ~name:"xor" mk + + let popcount = { + s_name = "popcount"; + s_ntyparams = 2; + s_argsty = (fun ws -> [fst (as_seq2 ws)]); + s_retty = (fun ws -> snd (as_seq2 ws)); + s_mk = (fun ws -> mk1 (fun x -> EPopCount (snd (as_seq2 ws), x))); + } +end + +(* -------------------------------------------------------------------- *) +let sigs : sig_ list = [ + Sigs.sla; + Sigs.sra; + Sigs.sll; + Sigs.srl; + Sigs.usat; + Sigs.ssat; + Sigs.uextend; + Sigs.sextend; + Sigs.not; + Sigs.incr; + Sigs.add; + Sigs.ssadd; + Sigs.usadd; + Sigs.sub; + Sigs.and_; + Sigs.or_; + Sigs.xor_; + Sigs.umul; + Sigs.umullo; + Sigs.umulhi; + Sigs.smul; + Sigs.smullo; + Sigs.smulhi; + Sigs.usmul; + Sigs.sgt; + Sigs.sge; + Sigs.ugt; + Sigs.uge; + Sigs.popcount; +] + +(* -------------------------------------------------------------------- *) +let get_sig_of_name (name : string) : sig_ option = + List.find_opt (fun x -> x.s_name = name) sigs + +(* -------------------------------------------------------------------- *) +let ty_compatible ~(src : atype) ~(dst : atype) : bool = + match src, dst with + | (`Signed | `Unsigned), `W _ -> true + | _, _ -> src = dst + +(* -------------------------------------------------------------------- *) +let join_types (ty1 : atype loced) (ty2 : atype loced) = + match ty1.data, ty2.data with + | `Unsigned, `W n -> `W n + | `W n, `Unsigned -> `W n + | _, _ -> + if ty1.data <> ty2.data then + tyerror + (Lc.merge ty1.range ty2.range) + "the branches of the conditional have incompatible types: %a / %a" + pp_atype ty1.data pp_atype ty2.data + else ty1.data + +(* -------------------------------------------------------------------- *) +let rec tt_expr_ (env : env) (e : pexpr) : aargs option * aexpr = + match e.data with + | PEParens e -> + (None, tt_expr env e) + + | PEInt (i, w) -> + let w = Option.map (tt_pword env) w in + let type_ = Option.default `Unsigned (w :> atype option) in + let e = { node = EInt i; type_; } in + (None, e) + + | PEFun (fargs, f) -> + let benv, args = tt_args env fargs in + (Some args, tt_expr benv f) + + | PEFName { data = (v, None) } -> begin + let (vid, (targs, vt)) = Option.get_exn + (Env.lookup env (Lc.unloc v)) + (mk_tyerror v.range "unknown variable: %s" (Lc.unloc v)) in + + match targs with + | None -> + (None, { node = EVar vid; type_ = vt; }) + + | Some targs -> + let ftargs = + List.map (fun ty -> (Ident.create "_", ty)) targs in + let args = + List.map + (fun (x, ty) -> { node = EVar x; type_ = (ty :> atype) }) + ftargs in + (Some ftargs, { node = EApp (vid, args); type_ = vt; }) + end + + | PEFName { data = (v, Some ws) } -> + let sig_ = + Option.get_exn + (get_sig_of_name (Lc.unloc v)) + (mk_tyerror v.range "unkown symbol: %s" (Lc.unloc v)) + in + + let ws = List.map (tt_pword env) ws in + let args = sig_.s_argsty ws in + let retty = sig_.s_retty ws in + let args = List.map (fun ty -> (Ident.create "_", ty)) args in + + let eargs = + List.map (fun (x, ty) -> + { node = EVar x; type_ = (ty :> atype); } + ) args + in + let node = sig_.s_mk ws eargs in + (Some args, { node; type_ = (retty :> atype); }) + + | PELet ((v, args, e1), e2) -> + let args, e1 = + let env, args = + args + |> Option.map (tt_args env) + |> Option.map (fun (e, a) -> (e, Some a)) + |> Option.default (env, None) in + let e1 = tt_expr env e1 in + (args, e1) + in + + let ebody, vid = + let targs = Option.map (List.map snd) args in + Env.push env (Lc.unloc v) (targs, e1.type_) in + + let e2 = tt_expr ebody e2 in + + let node = ELet ((vid, args, e1), e2) in + let type_ = e2.type_ in + + (None, { node; type_; }) + + | PECond (c, (pe1, pe2)) -> + let c = tt_expr env c in (* FIXME: must be a word *) + let e1 = tt_expr env pe1 in + let e2 = tt_expr env pe2 in + + let type_ = + join_types + (Lc.mk pe1.range e1.type_) + (Lc.mk pe2.range e2.type_) + in + + let e1 = { e1 with type_ } in + let e2 = { e2 with type_ } in + + let node = ECond (c, (e1, e2)) in + + (None, { node; type_; }) + + | PESlice (ev, (start, len, scale)) -> + let ev = tt_expr env ev in + let start = tt_expr env start in + let len = Option.default 1 (Option.map as_int_constant len) in + let scale = Option.default 1 (Option.map as_int_constant scale) in + let node = ESlice (ev, (start, len, scale)) + and type_ = `W (len * scale) in + (None, { node; type_; }) + + | PEAssign (ev, (start, len, scale), v) -> + let ev = tt_expr env ev in + let start = tt_expr env start in + let len = Option.default 1 (Option.map as_int_constant len) in + let scale = Option.default 1 (Option.map as_int_constant scale) in + let v = tt_expr env ~check:(`W (len * scale)) v in + let node = EAssign (ev, (start, len, scale), v) in + (None, { node; type_ = ev.type_; }) + + | PEApp ({ data = (f, None) }, args) -> + let (vid, (targs, vt)) = Option.get_exn + (Env.lookup env (Lc.unloc f)) + (mk_tyerror f.range "unknown symbol: %s" (Lc.unloc f)) in + + let targs = + Option.get_exn + targs + (mk_tyerror f.range "the symbol `%s' cannot be applied" (Lc.unloc f)) in + + if List.length args <> List.length targs then begin + tyerror e.range + "invalid number of arguments: expected %d, got %d" + (List.length targs) (List.length args) + end; + + let bds, args = List.fold_left_map (fun bds (a, ety) -> + match a.data with + | None -> + let x = Ident.create "_" in + let a = { node = EVar x; type_ = (ety :> atype); } in + ((x, ety) :: bds, a) + | Some a -> + (bds, tt_expr env ~check:(ety :> atype) a) + ) [] (List.combine args targs) + in + + let bds = if List.is_empty bds then None else Some (List.rev bds) in + let node = EApp (vid, args) in + + (bds, { node; type_ = vt; }) + + | PEApp ({ data = ({ data = "concat" as f }, w) } as fn, args) -> + let (`W w) = as_seq1 (tt_type_parameters env fn.range f ~expected:1 w) in + let args = List.map (check_plain_arg env) args in + let targs = List.map (tt_expr env ~check:(`W w)) args in + let wsz = `W (w * List.length targs) in + (None, { node = EConcat (wsz, targs); type_ = wsz; }) + + | PEApp ({ data = ({ data = "repeat" as f }, w) } as fn, args) -> + let (`W w) = as_seq1 (tt_type_parameters env fn.range f ~expected:1 w) in + let args = List.map (check_plain_arg env) args in + let e, n = as_seq2 (check_arguments_count e.range ~expected:2 args) in + let n = as_int_constant n in + let ne = tt_expr env ~check:(`W w) e in + (None, { node = ERepeat (`W (w * n), (ne, n)); type_ = `W (w * n); }) + + | PEApp ({ data = ({ data = "map" as c }, w) } as cn, args) -> + let `W w, `W n = as_seq2 (tt_type_parameters env cn.range c ~expected:2 w) in + let args = List.map (check_plain_arg env) args in + + if List.is_empty args then + tyerror e.range "the combinator `map' takes at least one argument"; + + let f, args = (List.hd args, List.tl args) in + let nargs = List.map (tt_expr ~check:(`W (w * n)) env) args in + + let ftargs, ftbody = tt_expr_ env f in + + let ftype = + match ftbody.type_ with + | `W k -> k + | _ -> tyerror f.range "the mapped function should return a word" in + + let ftargs = + Option.get_exn + ftargs + (mk_tyerror f.range "this expression must be higher-order") in + + let targs = List.map snd ftargs in + + if targs <> List.make (List.length args) (`W w) then begin + tyerror e.range + "the mapped function must take exactly %d arguments of type @%d" + (List.length targs) w + end; + + let node = EMap ((`W w, `W n), (ftargs, ftbody), nargs) + and type_ = `W (n * ftype) in + (None, { node; type_; }) + + | PEApp ({ data = (f, Some ws) } as fn, args) -> + let sig_ = + Option.get_exn + (get_sig_of_name (Lc.unloc f)) + (mk_tyerror f.range "unknown symbol: %s" (Lc.unloc f)) + in + tt_fname_app env e.range sig_ (Lc.mk fn.range ws) args + +(* -------------------------------------------------------------------- *) +and tt_fname_app + (env : env) + (range : range) + (sig_ : sig_) + (ws : pword list loced) + (args : pexpr option loced list) += + let ws = + tt_type_parameters + env ws.range sig_.s_name ~expected:sig_.s_ntyparams + (Some ws.data) + in + + let targs = sig_.s_argsty ws in + + if List.length args <> List.length targs then begin + tyerror range + "invalid number of arguments for `%s': expected %d, get %d" + sig_.s_name (List.length targs) (List.length args) + end; + + let bds, args = List.fold_left_map (fun bds (a, ety) -> + match a.data with + | None -> + let x = Ident.create "_" in + let a = { node = EVar x; type_ = (ety :> atype); } in + ((x, ety) :: bds, a) + | Some a -> + (bds, tt_expr env ~check:(ety :> atype) a) + ) [] (List.combine args targs) + in + + let bds = if List.is_empty bds then None else Some (List.rev bds) in + + let node = sig_.s_mk ws args in + let type_ = (sig_.s_retty ws :> atype) in + + (bds, { node; type_; }) + +(* -------------------------------------------------------------------- *) +and tt_expr (env : env) ?(check : atype option) (p : pexpr) : aexpr = + let (args, {node = n_; type_ = t;}) = tt_expr_ env p in + if not (Option.is_none args) then + tyerror p.range "high-order functions not allowed here"; + check |> Option.may (fun dst -> + if not (ty_compatible ~src:t ~dst) then begin + tyerror p.range + "this expression has type %a but is expected to have type %a" + pp_atype t pp_atype dst + end); + { node = n_; type_ = Option.default t check; } + +(* -------------------------------------------------------------------- *) +and tt_arg (env : env) ((x, { data = `W ty }) : parg) : env * aarg = + let env, idx = Env.push env (Lc.unloc x) (None, `W ty) in + (env, (idx, `W ty)) + +(* -------------------------------------------------------------------- *) +and tt_args (env : env) (args : pargs) : env * aargs = + List.fold_left_map tt_arg env args + +(* -------------------------------------------------------------------- *) +let tt_def (env : env) (p : pdef) : symbol * adef = + let env, args = tt_args env p.args in + let rty = tt_pword env p.rty in + let bod = tt_expr env ~check:(rty :> atype) p.body in + (p.name, { name = p.name; arguments = args; body = bod; rettype = rty; }) + +(* -------------------------------------------------------------------- *) +let tt_program (env : env) (p : pprogram) : (symbol * adef) list = + List.map (tt_def env) p diff --git a/libs/lospecs/word.ml b/libs/lospecs/word.ml new file mode 100644 index 0000000000..70601c824d --- /dev/null +++ b/libs/lospecs/word.ml @@ -0,0 +1,193 @@ +(* -------------------------------------------------------------------- *) +module type S = sig + type t + + val nbits : int + + val zero : t + val one : t + + val neg : t -> t + val add : t -> t -> t + val sub : t -> t -> t + val mul : t -> t -> t + val div : t -> t -> t + + val lognot : t -> t + val logand : t -> t -> t + val logor : t -> t -> t + val logxor : t -> t -> t + + val shiftl : t -> int -> t + val shiftr : t -> int -> t + + val abs : t -> t + + val of_int : int -> t + val to_int : t -> int + + val mod_ : t -> t -> t +end + +(* -------------------------------------------------------------------- *) +module type Size = sig + val nbits : int +end + +(* -------------------------------------------------------------------- *) +module SWord(I : Size) : S = struct + type t = int + + let () = assert (I.nbits < Sys.int_size) + + let nbits = I.nbits + + let of_int (x : int) : t = + x lsl (Sys.int_size - nbits) + + let to_int (x : t) : int = + x asr (Sys.int_size - nbits) + + let mask : int = + (1 lsl nbits) - 1 + + let zero : t = + of_int 0 + + let one : t = + of_int 1 + + let add (x : t) (y : t) = + x + y + + let sub (x : t) (y : t) = + x - y + + let neg (x : t) : t = + -x + + let mul (x : t) (y : t) : t = + (to_int x) * y + + let div (x : t) (y : t) : t = + of_int (x / y) + + let logand (x : t) (y : t) : t = + x land y + + let logor (x : t) (y : t) : t = + x lor y + + let logxor (x : t) (y : t) : t = + (x lxor y) land (of_int mask) + + let lognot (x : t) : t = + logxor x (of_int (-1)) + + let shiftl (x : t) (y : int) : t = + x lsl y + + let shiftr (x : t) (y : t) : t = + (x asr y) land (of_int mask) + + let abs (x : t) : t = + abs x + + (* Careful with size *) + let urem (x : t) (y : t) : t = + assert (Sys.int_size - nbits >= 1); + let x = x lsr 1 in + let y = y lsr 1 in + (x mod y) lsl 1 + + let mod_ (x: t) (y: t) : t = + if y = zero then x else + let u = urem (abs x) (abs y) in + if u = zero + then u + else if (x >= zero) && (y >= zero) + then u + else if (x < zero) && (y >= zero) + then (-u + y) + else if (x >= zero) && (y < zero) + then (u + y) + else -u + +end + +(* -------------------------------------------------------------------- *) +module UWord(I : Size) : S = struct + type t = int + + let () = assert (I.nbits < Sys.int_size) + + let nbits = I.nbits + + let mask : int = + (1 lsl nbits) - 1 + + let of_int (x : int) : t = + x land mask + + let to_int (x : t) : int = + x + + let zero : t = + of_int 0 + + let one : t = + of_int 1 + + let add (x : t) (y : t) = + of_int (x + y) + + let sub (x : t) (y : t) = + of_int (x - y) + + let neg (x : t) : t = + of_int (-x) + + let mul (x : t) (y : t) = + of_int (x * y) + + let div (x : t) (y : t) : t = + of_int (x / y) + + let logand (x : t) (y : t) : t = + x land y + + let logor (x : t) (y : t) : t = + x lor y + + let logxor (x : t) (y : t) = + x lxor y + + let lognot (x : t) : t = + x lxor mask + + let shiftl (x : t) (y : int) = + of_int (x lsl y) + + let shiftr (x : t) (y : int) = + x lsr y + + let abs (x : t) : t = + x + + let mod_ (x: t) (y : t) : t = + if y = 0 then x else x mod y +end + +(* -------------------------------------------------------------------- *) +let sword ~(size : int) : (module S) = + (module SWord(struct let nbits = size end)) + +(* -------------------------------------------------------------------- *) +let uword ~(size : int) : (module S) = + (module UWord(struct let nbits = size end)) + +(* -------------------------------------------------------------------- *) +let word ~(sign : [`U | `S]) ~(size : int) : (module S) = + match sign with + | `U -> uword ~size + | `S -> sword ~size diff --git a/libs/lospecs/word.mli b/libs/lospecs/word.mli new file mode 100644 index 0000000000..6871239ed9 --- /dev/null +++ b/libs/lospecs/word.mli @@ -0,0 +1,37 @@ +(* -------------------------------------------------------------------- *) +module type S = sig + type t + + val nbits : int + + val zero : t + val one : t + + val neg : t -> t + val add : t -> t -> t + val sub : t -> t -> t + val mul : t -> t -> t + val div : t -> t -> t + + val lognot : t -> t + val logand : t -> t -> t + val logor : t -> t -> t + val logxor : t -> t -> t + + val shiftl : t -> int -> t + val shiftr : t -> int -> t + + val abs : t -> t + + val of_int : int -> t + val to_int : t -> int + + val mod_ : t -> t -> t +end + +(* -------------------------------------------------------------------- *) +val sword : size:int -> (module S) +val uword : size:int -> (module S) + +(* -------------------------------------------------------------------- *) +val word : sign:[`U | `S] -> size:int -> (module S) diff --git a/src/dune b/src/dune index 75cb7e8abc..151bf4edfe 100644 --- a/src/dune +++ b/src/dune @@ -15,7 +15,7 @@ (public_name easycrypt.ecLib) (foreign_stubs (language c) (names eunix)) (modules :standard \ ec) - (libraries batteries camlp-streams dune-build-info dune-site inifiles markdown markdown.html pcre2 tyxml why3 yojson zarith) + (libraries batteries camlp-streams dune-build-info dune-site inifiles lospecs markdown markdown.html pcre2 tyxml why3 yojson zarith) ) (executable diff --git a/src/ec.ml b/src/ec.ml index 48da43b802..0175b65170 100644 --- a/src/ec.ml +++ b/src/ec.ml @@ -415,6 +415,7 @@ let main () = gccompact : int option; docgen : bool; outdirp : string option; + specs : spec_options; } end in @@ -471,7 +472,8 @@ let main () = ; eco = false ; gccompact = None ; docgen = false - ; outdirp = None } + ; outdirp = None + ; specs = cliopts.clio_specs; } end @@ -500,7 +502,8 @@ let main () = ; eco = cmpopts.cmpo_noeco ; gccompact = cmpopts.cmpo_compact ; docgen = false - ; outdirp = None } + ; outdirp = None + ; specs = cmpopts.cmpo_specs; } end @@ -536,6 +539,10 @@ let main () = lazy (T.from_channel ~name (open_in name)) in + let nospec = { + files = []; + } in + { prvopts = prvoff ; input = Some name ; terminal = terminal @@ -543,7 +550,8 @@ let main () = ; eco = true ; gccompact = None ; docgen = true - ; outdirp = docopts.doco_outdirp } + ; outdirp = docopts.doco_outdirp + ; specs = nospec; } end in @@ -650,6 +658,7 @@ let main () = EcCommands.cm_provers = state.prvopts.prvo_provers; EcCommands.cm_profile = state.prvopts.prvo_profile; EcCommands.cm_iterate = state.prvopts.prvo_iterate; + EcCommands.cm_specs = state.specs.files; } in let checkproof = not state.docgen in diff --git a/src/ecBigInt.ml b/src/ecBigInt.ml index a9a8b5a845..8788ce3035 100644 --- a/src/ecBigInt.ml +++ b/src/ecBigInt.ml @@ -74,6 +74,12 @@ module ZImpl : EcBigIntCore.TheInterface = struct let to_why3 (x : zint) = Why3.BigInt.of_string (to_string x) + + let to_zt (x: zint) : Z.t = + x + + let of_zt (z: Z.t) : zint = + z end (* -------------------------------------------------------------------- *) @@ -150,6 +156,12 @@ module BigNumImpl : EcBigIntCore.TheInterface = struct let to_why3 (x : zint) = Why3.BigInt.of_string (to_string x) + + let to_zt (x: zint) : Z.t = + x |> to_string |> Z.of_string + + let of_zt (z: Z.t) : zint = + z |> Z.to_string |> of_string end (* -------------------------------------------------------------------- *) diff --git a/src/ecBigIntCore.ml b/src/ecBigIntCore.ml index 39d9391478..07ee40d242 100644 --- a/src/ecBigIntCore.ml +++ b/src/ecBigIntCore.ml @@ -64,4 +64,6 @@ module type TheInterface = sig val pp_print : Format.formatter -> zint -> unit val to_why3 : zint -> Why3.BigInt.t + val to_zt: zint -> Z.t + val of_zt: Z.t -> zint end diff --git a/src/ecCircuits.ml b/src/ecCircuits.ml new file mode 100644 index 0000000000..1e3976c2c6 --- /dev/null +++ b/src/ecCircuits.ml @@ -0,0 +1,1193 @@ +(* -------------------------------------------------------------------- *) +open EcUtils +open EcBigInt +open EcPath +open EcEnv +open EcAst +open EcCoreFol +open EcIdent +open LDecl +open EcLowCircuits + +(* -------------------------------------------------------------------- *) +module Map = Batteries.Map +module Hashtbl = Batteries.Hashtbl +module Set = Batteries.Set +module Option = Batteries.Option + +(* -------------------------------------------------------------------- *) +module C_ = struct + include Lospecs.Aig + include Lospecs.Circuit + include Lospecs.Circuit_spec +end + +module HL = struct + include Lospecs.Hlaig +end + +(* -------------------------------------------------------------------- *) +let debug : bool = EcLowCircuits.debug + +(* -------------------------------------------------------------------- *) +let circ_red (hyps: hyps) = let base_red = EcReduction.full_red in + {base_red with delta_p = (fun pth -> + if (EcEnv.Circuit.reverse_operator (LDecl.toenv hyps) pth |> List.is_empty) then + base_red.delta_p pth + else + `No) +} + +module AInvFHash = struct + let combine = Why3.Hashcons.combine + + type t = form + + let known_hashes : (int, int) Map.t ref = ref Map.empty + + let clean_known : unit -> unit = + fun () -> known_hashes := Map.empty + + let bruijn_idents : (int, ident) Map.t ref = ref Map.empty + + let clean_bruijn_idents : unit -> unit = + fun () -> bruijn_idents := Map.empty + + let form_storage : (int, form) Map.t ref = ref Map.empty + + let clean_form_storage : unit -> unit = + fun () -> form_storage := Map.empty + + let nuke_state_from_orbit : unit -> unit = + fun () -> + clean_known (); + clean_bruijn_idents (); + clean_form_storage () + + let ident_of_debruijn_level (i: int) : ident = + match Map.find_opt i !bruijn_idents with + | Some id -> id + | None -> let id = create (string_of_int i) in + bruijn_idents := Map.add i id !bruijn_idents; + id + + type state = { + level: int; + subst: EcSubst.subst; + } + + + let add_to_state (id: ident) (ty: ty) (st: state) = + let new_id = ident_of_debruijn_level st.level in + let level = st.level + 1 in + let subst = EcSubst.add_flocal st.subst id (f_local new_id ty) in + { level; subst }, new_id + + + let to_debruijn (f: form) : form = + let rec doit (st: state) (f: form) = + match f.f_node with + | Fquant (qnt, bnds, f) -> + let st, bnds = + List.fold_left_map (fun st (orig_id, gty) -> + match gty with + | GTty ty -> + let st, new_id = add_to_state orig_id ty st in + st, (new_id, gty) + | _ -> + st, (orig_id, gty) + ) st bnds + in f_quant qnt bnds (doit st (EcSubst.subst_form st.subst f)) + | Fif (cond, tb, fb) -> + let doit = doit st in + f_if (doit cond) (doit tb) (doit fb) + | Fmatch (_, _, _) -> assert false + | Flet (lp, value, body) -> + begin match lp with + | LSymbol (orig_id, ty) -> + let nval = doit st value in + let st, new_id = add_to_state orig_id ty st in + let nbody = doit st (EcSubst.subst_form st.subst body) in + f_let (LSymbol (new_id, ty)) nval nbody + | LTuple bnds -> + let nval = doit st value in + let st, new_ids = List.fold_left_map (fun st (id, ty) -> add_to_state id ty st) st bnds in + let nbody = doit st (EcSubst.subst_form st.subst body) in + let nbinds = List.combine new_ids (List.snd bnds) in + f_let (LTuple nbinds) nval nbody + | LRecord (_, _) -> assert false + end + | Fapp (op, args) -> + let nargs = List.map (doit st) args in + let nop = doit st op in + f_app nop nargs f.f_ty + | Ftuple comps -> + f_tuple (List.map (doit st) comps) + | Fproj (tp, i) -> + f_proj (doit st tp) i f.f_ty + | FhoareF { hf_m; hf_pr; hf_f; hf_po } -> + let npre = doit st hf_pr in + let npo = doit st hf_po in + let m = hf_m in + f_hoareF {inv=npre;m} hf_f {inv=npo;m} + | FhoareS { hs_m=(m, me); hs_pr; hs_s; hs_po } -> + let npre = doit st hs_pr in + let npo = doit st hs_po in + f_hoareS me {inv=npre;m} hs_s {inv=npo;m} + | FbdHoareF _ -> assert false + | FbdHoareS _ -> assert false + | FeHoareF _ -> assert false + | FeHoareS _ -> assert false + | FequivF { ef_ml; ef_mr; ef_pr; ef_fl; ef_fr; ef_po } -> + let npre = doit st ef_pr in + let npo = doit st ef_po in + f_equivF {inv=npre;ml=ef_ml;mr=ef_mr} ef_fl ef_fr {inv=npo;ml=ef_ml;mr=ef_mr} + | FequivS { es_ml=(ml, mel); es_mr=(mr, mer); es_pr; es_sl; es_sr; es_po } -> + let npre = doit st es_pr in + let npo = doit st es_po in + f_equivS mel mer {inv=npre;ml;mr} es_sl es_sr {inv=npo;ml;mr} + | FeagerF _ -> assert false + | Fpr _ -> assert false + | Fint _ + | Flocal _ + | Fpvar (_, _) + | Fglob (_, _) + | Fop (_, _) -> f + in + doit {level = 0; subst = EcSubst.empty} f + + + + let hash_form (f: form) = + match Map.find_opt f.f_tag !known_hashes with + | Some hash -> hash + | None -> let fnorm = to_debruijn f in + form_storage := Map.add f.f_tag fnorm !form_storage; + known_hashes := Map.add f.f_tag fnorm.f_tag !known_hashes; + fnorm.f_tag +end + +(* -------------------------------------------------------------------- *) +type width = int +exception CircError of string Lazy.t + +let rec ctype_of_ty (env: env) (ty: ty) : ctype = + match ty.ty_node with + | Ttuple tys -> CTuple (List.map (ctype_of_ty env) tys) + | Tconstr (pth, []) when pth = EcCoreLib.CI_Bool.p_bool -> CBool + | _ -> begin + match EcEnv.Circuit.lookup_array_and_bitstring env ty with + | Some ({size=(_, Some size_arr)}, {size=(_, Some size_bs)}) -> CArray {width=size_bs; count=size_arr} + | None -> + begin match EcEnv.Circuit.lookup_bitstring_size env ty with + | Some sz -> CBitstring sz + | _ -> + Format.eprintf "Missing binding for type %a@." + EcPrinting.(pp_type (PPEnv.ofenv env)) ty; + raise (CircError (lazy "Failed to convert EC type to Circuit type")) + end + | Some ({size = (_, None)}, _) -> + raise (CircError (lazy ("No concrete binding for array type " ^ (Format.asprintf "%a" EcPrinting.(pp_type PPEnv.(ofenv env)) ty)))) + | Some (_, {size = (_, None)}) -> + raise (CircError (lazy ("No concrete binding for bitstring type " ^ (Format.asprintf "%a" EcPrinting.(pp_type PPEnv.(ofenv env)) ty)))) + end + + +let width_of_type (env: env) (t: ty) : int = + let cty = ctype_of_ty env t in + EcLowCircuits.size_of_ctype cty + +(* Requires concrete bindings for both types *) +let destr_array_type (env: env) (t: ty) : (int * ty) option = + match EcEnv.Circuit.lookup_array_and_bitstring env t with + | Some ({size = (_, Some size)}, {type_; size = (_, Some _)}) -> Some (size, EcTypes.tconstr type_ []) + | _ -> None + +(* FIXME: Fix an order for array size parameters, this one goes against the rest *) +let shape_of_array_type (env: env) (t: ty) : (int * int) = + match ctype_of_ty env t with + | CArray {width=w; count=n} -> (n, w) + | _ -> raise (CircError (lazy "shape_of_array_type on non array type")) + +let input_of_type ~name (env: env) (t: ty) : circuit = + let ct = ctype_of_ty env t in + input_of_ctype ~name ct + +(* Should correspond to QF_ABV *) +module BVOps = struct + let temp_symbol = "temp_circ_input" + + let is_of_int (env: env) (p: path) : bool = + match EcEnv.Circuit.reverse_bitstring_operator env p with + | Some (_, `OfInt) -> true + | _ -> false + + let op_is_parametric_bvop (env: env) (op: path) : bool = + match EcEnv.Circuit.lookup_bvoperator_path env op with + | Some { kind = `ASliceGet _ } + | Some { kind = `ASliceSet _ } + | Some { kind = `Extract _ } + | Some { kind = `Insert _ } + | Some { kind = `Map _ } + | Some { kind = `Get _ } + | Some { kind = `AInit _ } + | Some { kind = `Init _ } -> true + | _ -> false + + let circuit_of_parametric_bvop (env : env) (op: [`Path of path | `BvBind of EcDecl.crb_bvoperator]) (args: arg list) : circuit = + let op = match op with + | `BvBind op -> op + | `Path p -> begin match EcEnv.Circuit.lookup_bvoperator_path env p with + | Some op -> op + | None -> raise (CircError (lazy ("No binding matching operator at path " ^ (EcPath.tostring p)) )) + end + in + circuit_of_parametric_bvop op args + + let op_is_bvop (env: env) (op : path) : bool = + match EcEnv.Circuit.lookup_bvoperator_path env op with + | Some { kind = `Add _ } | Some { kind = `Sub _ } + | Some { kind = `Mul _ } | Some { kind = `Div _ } + | Some { kind = `Rem _ } | Some { kind = `Shl _ } + | Some { kind = `Shr _ } | Some { kind = `Rol _ } + | Some { kind = `Shrs _ } | Some { kind = `Shls _ } + | Some { kind = `Ror _ } | Some { kind = `And _ } + | Some { kind = `Or _ } | Some { kind = `Xor _ } + | Some { kind = `Not _ } | Some { kind = `Lt _ } + | Some { kind = `Le _ } | Some { kind = `Extend _ } + | Some { kind = `Truncate _ } | Some { kind = `Concat _ } + | Some { kind = `A2B _ } | Some { kind = `B2A _ } + | Some { kind = `Opp _ } -> true + | Some { kind = + `ASliceGet _ + | `ASliceSet _ + | `Extract _ + | `Insert _ + | `Map _ + | `AInit _ + | `Get _ + | `Init _ } + | None -> false + + let circuit_of_bvop (env: env) (op: [`Path of path | `BvBind of EcDecl.crb_bvoperator]) : circuit = + let op = match op with + | `BvBind op -> op + | `Path p -> begin match EcEnv.Circuit.lookup_bvoperator_path env p with + | Some op -> op + | None -> raise (CircError (lazy ("No binding matching operator at path " ^ (EcPath.tostring p)))) + end + in + circuit_of_bvop op +end +open BVOps + +module BitstringOps = struct + type binding = crb_bitstring_operator + + let op_is_bsop (env: env) (op: path) : bool = + match EcEnv.Circuit.reverse_bitstring_operator env op with + | Some (_, `OfInt) -> true + | _ -> false + + let circuit_of_bsop (env: env) (op: [`Path of path | `BSBinding of binding]) (args: arg list) : circuit = + let bnd = match op with + | `BSBinding bnd -> bnd + | `Path p -> begin match EcEnv.Circuit.reverse_bitstring_operator env p with + | Some bnd -> bnd + | None -> raise (CircError (lazy ("No binding matching operator at path " ^ (EcPath.tostring p)))) + end + in + (* assert false => should be guarded by a previous call to op_is_bsop *) + match bnd with + | bs, `From -> assert false (* doesn't translate to circuit *) + | {size = (_, Some size)}, `OfInt -> begin match args with + | [ `Constant i ] -> + circuit_of_zint ~size i + | args -> raise (CircError (lazy (Format.asprintf "Bad arguments for bitstring of_int: expected (int) got (%a)" EcPrinting.(pp_list ", " pp_arg) args))) + end + | {size = (_, None)}, `OfInt -> + raise (CircError (lazy "No concrete binding for type of of_int@.")) (* FIXME: error messages *) + | bs, `To -> assert false (* doesn't translate to circuit *) + | bs, `ToSInt -> assert false (* doesn't translate to circuit *) + | bs, `ToUInt -> assert false (* doesn't translate to circuit *) +end +open BitstringOps + +module ArrayOps = struct + type binding = crb_array_operator + + + let op_is_arrayop (env: env) (op: path) : bool = + match EcEnv.Circuit.reverse_array_operator env op with + | Some (_, `Get) -> true + | Some (_, `Set) -> true + | Some (_, `OfList) -> true + | _ -> false + + let circuit_of_arrayop (env: env) (op: [`Path of path | `ABinding of binding]) (args: arg list) : circuit = + let op = match op with + | `ABinding bnd -> bnd + | `Path p -> begin match EcEnv.Circuit.reverse_array_operator env p with + | Some bnd -> bnd + | None -> raise (CircError (lazy ("No binding matching operator at path " ^ (EcPath.tostring p))) ) + end + in + (* assert false => should be guarded by a call to op_is_arrayop *) + match op with + | (arr, `ToList) -> assert false (* We do not translate this to circuit *) + | (arr, `Get) -> begin match args with + | [ `Circuit (({type_ = CArray _}, inps) as arr); `Constant i] -> + array_get arr (BI.to_int i) + | args -> + let err = lazy (Format.asprintf "Bad inputs to arr get: Expected (arr, idx) got (%a)" (EcPrinting.pp_list "," pp_arg) args) in + raise (CircError err) + end + (* FIXME: Check argument order *) + | ({size = (_, Some size)}, `OfList) -> begin match args with + | [ `Circuit dfl; `List cs ] -> array_oflist cs dfl size + | args -> + let err = lazy (Format.asprintf "Bad inputs to arr of_list: Expected (default, list) got (%a)" (EcPrinting.pp_list "," pp_arg) args) in + raise (CircError err) + end + | ({size = (_, None)}, `OfList) -> raise (CircError (lazy "Array of list with non-concrete size")) + | (_arr, `Set) -> begin match args with + | [ `Circuit (({type_ = CArray _}, _) as arr); + `Constant i; + `Circuit (({type_ = CBitstring _}, _) as bs) ] -> + array_set arr (BI.to_int i) bs + | args -> + let err = lazy (Format.asprintf "Bad inputs to arr set: Expected (arr, idx, new_val) got (%a)" (EcPrinting.pp_list "," pp_arg) args) in + raise (CircError err) + end +end +open ArrayOps + +(* Functions for dealing with uninitialized inputs *) +let circuit_uninit (env:env) (t: ty) : circuit = + circuit_uninit (ctype_of_ty env t) + +module CircuitSpec = struct + let circuit_from_spec env (c : [`Path of path | `Bind of EcDecl.crb_circuit ] ) : circuit = + let c = match c with + | `Path p -> begin match EcEnv.Circuit.reverse_circuit env p with + | Some c -> c + | None -> raise (CircError (lazy ("No spec binding for operator at path " ^ EcPath.(tostring p)))) + end + | `Bind c -> c + in + let _, name = (EcPath.toqsymbol c.operator) in + let op = EcEnv.Op.by_path c.operator env in + + let unroll_fty (ty: ty) : ty list * ty = + let rec doit (acc: ty list) (ty: ty) : ty list * ty = + try + let a, b = EcTypes.tfrom_tfun2 ty in + (doit (a::acc) b) + with + | EcTypes.TyDestrError "fun" -> List.rev acc, ty + in doit [] ty + in + + let arg_tys, ret_ty = unroll_fty op.op_ty in + let arg_tys = List.map (ctype_of_ty env) arg_tys in + let ret_ty = ctype_of_ty env ret_ty in + circuit_from_spec ~name (arg_tys, ret_ty) c.circuit + + let op_has_spec env pth = + Option.is_some @@ EcEnv.Circuit.reverse_circuit env pth +end +open CircuitSpec + +let op_is_base (env: env) (p: path) : bool = + op_is_bvop env p || + op_has_spec env p + +let circuit_of_baseop (env: env) (p: path) : circuit = + if op_is_bvop env p then + circuit_of_bvop env (`Path p) + else if op_has_spec env p then + circuit_from_spec env (`Path p) + else + assert false (* Should be guarded by call to op_is_base *) + +let op_is_parametric_base (env: env) (p: path) = + op_is_parametric_bvop env p || + op_is_arrayop env p || + op_is_bsop env p + +let circuit_of_parametric_baseop (env: env) (p: path) (args: arg list) : circuit = + if op_is_parametric_bvop env p then + circuit_of_parametric_bvop env (`Path p) args + else if op_is_arrayop env p then + circuit_of_arrayop env (`Path p) args + else if op_is_bsop env p then + circuit_of_bsop env (`Path p) args + else + assert false (* Should be guarded by call to op_is_parametric_base *) + +let circuit_of_op (env: env) (p: path) : circuit = + let op = try + EcEnv.Circuit.reverse_operator env p |> List.hd + with Failure _ -> + raise (CircError (lazy "Failed reverse operator")) + in + match op with + | `Bitstring (bs, op) -> assert false (* Should be guarded by a call to op_is_base *) + | `Array _ -> assert false (* Should be guarded by a call to op_is_parametric_base *) + | `BvOperator bvbnd -> circuit_of_bvop env (`BvBind bvbnd) + | `Circuit c -> circuit_from_spec env (`Bind c) + +let circuit_of_op_with_args (env: env) (p: path) (args: arg list) : circuit = + let op = try + EcEnv.Circuit.reverse_operator env p |> List.hd + with Failure _ -> + raise (CircError (lazy "Failed reverse operator")) + in + match op with + | `Bitstring bsbnd -> circuit_of_bsop env (`BSBinding bsbnd) args + | `Array abnd -> circuit_of_arrayop env (`ABinding abnd) args + | `BvOperator bvbnd -> circuit_of_parametric_bvop env (`BvBind bvbnd) args + | `Circuit c -> assert false (* FIXME PR: Do we want to have parametric operators coming from the spec? *) +(* ------------------------------ *) + + +(* FIXME: why are all these openings required? *) + +(* FIXME: move this to module? *) +let type_is_registered (env: env) (t: ty) : bool = + (Option.is_some (EcEnv.Circuit.lookup_array_and_bitstring env t)) || + (Option.is_some (EcEnv.Circuit.lookup_bitstring env t)) + +(* FIXME: Check if we need to reduce twice here *) +let int_of_form (hyps: hyps) (f: form) : zint = + let env = toenv hyps in + let redmode = circ_red hyps in + let f = + EcCallbyValue.norm_cbv redmode hyps f + in + match f.f_node with + | Fint i -> i + | _ -> begin + try destr_int @@ EcCallbyValue.norm_cbv EcReduction.full_red hyps f + with + DestrError "int" + | DestrError "destr_int" -> + let err = lazy (Format.asprintf "Failed to reduce form | %a | to integer" + (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f) in + raise (CircError err) + end + +let rec form_list_of_form ?(ppe: EcPrinting.PPEnv.t option) (f: form) : form list = + match destr_op_app f with + | (pc, _), [h; {f_node = Fop(p, _)}] when + pc = EcCoreLib.CI_List.p_cons && + p = EcCoreLib.CI_List.p_empty -> + [h] + | (pc, _), [h; t] when + pc = EcCoreLib.CI_List.p_cons -> + h::(form_list_of_form t) + | _ -> + if debug then Option.may (fun ppenv -> Format.eprintf "Failed to destructure claimed list: %a@." (EcPrinting.pp_form ppenv) f) ppe; + raise (CircError (lazy "Failed to destruct list")) + +let form_is_iter (f: form) : bool = + match f.f_node with + | Fapp ({f_node = Fop (p, _)}, _) when + p = EcCoreLib.CI_Int.p_iter || + p = EcCoreLib.CI_Int.p_fold || + p = EcCoreLib.CI_Int.p_iteri -> true + | _ -> false + +(* Expands iter, fold and iteri (for integer arguments) *) +let expand_iter_form (hyps: hyps) (f: form) : form = + let redmode = circ_red hyps in + let env = toenv hyps in + let ppenv = EcPrinting.PPEnv.ofenv env in + let (@!!) f fs = + EcTypesafeFol.fapply_safe ~redmode hyps f fs + in + + let res = match f.f_node with + | Fapp ({f_node = Fop (p, _)}, [rep; fn; base]) when p = EcCoreLib.CI_Int.p_iteri -> + let rep = int_of_form hyps rep in + let is = List.init (BI.to_int rep) BI.of_int in + if debug then Format.eprintf "Done generating functions!@."; + let f = List.fold_left (fun f i -> + if debug then Format.eprintf "Expanding iter... Step #%d@.Form: %a@." (BI.to_int i) + (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv (toenv hyps))) f + ; + fn @!! [f_int i; f] + ) base is in + f + | Fapp ({f_node = Fop (p, _)}, [rep; fn; base]) when p = EcCoreLib.CI_Int.p_iter -> + let rep = int_of_form hyps rep in + let is = List.init (BI.to_int rep) BI.of_int in + let f = List.fold_left (fun f i -> fn @!! [f]) base is in + f + | Fapp ({f_node = Fop (p, _)}, [fn; base; rep]) when p = EcCoreLib.CI_Int.p_fold -> + let rep = int_of_form hyps rep in + let is = List.init (BI.to_int rep) BI.of_int in + let f = List.fold_left (fun f i -> fn @!! [f]) base is in + f + | _ -> raise (CircError (lazy (Format.asprintf "Failed to destructure form for iter expansion %a" EcPrinting.(pp_form ppenv) f))) + in + if debug then Format.eprintf "Expanded iter form: @.%a@." EcPrinting.(pp_form ppenv) res; + res + +let circuit_of_form + ?(st : state = empty_state) (* Program variable values *) + (hyps : hyps) + (f_ : EcAst.form) + : hyps * circuit = + + (* Form level cache, local to each high-level call *) + (* Forms are only cached the second time they are seen *) + (* Only cache function application for now *) + let cache : (int, circuit) Map.t ref = ref Map.empty in +(* let seen_forms : int Set.t ref = ref Set.empty in *) + let fhash = AInvFHash.hash_form in + let op_cache : circuit Mp.t ref = ref Mp.empty in + let redmode = circ_red hyps in + let env = toenv hyps in + let ppe = EcPrinting.PPEnv.ofenv env in + let fapply_safe f fs = + let res = EcTypesafeFol.fapply_safe ~redmode hyps f fs in + res + in + let int_of_form (f: form) : zint = + int_of_form hyps f + in + + (* Supposed to be called on an apply *) + let propagate_integer_arguments (op: form) (args: form list) : form = + let op = + let pth, _ = destr_op op in + match (EcEnv.Op.by_path pth env).op_kind with + | OB_oper (Some (OP_Plain f)) -> + f + | _ -> + if debug then Format.eprintf "Failed to get body for op: %a (args: %a)\n" + (EcPrinting.pp_form ppe) op + (EcPrinting.(pp_list "," (pp_form ppe))) args; + raise (CircError (lazy "Failed to get body for op in propagate integer arg")) + in + let res = fapply_safe op args in + res + in + let rec arg_of_form (st: state) (f: form) : arg = + match f.f_ty with + (* FIXME: check this *) + | t when t.ty_node = EcTypes.tint.ty_node -> arg_of_zint (int_of_form f) + | t when type_is_registered env t -> + let f = doit st f in + arg_of_circuit f + | {ty_node = Tfun(i_t, c_t)} when + i_t.ty_node = EcTypes.tint.ty_node && + type_is_registered env c_t -> + arg_of_init (fun i -> +(* + let tm = Unix.gettimeofday () in + Format.eprintf "Generating lane %d of init@." i; + *) + let f = (fapply_safe f [f_int (BI.of_int i)]) in + (* + let tm2 = Unix.gettimeofday () in + Format.eprintf "Done applying form took: %f@." (tm2 -. tm); + *) + let res = doit st f in + (* + let tm3 = Unix.gettimeofday () in + Format.eprintf "Done generating lane took %f (%f total)@." (tm3 -. tm2) (tm3 -. tm); +*) + res + ) + | {ty_node = Tconstr(p, [t])} when + p = EcCoreLib.CI_List.p_list && + type_is_registered env t -> + let cs = List.map (fun f -> + doit st f) + (try + (form_list_of_form ~ppe f) + with + CircError _ -> + raise (CircError + (lazy (Format.asprintf "Failed to destructure %a as list when attempting to convert it to an argument" + EcPrinting.(pp_form ppe) f)))) + in + arg_of_circuits cs + | _ -> Format.eprintf "Failed to convert form to arg: %a@." EcPrinting.(pp_form ppe) f; + raise (CircError (lazy "Failed to convert arg to form")) + + (* State does not get backward propagated so it is not returned *) + and doit (st: state) (f_: form) : circuit = + match f_.f_node with + | Fint z -> raise (CircError (lazy "Translation encountered unexpected integer value")) + + (* Assumes no quantifier bindings/new inputs within if *) + | Fif (c_f, t_f, f_f) -> + let t = doit st t_f in + let f = doit st f_f in + let c = doit st c_f in + circuit_ite ~c ~t ~f + + | Flocal idn -> + state_get st idn + + | Fop (pth, _) -> + begin + if pth = EcCoreLib.CI_Witness.p_witness then + (if debug then Format.eprintf "Assigning witness to var of type %a@." + EcPrinting.(pp_type ppe) f_.f_ty; + circuit_uninit env (f_.f_ty)) + else + match Mp.find_opt pth !op_cache with + | Some op -> + op + | None -> + if op_is_base env pth then + let circ = try + circuit_of_op env pth + with + | CircError le -> Format.eprintf "(%s ->)" (EcPath.tostring pth); raise (CircError le) + in + op_cache := Mp.add pth circ !op_cache; + circ + else + let circ = match (EcEnv.Op.by_path pth env).op_kind with + | OB_oper (Some (OP_Plain f)) -> +(* if debug then Format.eprintf "[BDEP] Opening definition of function at path %s" (EcPath.tostring pth); *) + doit st f + | _ -> + begin match EcFol.op_kind (destr_op f_ |> fst) with + | Some `True -> + (circuit_true :> circuit) + | Some `False -> + (circuit_false :> circuit) + | _ -> + let err = lazy (Format.sprintf "Unsupported op kind%s@." (EcPath.tostring pth)) in + raise (CircError err) + end + in + op_cache := Mp.add pth circ !op_cache; + circ + end + | Fapp (f, fs) -> begin try + begin match Map.find_opt (fhash f_) !cache with + | Some circ -> +(* + Format.eprintf "Cache hit for form: %a@." + EcPrinting.(pp_form ppe) f_; +*) + circ + | None -> let circ = + (* TODO: find a way to properly propagate int arguments. Done? *) + begin match f with + | {f_node = Fop (pth, _)} when op_is_parametric_base env pth -> + let args = List.map (arg_of_form st) fs in + circuit_of_op_with_args env pth args + + (* For dealing with iter cases: *) + | {f_node = Fop _} when form_is_iter f_ -> + trans_iter st hyps f fs + | {f_node = Fop (p, _)} when not (List.for_all (fun f -> f.f_ty.ty_node <> EcTypes.tint.ty_node) fs) -> +(* if debug then Format.eprintf "Attempting to propagate interger arguments for op with path %s@." (EcPath.tostring p); *) + doit st (propagate_integer_arguments f fs) + | {f_node = Fop _} -> + (* Assuming correct types coming from EC *) + (* FIXME: Add some extra info about errors when something here throws *) + begin match EcFol.op_kind (destr_op f |> fst), fs with + | Some `Eq, [f1; f2] -> + let c1 = doit st f1 in + let c2 = doit st f2 in + (circuit_eq c1 c2 :> circuit) + | Some `Not, [f] -> + let c = doit st f in + circuit_not c + (* FIXME: Should this be here on inside the module? *) + | Some `True, [] -> + (circuit_true :> circuit) + | Some `False, [] -> + (circuit_false :> circuit) + | Some `Imp, [f1; f2] -> + let c1 = doit st f1 in + let c2 = doit st f2 in + (circuit_or (circuit_not c1) c2 :> circuit) + | Some (`And _), [f1; f2] -> + let c1 = doit st f1 in + let c2 = doit st f2 in + (circuit_and c1 c2 :> circuit) + | Some (`Or _), [f1; f2] -> + let c1 = doit st f1 in + let c2 = doit st f2 in + (circuit_or c1 c2 :> circuit) + | Some `Iff, [f1; f2] -> + let c1 = doit st f1 in + let c2 = doit st f2 in + (circuit_or (circuit_and c1 c2) (circuit_and (circuit_not c1) (circuit_not c2)) :> circuit) +(* | Some `Not, [f] -> doit st hyps (f_not f) *) + | _ -> (* recurse down into definition *) + let f_c = doit st f in + let fcs = List.map (doit st) fs in + circuit_compose f_c fcs + end + | _ -> (* recurse down into definition *) + let f_c = doit st f in + let fcs = List.map (doit st) fs in + circuit_compose f_c fcs + end + in +(* in (if Set.mem (fhash f_) !seen_forms *) +(* then *) + cache := Map.add (fhash f_) circ !cache; +(* else seen_forms := Set.add (fhash f_) !seen_forms); *) + circ + end + with CircError le -> + let err = lazy (Format.asprintf "Call %a\n%s" EcPrinting.(pp_form ppe) f (Lazy.force le)) in + raise (CircError err) + end + + | Fquant (qnt, binds, f) -> + let binds = List.map (fun (idn, t) -> (idn, gty_as_ty t |> ctype_of_ty env)) binds in (* FIXME *) + begin match qnt with + | Lforall + | Llambda -> circ_lambda_oneshot st binds (fun st -> doit st f) (* FIXME: look at this interaction *) + | Lexists -> raise (CircError (lazy "Universal/Existential quantification not supported")) + (* TODO: figure out how to handle quantifiers *) + end + + | Fproj (f, i) -> + let ftp = doit st f in + (circuit_tuple_proj ftp i :> circuit) + + | Fmatch (f, fs, ty) -> raise (CircError (lazy "Match not supported")) + + | Flet (LSymbol (id, t), v, f) -> + let vc = doit st v in + let st = update_state st id vc in + doit st f + + | Flet (LTuple vs, v, f) -> + let vc = doit st v in + let comps = circuits_of_circuit_tuple vc in + let st = List.fold_left2 (fun st (id, t) vc -> + update_state st id vc) + st + vs + comps + in doit st f + + | Fpvar (pv, mem) -> + let v = match pv with + | PVloc v -> v + (* FIXME: Should globals be supported? *) + | _ -> raise (CircError (lazy "Global vars not supported")) + in + let v = match state_get_pv_opt st mem v with + | Some v -> v + | None -> + if debug then Format.eprintf "Assigning unassigned program variable %a of type %a@." EcPrinting.(pp_pv ppe) pv EcPrinting.(pp_type ppe) f_.f_ty; + circuit_uninit env f_.f_ty (* Allow uninitialized program variables *) + in + v + + | Fglob (id, mem) -> raise (CircError (lazy "glob not supported")) + + | Ftuple comps -> + let comps = + List.map (fun comp -> doit st comp) comps + in +(* assert (List.for_all (circuit_is_free) (comps :> circuit list)); *) + (circuit_tuple_of_circuits comps :> circuit) + + | _ -> raise (CircError (lazy "Unsupported form kind in translation")) + + + and trans_iter (st: state) (hyps: hyps) (f: form) (fs: form list) = + (* FIXME: move auxiliary function out of the definitions *) + let redmode = circ_red hyps in + let env = toenv hyps in + let ppenv = EcPrinting.PPEnv.ofenv env in + let fapply_safe f fs = + let res = EcTypesafeFol.fapply_safe ~redmode hyps f fs in + res + in + match f, fs with + | ({f_node = Fop (p, _)}, [rep; fn; base]) when p = EcCoreLib.CI_Int.p_iteri -> + let rep = int_of_form rep in + let fs = List.init (BI.to_int rep) (fun i -> + fapply_safe fn [f_int (BI.of_int i)] + ) in + List.fold_lefti (fun f i fn -> + if debug then Format.eprintf "Translating iteri... Step #%d@." i; + let fn = doit st fn in + circuit_compose fn [f] + ) (doit st base) fs + (* FIXME PR: this is currently being implemented directly on circuits, do we want this case as well? *) + | ({f_node = Fop (p, _)}, [rep; fn; base]) when p = EcCoreLib.CI_Int.p_iter -> assert false + | ({f_node = Fop (p, _)}, [rep; fn; base]) when p = EcCoreLib.CI_Int.p_fold -> assert false + | _ -> raise (CircError (lazy (Format.asprintf "Failed to destr form %a to translate iter" EcPrinting.(pp_form ppenv) f))) + in +(* + let t0 = Unix.gettimeofday () in + let () = if debug then Format.eprintf "Translating form %a@\n" (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv (LDecl.toenv hyps))) f_ in +*) + + let f_c = doit st f_ in + +(* + let () = if debug then Format.eprintf "Took %.2f s to translate form : %a@." (Unix.gettimeofday () -. t0) + (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv (LDecl.toenv hyps))) f_ in +*) + + hyps, f_c + +let circuit_simplify_equality ?(do_time = true) ~(st: state) ~(hyps: hyps) ~(pres: circuit list) (f1: form) (f2: form) : bool = + let tm = ref (Unix.gettimeofday ()) in + let env = toenv hyps in + let time (env: env) (t: float ref) (msg: string) : unit = + let new_t = Unix.gettimeofday () in +(* EcEnv.notify ~immediate:true env `Info "[W] %s, took %f s@." msg (new_t -. !t); *) + Format.eprintf "[W] %s, took %f s@." msg (new_t -. !t); + t := new_t + in + + if debug then Format.eprintf "Filletting circuit...@."; + let c1 = circuit_of_form ~st hyps f1 |> snd |> state_close_circuit st in + if do_time then time env tm "Left side circuit generation done"; + let c2 = circuit_of_form ~st hyps f2 |> snd |> state_close_circuit st in + if do_time then time env tm "Right side circuit generation done"; + + let pres = List.map (state_close_circuit st) pres in (* Assumes pres come open *) + assert (Option.is_none @@ circuit_has_uninitialized c1); + assert (Option.is_none @@ circuit_has_uninitialized c2); + let posts = circuit_eqs c1 c2 in + if do_time then time env tm "Done with postcondition circuit generation"; + + if debug then Format.eprintf "Number of checks before batching: %d@." (List.length posts); + let posts = batch_checks ~mode:`BySub posts in + if debug then Format.eprintf "Number of checks after batching: %d@." (List.length posts); + if do_time then time env tm "Done with lane compression"; + if fillet_tauts pres posts then + begin + if do_time then time env tm "Done with equivalence checking (structural equality + SMT)"; + true + end + else + begin + if do_time then time env tm "Failed equivalence check"; + false + end + +let circuit_of_path (hyps: hyps) (p: path) : hyps * circuit = + let f = EcEnv.Op.by_path p (toenv hyps) in + let f = match f.op_kind with + | OB_oper (Some (OP_Plain f)) -> f + | _ -> raise (CircError (lazy "Invalid operator type")) + in + circuit_of_form hyps f + +let vars_of_memtype ?st (env: env) (mt : memtype) = + let Lmt_concrete lmt = mt in + List.filter_map (function + | { ov_name = Some name; ov_type = ty } -> + Some { v_name = name; v_type = ty; } + | _ -> None + ) (Option.get lmt).lmt_decl + + +let process_instr ?me (hyps: hyps) (mem: memory) ~(st: state) (inst: instr) : hyps * state = + let env = toenv hyps in + let env = match me with + | Some me -> EcEnv.Memory.push_active_ss me env + | None -> env + in +(* if debug then Format.eprintf "[W] Processing : %a@." (EcPrinting.pp_instr (EcPrinting.PPEnv.ofenv env)) inst; *) + (* let start = Unix.gettimeofday () in *) + try + match inst.i_node with + | Sasgn (LvVar (PVloc v, _ty), e) -> +(* + if debug then Format.eprintf "Assigning form %a to var %s@\n" + (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv (LDecl.toenv hyps))) (form_of_expr mem e) v; +*) + let hyps, c = ((ss_inv_of_expr mem e).inv |> circuit_of_form ~st hyps) in + let st = update_state_pv st mem v c in + hyps, st + (* if debug then Format.eprintf "[W] Took %f seconds@." (Unix.gettimeofday() -. start); *) + | Sasgn (LvTuple (vs), {e_node = Etuple es; _}) when List.compare_lengths vs es = 0 -> + let st = List.fold_left (fun (hyps, st) (v, e) -> + let hyps, c = ((ss_inv_of_expr mem e).inv |> circuit_of_form ~st hyps) in + let st = update_state_pv st mem v c in + hyps, st + ) (hyps, st) + (List.combine + (List.map (function + | (PVloc v, _ty) -> v + | _ -> raise (CircError (lazy "Failed to parse tuple assignment"))) vs) + es) in + st + | Sasgn (LvTuple (vs), e) -> + let hyps, tp = ((ss_inv_of_expr mem e).inv |> circuit_of_form ~st hyps) in + let comps = circuits_of_circuit_tuple tp in + let st = List.fold_left2 (fun st (pv, _ty) c -> + let v = match pv with + | PVloc v -> v + | _ -> raise (CircError (lazy "Global variables not supported")) + in + update_state_pv st mem v c + ) st vs (comps :> circuit list) + in + hyps, st + | _ -> + let err = lazy (Format.asprintf "Instruction not supported: %a@." + (EcPrinting.pp_instr (EcPrinting.PPEnv.ofenv env)) inst) in + raise (CircError err) + with + | CircError le -> + let err = lazy ( + Format.asprintf "BDep failed on instr: %a@.CircError:@.%s@.BACKTRACE: %s@.@." + (EcPrinting.pp_instr (EcPrinting.PPEnv.ofenv env)) inst + (Lazy.force le) + (Printexc.get_backtrace ())) in + raise (CircError err) + | e -> + (* FIXME: Bad handling *) + let err = lazy ( + Format.asprintf "BDep failed on instr: %a@.Exception thrown: %s@.BACKTRACE: %s@.@." + (EcPrinting.pp_instr (EcPrinting.PPEnv.ofenv env)) inst + (Printexc.to_string e) + (Printexc.get_backtrace ())) in + raise (CircError err) + +(* FIXME: check if memory is the right one in calls to state *) +let instrs_equiv + (hyps : hyps ) + ((mem, mt) : memenv ) + ?(keep : EcPV.PV.t option ) + ?(st : state = empty_state ) + (s1 : instr list ) + (s2 : instr list ) : bool += + let env = LDecl.toenv hyps in + + let rd, rglobs = EcPV.PV.elements (EcPV.is_read env (s1 @ s2)) in + let wr, wglobs = EcPV.PV.elements (EcPV.is_write env (s1 @ s2)) in + + if not (List.is_empty rglobs && List.is_empty wglobs) then + raise (CircError (lazy "the statements should not read/write globs")); + + if not (List.for_all (EcTypes.is_loc |- fst) (rd @ wr)) then + raise (CircError (lazy "the statements should not read/write global variables")); + + let inputs = List.map (fun (pv, ty) -> { v_name = EcTypes.get_loc pv; v_type = ty; }) (rd @ wr) in + let inputs = List.map (fun {v_name; v_type} -> (create v_name, ctype_of_ty env v_type)) inputs in + let st = open_circ_lambda st inputs in + + let hyps, st1 = List.fold_left (fun (hyps, st) -> process_instr hyps mem ~st) (hyps, st) s1 in + let hyps, st2 = List.fold_left (fun (hyps, st) -> process_instr hyps mem ~st) (hyps, st) s2 in + + let st1 = close_circ_lambda st1 in + let st2 = close_circ_lambda st2 in + (* FIXME: what was the intended behaviour for keep? *) + match keep with + | Some pv -> + let vs = EcPV.PV.elements pv |> fst in + let vs = List.map (function + | (PVloc v, ty) -> (v, ty) + | _ -> raise (CircError (lazy "global variables not supported")) + ) vs + in List.for_all (fun (var, ty) -> + let circ1 = state_get_pv_opt st1 mem var in + let circ2 = state_get_pv_opt st2 mem var in + match circ1, circ2 with + | None, None -> true + | None, Some circ1 + | Some circ1, None -> false (* Variable only defined on one of the blocks (and not in the prelude) *) + | Some circ1, Some circ2 -> circ_equiv circ1 circ2 + ) vs + | None -> state_get_all_memory st mem |> List.for_all (fun (var, _) -> + let circ1 = state_get_pv st1 mem var in + let circ2 = state_get_pv st2 mem var in + circ_equiv circ1 circ2 + ) + +(* FIXME: remove variable list from the arguments *) +(* FIXME: change memory -> memenv *) +let state_of_prog ?me (hyps: hyps) (mem: memory) ?(st: state = empty_state) (proc: instr list) (invs: variable list) : hyps * state = + let env = LDecl.toenv hyps in + let invs = List.map (fun {v_name; v_type} -> ((mem, v_name), ctype_of_ty env v_type)) invs in + let st = open_circ_lambda_pv st invs in + + let hyps, st = + List.fold_left (fun (hyps, st) -> process_instr ?me hyps mem ~st) (hyps, st) proc + in + hyps, close_circ_lambda st + +(* FIXME: refactor this function *) +let rec circ_simplify_form_bitstring_equality + ?(st: state = empty_state) + ?(pres: circuit list = []) + (hyps: hyps) + (f: form) + : form = + let env = toenv hyps in + + let rec check (f : form) = + match EcFol.sform_of_form f with + | SFeq (f1, f2) + when (Option.is_some @@ EcEnv.Circuit.lookup_bitstring env f1.f_ty) + || (Option.is_some @@ EcEnv.Circuit.lookup_array env f1.f_ty) + -> + f_bool (circuit_simplify_equality ~st ~hyps ~pres f1 f2) + | _ -> f_map (fun ty -> ty) check f + in check f + + +(* Mli stuff needed: *) +let compute ~(sign: bool) (c: circuit) (args: zint list) : zint = + match compute ~sign c (List.map (fun z -> arg_of_zint z) args) with + | Some z -> z + | None -> raise (CircError (lazy "Failed to reduce circuit to constant in compute")) + +let circ_equiv ?(pcond: circuit option) c1 c2 = + circ_equiv ?pcond c1 c2 + +let circ_sat = circ_sat +let circ_taut = circ_taut + +let circuit_permute (bsz: int) (perm: int -> int) (c: circuit) : circuit = + let c = match c with + | ({ type_= CBitstring _; reg = r}, inps) as c -> c + | _ -> assert false (* FIXME PR: currently only implemented for bitstring, do we want to expand this ? *) + in + (permute bsz perm c :> circuit) + +let circuit_mapreduce ?(perm : (int -> int) option) (c: circuit) (w_in: width) (w_out: width) : circuit list = + let c = match c, perm with + | ({type_ = CBitstring _}, inps) as c, None -> c + | ({type_ = CBitstring _}, inps) as c, Some perm -> permute w_out perm c + | _ -> assert false (* FIXME PR: currently only implemented for bitstring, do we want to expand this ? *) + in + (decompose w_in w_out c :> circuit list) + +let circuit_to_string ((circ, inps): circuit) : string = Format.asprintf "(%a => %a)" EcPrinting.(pp_list ", " pp_cinp) inps pp_circ circ +let circuit_ueq = (fun c1 c2 -> (circuit_eq c1 c2 :> circuit)) +let circuit_aggregate = + circuit_aggregate +let circuit_has_uninitialized = circuit_has_uninitialized + +let circuit_to_file = circuit_to_file + +let circuit_aggregate_inps = + circuit_aggregate_inputs + +let circuit_slice (c: circuit) (size: int) (offset: int) = + circuit_slice ~size c offset + +(* FIXME: this should use ids instead of strings *) +let circuit_align_inputs = + align_inputs + +let circuit_flatten ((circ, inps) as c: circuit) = + convert_type (CBitstring (size_of_ctype circ.type_)) c + +let state_get = state_get_pv +let state_get_opt = state_get_pv_opt +let state_get_all = fun st -> state_get_all_pv st |> List.snd + +(* (cbitstring_of_circuit ~strict:false c :> circuit) *) +let circuit_state_of_memenv ~(st: state) (env:env) ((m, mt): memenv) : state = + match mt with + | (Lmt_concrete Some {lmt_decl=decls}) -> + let bnds = List.map (fun {ov_name; ov_type} -> + match ov_name with + | Some v -> + begin try + Some ((m, v), ctype_of_ty env ov_type) + with CircError err -> + raise (CircError (lazy ( + (Format.asprintf "Failed for decl for var %s@." v) ^ Lazy.force err + ))) + end + | None -> None + ) decls in + open_circ_lambda_pv st (List.filter_map identity bnds) + | Lmt_concrete None -> st + + +(* Generally called without the optional argument, here just to see if we need it, + maybe remove later? FIXME *) +let circuit_state_of_hyps ?(strict = false) ?(use_mem = false) ?(st = empty_state) hyps : state = + let env = toenv hyps in + let ppe = EcPrinting.PPEnv.ofenv env in + let st = List.fold_left (fun st (id, lk) -> + if debug then Format.eprintf "Processing hyp: %s@." (id.id_symb); + match lk with +(* FIXME: Reasoning here is that we do not directly process program variables in the hyps + They are either given a value by assignment in the program or if they are used + before that they are implicitly initialized to BAD +*) + + | EcBaseLogic.LD_mem mt when use_mem -> circuit_state_of_memenv ~st env (id, mt) + + (* Initialized variable. + Check if body is convertible to circuit, if not just process it as uninitialized. + TODO: Maybe do a first pass on this, check convertibility and remove duplicates? *) + | EcBaseLogic.LD_var (t, Some f) -> + if debug then Format.eprintf "Assigning %a to %a@." EcPrinting.(pp_form ppe) f EcIdent.pp_ident id; + begin try + update_state st id (circuit_of_form ~st hyps f |> snd) + with CircError _ -> + try + open_circ_lambda st [(id, ctype_of_ty env t)] + with (CircError _) as e -> + if strict then raise e else st + end + + (* Uninitialized variable. + Treat as input *) + | EcBaseLogic.LD_var (t, None) -> + begin try + open_circ_lambda st [(id, ctype_of_ty env t)] + with (CircError _) as e -> + if strict then raise e else st end + + (* For things of the form a_ = a{&hr}, we assume the local variable takes precedence *) + | EcBaseLogic.LD_hyp f -> + if debug then Format.eprintf "Form hyp: %a@.Simplified: %a@." + EcPrinting.(pp_form ppe) f + EcPrinting.(pp_form ppe) (EcCallbyValue.norm_cbv (circ_red hyps) hyps f) + ; + begin match (EcCallbyValue.norm_cbv (circ_red hyps) hyps f) with + | {f_node=Fapp ({f_node = Fop (p, _); _}, [{f_node = Fpvar (PVloc pv, m); _}; fv])} + | {f_node=Fapp ({f_node = Fop (p, _); _}, [fv; {f_node = Fpvar (PVloc pv, m); _}])} when EcFol.op_kind p = Some `Eq -> + begin try + update_state_pv st m pv (circuit_of_form ~st hyps fv |> snd) + with CircError _ -> + st + end + | _ -> st + end + + + (* Some formula which we know to hold. Ignore for now? + TODO: FIXME: What to do with this in general? Maybe process it separately in another function + | EcBaseLogic.LD_hyp f_hyp -> + begin try + ignore (circuit_of_form ~st hyps f_hyp); + (f_imp f_hyp goal), [] + with e -> + if debug then Format.eprintf "Failed to convert hyp %a with error:@.%s@." + EcPrinting.(pp_form (PPEnv.ofenv (toenv hyps))) f_hyp (Printexc.to_string e); + (goal), [] + end +*) + | _ -> st + ) st (List.rev (tohyps hyps).h_local) + in + st + +let clear_translation_caches () = + EcLowCircuits.reset_backend_state (); + AInvFHash.nuke_state_from_orbit () diff --git a/src/ecCircuits.mli b/src/ecCircuits.mli new file mode 100644 index 0000000000..2240384997 --- /dev/null +++ b/src/ecCircuits.mli @@ -0,0 +1,70 @@ +(* -------------------------------------------------------------------- *) +open EcIdent +open EcSymbols +open EcAst +open EcEnv +open LDecl +open EcLowCircuits + +(* -------------------------------------------------------------------- *) +module Map = Batteries.Map + +(* -------------------------------------------------------------------- *) +exception CircError of string Lazy.t + +(* -------------------------------------------------------------------- *) +(* Utilities (figure out better name) *) +val circ_red : hyps -> EcReduction.reduction_info +val width_of_type : env -> ty -> int +val circuit_to_string : circuit -> string +val ctype_of_ty : env -> ty -> ctype + +(* State utilities *) +val state_get : state -> memory -> symbol -> circuit +val state_get_opt : state -> memory -> symbol -> circuit option +val state_get_all : state -> circuit list + +(* Create circuits *) +val input_of_type : name:[`Str of string | `Idn of ident | `Bad] -> env -> ty -> circuit + +(* Transform circuits *) +val circuit_ueq : circuit -> circuit -> circuit +val circuit_aggregate : circuit list -> circuit +val circuit_aggregate_inps : circuit -> circuit +val circuit_flatten : circuit -> circuit +val circuit_permute : int -> (int -> int) -> circuit -> circuit +val circuit_mapreduce : ?perm:(int -> int) -> circuit -> int -> int -> circuit list + +(* Use circuits *) +val compute : sign:bool -> circuit -> BI.zint list -> BI.zint +val circ_equiv : ?pcond:circuit -> circuit -> circuit -> bool +val circ_sat : circuit -> bool +val circ_taut : circuit -> bool + +(* Generate circuits *) +(* Form processors *) +val circuit_of_form : ?st:state -> hyps -> form -> hyps * circuit +val circuit_simplify_equality : ?do_time:bool -> st:state -> hyps:hyps -> pres:circuit list -> form -> form -> bool +val circ_simplify_form_bitstring_equality : + ?st:state -> + ?pres:circuit list -> hyps -> form -> form + +(* Proc processors *) +val state_of_prog : ?me:memenv -> hyps -> memory -> ?st:state -> instr list -> variable list -> hyps * state +val instrs_equiv : hyps -> memenv -> ?keep:EcPV.PV.t -> ?st:state -> instr list -> instr list -> bool +val process_instr : ?me:memenv -> hyps -> memory -> st:state -> instr -> hyps * state +(* val pstate_of_memtype : ?pstate:pstate -> env -> memtype -> pstate * cinput list *) + +val circuit_state_of_memenv : st:state -> env -> memenv -> state +val circuit_state_of_hyps : ?strict:bool -> ?use_mem:bool -> ?st:state -> hyps -> state + +(* Check for uninitialized inputs *) +val circuit_has_uninitialized : circuit -> int option + +val circuit_slice : circuit -> int -> int -> circuit +val circuit_align_inputs : circuit -> (int * int) option list -> circuit + +val circuit_to_file : name:string -> circuit -> symbol + +(* Imperative state clearing *) +val clear_translation_caches : unit -> unit diff --git a/src/ecCommands.ml b/src/ecCommands.ml index 9f647c52b4..2c2f284f4a 100644 --- a/src/ecCommands.ml +++ b/src/ecCommands.ml @@ -749,6 +749,14 @@ and process_dump scope (source, tc) = scope +(* -------------------------------------------------------------------- *) +and process_crbind (scope : EcScope.scope) (binding : pcrbinding) = + match binding.binding with + | CRB_Bitstring bs -> EcScope.Circuit.add_bitstring scope binding.locality bs + | CRB_Array ba -> EcScope.Circuit.add_array scope binding.locality ba + | CRB_BvOperator op -> EcScope.Circuit.add_bvoperator scope binding.locality op + | CRB_Circuit cr -> EcScope.Circuit.add_circuits scope binding.locality cr + (* -------------------------------------------------------------------- *) and process ?(src : string option) (ld : Loader.loader) (scope : EcScope.scope) g = let loc = g.pl_loc in @@ -793,6 +801,7 @@ and process ?(src : string option) (ld : Loader.loader) (scope : EcScope.scope) | Greduction red -> `Fct (fun scope -> process_reduction scope red) | Ghint hint -> `Fct (fun scope -> process_hint scope hint) | GdumpWhy3 file -> `Fct (fun scope -> process_dump_why3 scope file) + | Gcrbinding bind -> `Fct (fun scope -> process_crbind scope bind) with | `Fct f -> Some (f scope) | `State f -> f scope; None @@ -827,6 +836,7 @@ type checkmode = { cm_provers : string list option; cm_profile : bool; cm_iterate : bool; + cm_specs : string list; } let initial ~checkmode ~boot ~checkproof = @@ -852,6 +862,7 @@ let initial ~checkmode ~boot ~checkproof = scope [tactics; prelude] in let scope = EcScope.Prover.set_default scope poptions in + let scope = EcScope.Circuit.register_spec_files scope checkmode.cm_specs in let scope = if checkproof then begin if checkall then diff --git a/src/ecCommands.mli b/src/ecCommands.mli index f61a313f34..3805e89858 100644 --- a/src/ecCommands.mli +++ b/src/ecCommands.mli @@ -22,6 +22,7 @@ type checkmode = { cm_provers : string list option; cm_profile : bool; cm_iterate : bool; + cm_specs : string list; } val initial : checkmode:checkmode -> boot:bool -> checkproof:bool -> EcScope.scope diff --git a/src/ecCoreFol.ml b/src/ecCoreFol.ml index 03dd1f64ec..3ebe5a9712 100644 --- a/src/ecCoreFol.ml +++ b/src/ecCoreFol.ml @@ -182,6 +182,34 @@ let f_true = f_op EcCoreLib.CI_Bool.p_true [] tbool let f_false = f_op EcCoreLib.CI_Bool.p_false [] tbool let f_bool = fun b -> if b then f_true else f_false +(* -------------------------------------------------------------------- *) +(* TODO: check types here *) +let ty_ftlist1 ty = toarrow (List.make 1 ty) (tlist ty) +let ty_ftlist2 ty = toarrow ([ty; (tlist ty)]) (tlist ty) +let ty_flist1 ty = toarrow (List.make 1 (tlist ty)) (tlist ty) +let ty_flist2 ty = toarrow (List.make 2 (tlist ty)) (tlist ty) +let ty_fllist ty = toarrow (List.make 1 (tlist @@ tlist ty)) (tlist ty) +let ty_lmap ty1 ty2 = toarrow ([toarrow [ty1] ty2; tlist ty1]) (tlist ty2) +let ty_chunk ty = toarrow [tint; tlist ty] (tlist @@ tlist ty) +let ty_all ty = toarrow [(toarrow [ty] tbool); tlist ty] tbool + +let fop_empty ty = f_op EcCoreLib.CI_List.p_empty [ty] (tlist ty) +let fop_cons ty = f_op EcCoreLib.CI_List.p_cons [ty] (ty_ftlist2 ty) +let fop_append ty = f_op EcCoreLib.CI_List.p_append [ty] (ty_flist2 ty) +let fop_flatten ty = f_op EcCoreLib.CI_List.p_flatten [ty] (ty_fllist ty) +let fop_lmap ty1 ty2 = f_op EcCoreLib.CI_List.p_map [ty2; ty1] (ty_lmap ty1 ty2) +let fop_chunk ty = f_op EcCoreLib.CI_List.p_chunk [ty] (ty_chunk ty) +let fop_all ty = f_op EcCoreLib.CI_List.p_all [ty] (ty_all ty) + +let f_append a b ty = f_app (fop_append ty) [a; b] (tlist ty) +let f_cons a b ty = f_app (fop_cons ty) [a; b] (tlist ty) +let f_flatten a ty = f_app (fop_flatten ty) [a] (tlist ty) +let f_lmap f a ty1 ty2 = f_app (fop_lmap ty1 ty2) [f;a] (tlist ty2) +let f_chunk a (n: int) ty2 = + let ty = tfrom_tlist a.f_ty in + f_app (fop_chunk ty) [mk_form (Fint (BI.of_int n)) tint; a] (tlist @@ tlist ty) +let f_all f a ty = f_app (fop_all ty) [f; a] tbool + (* -------------------------------------------------------------------- *) let f_tuple args = match args with @@ -785,6 +813,8 @@ let is_op_not p = EcPath.p_equal EcCoreLib.CI_Bool.p_not p let is_op_imp p = EcPath.p_equal EcCoreLib.CI_Bool.p_imp p let is_op_iff p = EcPath.p_equal EcCoreLib.CI_Bool.p_iff p let is_op_eq p = EcPath.p_equal EcCoreLib.CI_Bool.p_eq p +let is_op_cons p = EcPath.p_equal EcCoreLib.CI_List.p_cons p +let is_op_witness p = EcPath.p_equal EcCoreLib.CI_Witness.p_witness p (* -------------------------------------------------------------------- *) let destr_op = function @@ -866,6 +896,22 @@ let destr_nots form = | Some form -> aux (not b) form in aux true form +let destr_cons form = + match destr_app form with + | {f_node = Fop (p, _)}, [h;t] when is_op_cons p -> (h, t) + | _ -> destr_error "cons" + +let destr_list form = + let rec aux form = + match try Some (destr_cons form) with DestrError "cons" -> None with + | Some (h, t) -> h::(aux t) + | None -> [] + in + try + let h, t = destr_cons form in + h::(aux t) + with DestrError "cons" -> raise (DestrError "list") + (* -------------------------------------------------------------------- *) let is_from_destr dt f = try ignore (dt f); true with DestrError _ -> false @@ -900,6 +946,8 @@ let is_bdHoareF f = is_from_destr destr_bdHoareF f let is_pr f = is_from_destr destr_pr f let is_eq_or_iff f = (is_eq f) || (is_iff f) +let is_witness f = is_from_destr (fun f -> destr_op f |> fst |> is_op_witness) f + (* -------------------------------------------------------------------- *) let split_args f = match f_node f with @@ -939,7 +987,9 @@ let rec form_of_expr_r ?m (e : expr) = | Evar pv -> begin match m with - | None -> failwith "expecting memory" + | None -> + Printexc.(get_callstack 100 |> print_raw_backtrace stderr); + failwith "expecting memory" | Some m -> (f_pvar pv e.e_ty m).inv end diff --git a/src/ecCoreFol.mli b/src/ecCoreFol.mli index 0977a33a50..8e13a547db 100644 --- a/src/ecCoreFol.mli +++ b/src/ecCoreFol.mli @@ -133,6 +133,32 @@ val f_eagerF : ts_inv -> stmt -> xpath -> xpath -> stmt -> ts_inv -> form val f_pr_r : pr -> form val f_pr : memory -> xpath -> form -> ss_inv -> form +(* FIXME: Check this V *) + +val ty_ftlist1 : ty -> ty +val ty_ftlist2 : ty -> ty +val ty_flist1 : ty -> ty +val ty_flist2 : ty -> ty +val ty_lmap : ty -> ty -> ty +val ty_chunk : ty -> ty +val ty_all : ty -> ty + +val fop_empty : ty -> form +val fop_cons : ty -> form +val fop_append : ty -> form +val fop_flatten : ty -> form +val fop_lmap : ty -> ty -> form +val fop_chunk : ty -> form +val fop_all : ty -> form + +val f_append : form -> form -> ty -> form +val f_cons : form -> form -> ty -> form +val f_flatten : form -> ty -> form +val f_lmap : form -> form -> ty -> ty -> form +val f_chunk : form -> int -> ty -> form +val f_all : form -> form -> ty -> form + + (* soft-constructors - unit *) val f_tt : form @@ -272,6 +298,10 @@ val destr_int : form -> zint val destr_glob : form -> EcIdent.t * memory val destr_pvar : form -> EcTypes.prog_var * memory +val destr_cons : form -> form * form +val destr_list : form -> form list +val is_witness : form -> bool + (* -------------------------------------------------------------------- *) val is_true : form -> bool val is_false : form -> bool diff --git a/src/ecCoreGoal.ml b/src/ecCoreGoal.ml index 74ff095f5b..6c5a3024d3 100644 --- a/src/ecCoreGoal.ml +++ b/src/ecCoreGoal.ml @@ -157,6 +157,7 @@ and validation = | VRewrite of (handle * rwproofterm) (* rewrite *) | VApply of proofterm (* modus ponens *) | VShuffle of ident list (* goal shuffling *) +| VBdep (* map-reduce *) (* external (hl/phl/prhl/...) proof-node *) | VExtern : 'a * handle list -> validation diff --git a/src/ecCoreGoal.mli b/src/ecCoreGoal.mli index f574b49bf3..d045b8f935 100644 --- a/src/ecCoreGoal.mli +++ b/src/ecCoreGoal.mli @@ -155,6 +155,7 @@ type validation = | VRewrite of (handle * rwproofterm) (* rewrite *) | VApply of proofterm (* modus ponens *) | VShuffle of ident list (* goal shuffling *) +| VBdep (* map-reduce *) (* external (hl/phl/prhl/...) proof-node *) | VExtern : 'a * handle list -> validation diff --git a/src/ecCoreLib.ml b/src/ecCoreLib.ml index 6e884a4e8c..e758718b53 100644 --- a/src/ecCoreLib.ml +++ b/src/ecCoreLib.ml @@ -48,6 +48,31 @@ module CI_Bool = struct let p_eq = _Pervasive "=" end +(* -------------------------------------------------------------------- *) +module CI_List = struct + let i_List = "List" + let p_List = EcPath.pqname p_top i_List + let _List = fun x -> EcPath.pqname p_List x + let p_list = _List "list" + + let p_empty = _List "[]" + let p_cons = _List "::" + let p_head = _List "head" + let p_behead = _List "behead" + let p_tail = p_behead + let p_append = _List "++" + let p_flatten = EcPath.pqname p_List "flatten" + let p_map = _List "map" + let p_mapi = _List "mapi" + let p_chunk = EcPath.pqname (EcPath.pqname (EcPath.pqname p_top "BitEncoding") "BitChunking") "chunk" + let p_all = _List "all" + let p_nth = _List "nth" + let p_size = _List "size" + let p_mkseq = _List "mkseq" + let p_mem = _List "mem" + let p_iota = EcPath.extend p_top ["List"; "Iota"; "iota_"] +end + (* -------------------------------------------------------------------- *) module CI_Option = struct let i_Option = "Logic" @@ -83,6 +108,8 @@ module CI_Int = struct let p_int_edivz = _IntDiv "edivz" let p_int_max = _IntDiv "max" let p_iteri = EcPath.extend p_top ["Int"; "IterOp"; "iteri"] + let p_iter = EcPath.extend p_top ["Int"; "IterOp"; "iter"] + let p_fold = EcPath.extend p_top ["Int"; "fold"] end (* -------------------------------------------------------------------- *) diff --git a/src/ecCoreLib.mli b/src/ecCoreLib.mli index 49ff1a9405..79f8d07936 100644 --- a/src/ecCoreLib.mli +++ b/src/ecCoreLib.mli @@ -49,6 +49,31 @@ module CI_Option : sig val p_oget : path end + +(*-------------------------------------------------------------------- *) +module CI_List : sig + val i_List : symbol + val p_List : path + val p_list : path + + val p_empty : path + val p_cons : path + val p_head : path + val p_behead : path + val p_tail : path + val p_append : path + val p_flatten : path + val p_map : path + val p_mapi : path + val p_chunk : path + val p_all : path + val p_size : path + val p_nth : path + val p_mkseq : path + val p_mem : path + val p_iota : path + end + (*-------------------------------------------------------------------- *) module CI_Bool : sig val i_Bool : symbol @@ -84,6 +109,8 @@ module CI_Int : sig val p_int_pow : path val p_int_edivz : path val p_iteri : path + val p_iter : path + val p_fold : path end (* -------------------------------------------------------------------- *) diff --git a/src/ecDecl.ml b/src/ecDecl.ml index 5636641acc..a9d8f8fc66 100644 --- a/src/ecDecl.ml +++ b/src/ecDecl.ml @@ -16,9 +16,10 @@ type ty_params = ty_param list type ty_pctor = [ `Int of int | `Named of ty_params ] type tydecl = { - tyd_params : ty_params; - tyd_type : ty_body; - tyd_loca : locality; + tyd_params : ty_params; + tyd_type : ty_body; + tyd_loca : locality; + tyd_clinline : bool; } and ty_body = [ @@ -65,7 +66,10 @@ let abs_tydecl ?(tc = Sp.empty) ?(params = `Int 0) lc = (EcUid.NameGen.bulk ~fmt n) in - { tyd_params = params; tyd_type = `Abstract tc; tyd_loca = lc; } + { tyd_params = params + ; tyd_type = `Abstract tc + ; tyd_loca = lc + ; tyd_clinline = false } (* -------------------------------------------------------------------- *) let ty_instanciate (params : ty_params) (args : ty list) (ty : ty) = @@ -348,3 +352,77 @@ let field_equal f1 f2 = ring_equal f1.f_ring f2.f_ring && EcPath.p_equal f1.f_inv f2.f_inv && EcUtils.oall2 EcPath.p_equal f1.f_div f2.f_div + +(* -------------------------------------------------------------------- *) +type binding_size = form * (int option) + +type crb_bitstring = + { type_ : EcPath.path + ; from_ : EcPath.path + ; to_ : EcPath.path + ; ofint : EcPath.path + ; touint : EcPath.path + ; tosint : EcPath.path + ; size : binding_size + ; theory : EcPath.path } + +type crb_array = + { type_ : EcPath.path + ; get : EcPath.path + ; set : EcPath.path + ; tolist : EcPath.path + ; oflist : EcPath.path + ; size : binding_size + ; theory : EcPath.path } + +type bv_opkind = [ + | `Add of binding_size (* size *) + | `Sub of binding_size (* size *) + | `Mul of binding_size (* size *) + | `Div of binding_size * bool (* size + sign *) + | `Rem of binding_size * bool (* size + sign *) + | `Shl of binding_size (* size *) + | `Shr of binding_size * bool (* size + sign *) + | `Shls of binding_size * binding_size (* size *) + | `Shrs of binding_size * binding_size * bool (* size + sign *) + | `Rol of binding_size (* size *) + | `Rol of binding_size (* size *) + | `Ror of binding_size (* size *) + | `And of binding_size (* size *) + | `Or of binding_size (* size *) + | `Xor of binding_size (* size *) + | `Not of binding_size (* size *) + | `Opp of binding_size (* size *) + | `Lt of binding_size * bool (* size + sign *) + | `Le of binding_size * bool (* size + sign *) + | `Extend of binding_size * binding_size * bool (* size in + size out + sign *) + | `Truncate of binding_size * binding_size (* size in + size out *) + | `Extract of binding_size * binding_size (* size in + size out *) + | `Insert of binding_size * binding_size (* size in + size out *) + | `Concat of binding_size * binding_size * binding_size (* size in1 + size in2 *) + | `Init of binding_size (* size_out *) + | `Get of binding_size (* size_in *) + | `AInit of binding_size * binding_size (* arr_len + size_out *) + | `Map of binding_size * binding_size * binding_size (* size_in + size_out + arr_size *) + | `A2B of (binding_size * binding_size) * binding_size (* (arr_len, elem_sz), out_size *) + | `B2A of binding_size * (binding_size * binding_size) (* size in, (arr_len, elem_sz) *) + | `ASliceGet of (binding_size * binding_size) * binding_size (* arr_len + el_sz + sz_out *) + | `ASliceSet of (binding_size * binding_size) * binding_size (* arr_len + el_sz + sz_in *) +] + +type crb_bvoperator = + { kind : bv_opkind + ; types : EcPath.path list + ; operator : EcPath.path + ; theory : EcPath.path } + +type crb_circuit = +{ name : string +; circuit : Lospecs.Ast.adef +; operator : EcPath.path } + +type crbinding = +| CRB_Bitstring of crb_bitstring +| CRB_Array of crb_array +| CRB_BvOperator of crb_bvoperator +| CRB_Circuit of crb_circuit diff --git a/src/ecDecl.mli b/src/ecDecl.mli index 7864a0e0de..953d0ef00f 100644 --- a/src/ecDecl.mli +++ b/src/ecDecl.mli @@ -12,9 +12,10 @@ type ty_params = ty_param list type ty_pctor = [ `Int of int | `Named of ty_params ] type tydecl = { - tyd_params : ty_params; - tyd_type : ty_body; - tyd_loca : locality; + tyd_params : ty_params; + tyd_type : ty_body; + tyd_loca : locality; + tyd_clinline : bool; } and ty_body = [ @@ -198,3 +199,77 @@ type field = { f_div : EcPath.path option; } val field_equal : field -> field -> bool + +(* -------------------------------------------------------------------- *) +type binding_size = form * (int option) + +type crb_bitstring = + { type_ : EcPath.path + ; from_ : EcPath.path + ; to_ : EcPath.path + ; ofint : EcPath.path + ; touint : EcPath.path + ; tosint : EcPath.path + ; size : binding_size + ; theory : EcPath.path } + +type crb_array = + { type_ : EcPath.path + ; get : EcPath.path + ; set : EcPath.path + ; tolist : EcPath.path + ; oflist : EcPath.path + ; size : binding_size + ; theory : EcPath.path } + +type bv_opkind = [ + | `Add of binding_size (* size *) + | `Sub of binding_size (* size *) + | `Mul of binding_size (* size *) + | `Div of binding_size * bool (* size + sign *) + | `Rem of binding_size * bool (* size + sign *) + | `Shl of binding_size (* size *) + | `Shr of binding_size * bool (* size + sign *) + | `Shls of binding_size * binding_size (* size *) + | `Shrs of binding_size * binding_size * bool (* size + sign *) + | `Rol of binding_size (* size *) + | `Rol of binding_size (* size *) + | `Ror of binding_size (* size *) + | `And of binding_size (* size *) + | `Or of binding_size (* size *) + | `Xor of binding_size (* size *) + | `Not of binding_size (* size *) + | `Opp of binding_size (* size *) + | `Lt of binding_size * bool (* size + sign *) + | `Le of binding_size * bool (* size + sign *) + | `Extend of binding_size * binding_size * bool (* size in + size out + sign *) + | `Truncate of binding_size * binding_size (* size in + size out *) + | `Extract of binding_size * binding_size (* size in + size out *) + | `Insert of binding_size * binding_size (* size in + size out *) + | `Concat of binding_size * binding_size * binding_size (* size in1 + size in2 *) + | `Init of binding_size (* size_out *) + | `Get of binding_size (* size_in *) + | `AInit of binding_size * binding_size (* arr_len + size_out *) + | `Map of binding_size * binding_size * binding_size (* size_in + size_out + arr_size *) + | `A2B of (binding_size * binding_size) * binding_size (* (arr_len, elem_sz), out_size *) + | `B2A of binding_size * (binding_size * binding_size) (* size in, (arr_len, elem_sz) *) + | `ASliceGet of (binding_size * binding_size) * binding_size (* arr_len + el_sz + sz_out *) + | `ASliceSet of (binding_size * binding_size) * binding_size (* arr_len + el_sz + sz_in *) +] + +type crb_bvoperator = + { kind : bv_opkind + ; types : EcPath.path list + ; operator : EcPath.path + ; theory : EcPath.path } + +type crb_circuit = +{ name : string +; circuit : Lospecs.Ast.adef +; operator : EcPath.path } + +type crbinding = +| CRB_Bitstring of crb_bitstring +| CRB_Array of crb_array +| CRB_BvOperator of crb_bvoperator +| CRB_Circuit of crb_circuit diff --git a/src/ecEnv.ml b/src/ecEnv.ml index a4a5c8a7ca..b13d7623e6 100644 --- a/src/ecEnv.ml +++ b/src/ecEnv.ml @@ -27,7 +27,6 @@ type 'a suspension = { sp_params : int * (EcIdent.t * module_type) list; } - (* -------------------------------------------------------------------- *) let check_not_suspended (params, obj) = if not (List.for_all (fun x -> x = None) params) then @@ -172,6 +171,35 @@ type actmem = [ | `TS of EcMemory.memory * EcMemory.memory ] +(* -------------------------------------------------------------------- *) +type crb_tyrev_binding = [ + | `Bitstring of crb_bitstring + | `Array of crb_array +] + +type crb_bitstring_operator = crb_bitstring * [`From | `To | `OfInt | `ToUInt | `ToSInt ] + +type crb_array_operator = crb_array * [`Get | `Set | `ToList | `OfList] + +type crb_oprev_binding = [ + | `Bitstring of crb_bitstring_operator + | `Array of crb_array_operator + | `BvOperator of crb_bvoperator + | `Circuit of crb_circuit +] + +type crb_tyrev_map = crb_tyrev_binding list Mp.t +type crb_oprev_map = crb_oprev_binding list Mp.t + +type crbindings = { + bitstrings : crb_bitstring Mp.t; + arrays : crb_array Mp.t; + bvoperators : crb_bvoperator Mp.t; + circuits : crb_circuit Mp.t; + opreverse : crb_oprev_map; + tyreverse : crb_tyrev_map; +} + (* -------------------------------------------------------------------- *) type preenv = { env_top : EcPath.path option; @@ -193,6 +221,7 @@ type preenv = { env_modlcs : Sid.t; (* declared modules *) env_item : theory_item list; (* in reverse order *) env_norm : env_norm ref; + env_crbds : crbindings; (* Map theory paths to their env before just before theory was closed. *) (* The environment should be incuded for all theories, including *) (* abstract ones. The purpose of this map is to simplify the code *) @@ -303,6 +332,14 @@ let empty gstate = let icomps = MMsym.add name (IPPath path) MMsym.empty in { (empty_mc None) with mc_components = icomps } in + let empty_crbindings : crbindings = + { bitstrings = Mp.empty + ; arrays = Mp.empty + ; bvoperators = Mp.empty + ; circuits = Mp.empty + ; opreverse = Mp.empty + ; tyreverse = Mp.empty } in + { env_top = None; env_gstate = gstate; env_scope = { ec_path = path; ec_scope = `Theory; }; @@ -321,7 +358,8 @@ let empty gstate = env_albase = Mp.empty; env_modlcs = Sid.empty; env_item = []; - env_norm = ref empty_norm_cache; + env_norm = ref empty_norm_cache; + env_crbds = empty_crbindings; env_thenvs = Mp.empty; } (* -------------------------------------------------------------------- *) @@ -1116,10 +1154,12 @@ module MC = struct | Th_alias _ -> (* FIXME:ALIAS *) (mc, None) - - | Th_export _ | Th_addrw _ | Th_instance _ - | Th_auto _ | Th_reduction _ -> - (mc, None) + | Th_export _ + | Th_addrw _ + | Th_instance _ + | Th_auto _ + | Th_reduction _ + | Th_crbinding _ -> (mc, None) in let (mc, submcs) = @@ -2844,774 +2884,987 @@ module Algebra = struct end (* -------------------------------------------------------------------- *) -module Theory = struct - type t = ctheory - type mode = [`All | thmode] +let initial gstate = empty gstate - type compiled = env Mp.t +(* -------------------------------------------------------------------- *) +type ebinding = [ + | `Variable of EcTypes.ty + | `Function of function_ + | `Module of module_expr + | `ModType of module_sig +] - type compiled_theory = { - name : symbol; - ctheory : t; - compiled : compiled; - } +(* FIXME section : Global ? *) +let bind1 ((x, eb) : symbol * ebinding) (env : env) = + match eb with + | `Variable ty -> Var .bind_pvglob x ty env + | `Function f -> Fun .bind x f env + | `Module m -> Mod .bind x {tme_expr = m; tme_loca = `Global} env + | `ModType i -> ModTy .bind x {tms_sig = i; tms_loca = `Global} env - (* ------------------------------------------------------------------ *) - let enter name env = - enter `Theory name env +let bindall (items : (symbol * ebinding) list) (env : env) = + List.fold_left ((^~) bind1) env items - (* ------------------------------------------------------------------ *) - let by_path_opt ?(mode = `All)(p : EcPath.path) (env : env) = - let obj = - match MC.by_path (fun mc -> mc.mc_theories) (IPPath p) env, mode with - | (Some (_, {cth_mode = `Concrete })) as obj, (`All | `Concrete) -> obj - | (Some (_, {cth_mode = `Abstract })) as obj, (`All | `Abstract) -> obj - | _, _ -> None +(* -------------------------------------------------------------------- *) +module LDecl = struct + type error = + | InvalidKind of EcIdent.t * [`Variable | `Hypothesis] + | CannotClear of EcIdent.t * EcIdent.t + | NameClash of [`Ident of EcIdent.t | `Symbol of symbol] + | LookupError of [`Ident of EcIdent.t | `Symbol of symbol] - in omap check_not_suspended obj + exception LdeclError of error - let by_path ?mode (p : EcPath.path) (env : env) = - match by_path_opt ?mode p env with - | None -> lookup_error (`Path p) - | Some obj -> obj + let pp_error fmt (exn : error) = + match exn with + | LookupError (`Symbol s) -> + Format.fprintf fmt "unknown symbol %s" s - let add (p : EcPath.path) (env : env) = - let obj = by_path p env in - MC.import_theory p obj env + | NameClash (`Symbol s) -> + Format.fprintf fmt + "an hypothesis or variable named `%s` already exists" s - let lookup ?(mode = `Concrete) qname (env : env) = - match MC.lookup_theory qname env, mode with - | (_, { cth_mode = `Concrete }) as obj, (`All | `Concrete) -> obj - | (_, { cth_mode = `Abstract }) as obj, (`All | `Abstract) -> obj - | _ -> lookup_error (`QSymbol qname) + | InvalidKind (x, `Variable) -> + Format.fprintf fmt "`%s` is not a variable" (EcIdent.name x) - let lookup_opt ?mode name env = - try_lf (fun () -> lookup ?mode name env) + | InvalidKind (x, `Hypothesis) -> + Format.fprintf fmt "`%s` is not an hypothesis" (EcIdent.name x) - let lookup_path ?mode name env = - fst (lookup ?mode name env) + | CannotClear (id1,id2) -> + Format.fprintf fmt "cannot clear %s as it is used in %s" + (EcIdent.name id1) (EcIdent.name id2) - (* ------------------------------------------------------------------ *) - let env_of_theory (p : EcPath.path) (env : env) = - if EcPath.isprefix ~prefix:p ~path:env.env_scope.ec_path then - env - else - Option.get (Mp.find_opt p env.env_thenvs) + | LookupError (`Ident id) -> + Format.fprintf fmt "unknown identifier `%s`, please report" + (EcIdent.tostring_internal id) - (* ------------------------------------------------------------------ *) - let rebind_alias (name : symbol) (path : path) (env : env) = - let th = by_path path env in - let src = EcPath.pqname (root env) name in - let env = MC.import_theory ~name path th env in - let env = MC.import_mc ~name (IPPath path) env in - let env = { env with env_albase = Mp.add path src env.env_albase } in - env + | NameClash (`Ident id) -> + Format.fprintf fmt "name clash for `%s`, please report" + (EcIdent.tostring_internal id) - (* ------------------------------------------------------------------ *) - let alias ?(import = true) (name : symbol) (path : path) (env : env) = - let env = if import then rebind_alias name path env else env in - { env with env_item = mkitem ~import (Th_alias (name, path)) :: env.env_item } + let _ = EcPException.register (fun fmt exn -> + match exn with + | LdeclError e -> pp_error fmt e + | _ -> raise exn) - (* ------------------------------------------------------------------ *) - let aliases (env : env) = - env.env_albase + let error e = raise (LdeclError e) (* ------------------------------------------------------------------ *) - let rec bind_instance_th path inst cth = - List.fold_left (bind_instance_th_item path) inst cth - - and bind_instance_th_item path inst item = - if not item.ti_import then inst else + let ld_subst s ld = + match ld with + | LD_var (ty, body) -> + LD_var (ty_subst s ty, body |> omap (Fsubst.f_subst s)) - let xpath x = EcPath.pqname path x in + | LD_mem mt -> + LD_mem (EcMemory.mt_subst (ty_subst s) mt) - match item.ti_item with - | Th_instance (ty, k, _) -> - TypeClass.bind_instance ty k inst + | LD_modty mty -> + LD_modty (Fsubst.mty_mr_subst s mty) - | Th_theory (x, cth) when cth.cth_mode = `Concrete -> - bind_instance_th (xpath x) inst cth.cth_items + | LD_hyp f -> + LD_hyp (Fsubst.f_subst s f) - | Th_type (x, tyd) -> begin - match tyd.tyd_type with - | `Abstract tc -> - let myty = - let typ = List.map (fst_map EcIdent.fresh) tyd.tyd_params in - (typ, EcTypes.tconstr (xpath x) (List.map (tvar |- fst) typ)) - in - Sp.fold - (fun p inst -> TypeClass.bind_instance myty (`General p) inst) - tc inst + | LD_abs_st _ -> (* FIXME *) + assert false - | _ -> inst - end + (* ------------------------------------------------------------------ *) + let ld_fv = function + | LD_var (ty, None) -> + ty.ty_fv + | LD_var (ty,Some f) -> + EcIdent.fv_union ty.ty_fv f.f_fv + | LD_mem mt -> + EcMemory.mt_fv mt + | LD_hyp f -> + f.f_fv + | LD_modty p -> + gty_fv (GTmodty p) + | LD_abs_st us -> + let add fv (x,_) = match x with + | PVglob x -> EcPath.x_fv fv x + | PVloc _ -> fv in - | _ -> inst + let fv = Mid.empty in + let fv = List.fold_left add fv us.aus_reads in + let fv = List.fold_left add fv us.aus_writes in + List.fold_left EcPath.x_fv fv us.aus_calls (* ------------------------------------------------------------------ *) - let rec bind_base_th tx path base cth = - List.fold_left (bind_base_th_item tx path) base cth + let by_name s hyps = + match List.ofind ((=) s |- EcIdent.name |- fst) hyps.h_local with + | None -> error (LookupError (`Symbol s)) + | Some h -> h - and bind_base_th_item tx path base item = - if not item.ti_import then base else + let by_id id hyps = + match List.ofind (EcIdent.id_equal id |- fst) hyps.h_local with + | None -> error (LookupError (`Ident id)) + | Some x -> snd x - let xpath x = EcPath.pqname path x in + (* ------------------------------------------------------------------ *) + let as_hyp = function + | (id, LD_hyp f) -> (id, f) + | (id, _) -> error (InvalidKind (id, `Hypothesis)) - match item.ti_item with - | Th_theory (x, cth) -> begin - match cth.cth_mode with - | `Concrete -> - bind_base_th tx (xpath x) base cth.cth_items - | `Abstract -> base - end - | _ -> odfl base (tx path base item.ti_item) + let as_var = function + | (id, LD_var (ty, _)) -> (id, ty) + | (id, _) -> error (InvalidKind (id, `Variable)) (* ------------------------------------------------------------------ *) - let bind_tc_th = - let for1 path base = function - | Th_typeclass (x, tc) -> - tc.tc_prt |> omap (fun prt -> - let src = EcPath.pqname path x in - TC.Graph.add ~src ~dst:prt base) - | _ -> None + let hyp_by_name s hyps = as_hyp (by_name s hyps) + let var_by_name s hyps = as_var (by_name s hyps) - in bind_base_th for1 + (* ------------------------------------------------------------------ *) + let hyp_by_id x hyps = as_hyp (x, by_id x hyps) + let var_by_id x hyps = as_var (x, by_id x hyps) (* ------------------------------------------------------------------ *) - let bind_br_th = - let for1 path base = function - | Th_baserw (x,_) -> - let ip = IPPath (EcPath.pqname path x) in - assert (not (Mip.mem ip base)); - Some (Mip.add ip Sp.empty base) + let has_gen dcast s hyps = + try ignore (dcast (by_name s hyps)); true + with LdeclError (InvalidKind _ | LookupError _) -> false - | Th_addrw (b, r, _) -> - let change = function - | None -> assert false - | Some s -> Some (List.fold_left (fun s r -> Sp.add r s) s r) + let hyp_exists s hyps = has_gen as_hyp s hyps + let var_exists s hyps = has_gen as_var s hyps - in Some (Mip.change change (IPPath b) base) + (* ------------------------------------------------------------------ *) + let has_id x hyps = + try ignore (by_id x hyps); true + with LdeclError (LookupError _) -> false - | _ -> None + let has_inld s = function + | LD_mem mt -> is_bound s mt + | _ -> false - in bind_base_th for1 + let has_name ?(dep = false) s hyps = + let test (id, k) = + EcIdent.name id = s || (dep && has_inld s k) + in List.exists test hyps.h_local (* ------------------------------------------------------------------ *) - let bind_at_th = - let for1 _path db = function - | Th_auto {level; base; axioms; _} -> - Some (Auto.updatedb ?base ~level axioms db) - | _ -> None - - in bind_base_th for1 + let can_unfold id hyps = + try match by_id id hyps with LD_var (_, Some _) -> true | _ -> false + with LdeclError _ -> false - (* ------------------------------------------------------------------ *) - let bind_nt_th = - let for1 path base = function - | Th_operator (x, ({ op_kind = OB_nott _ } as op)) -> - Some (Op.update_ntbase path (x, op) base) - | _ -> None + let unfold id hyps = + try + match by_id id hyps with + | LD_var (_, Some f) -> f + | _ -> raise NotReducible + with LdeclError _ -> raise NotReducible - in bind_base_th for1 + (* ------------------------------------------------------------------ *) + let check_name_clash id hyps = + if has_id id hyps + then error (NameClash (`Ident id)) + else + let s = EcIdent.name id in + if s <> "_" && has_name ~dep:false s hyps then + error (NameClash (`Symbol s)) + + let add_local id ld hyps = + check_name_clash id hyps; + { hyps with h_local = (id, ld) :: hyps.h_local } (* ------------------------------------------------------------------ *) - let bind_rd_th = - let for1 _path db = function - | Th_reduction rules -> - let rules = List.map (fun (x, _, y) -> (x, y)) rules in - Some (Reduction.add_rules rules db) - | _ -> None + let fresh_id hyps s = + let s = + if s = "_" || not (has_name ~dep:true s hyps) + then s + else + let rec aux n = + let s = Printf.sprintf "%s%d" s n in + if has_name ~dep:true s hyps then aux (n+1) else s + in aux 0 - in bind_base_th for1 + in EcIdent.create s + + let fresh_ids hyps names = + let do1 hyps s = + let id = fresh_id hyps s in + (add_local id (LD_var (tbool, None)) hyps, id) + in List.map_fold do1 hyps names (* ------------------------------------------------------------------ *) - let add_restr_th = - let for1 path env = function - | Th_module me -> Some (Mod.add_restr_to_declared path me env) - | _ -> None - in bind_base_th for1 + type hyps = { + le_init : env; + le_env : env; + le_hyps : EcBaseLogic.hyps; + } + + let tohyps lenv = lenv.le_hyps + let toenv lenv = lenv.le_env + let baseenv lenv = lenv.le_init + + let add_local_env x k env = + match k with + | LD_var (ty, _) -> Var.bind_local x ty env + | LD_mem mt -> Memory.push (x, mt) env + | LD_modty i -> Mod.bind_local x i env + | LD_hyp _ -> env + | LD_abs_st us -> AbsStmt.bind x us env (* ------------------------------------------------------------------ *) - let bind - ?(import = true) - (cth : compiled_theory) - (env : env) - = - let { cth_items = items; cth_mode = mode; } = cth.ctheory in - let env = MC.bind_theory cth.name cth.ctheory env in - let env = { - env with - env_item = mkitem ~import (Th_theory (cth.name, cth.ctheory)) :: env.env_item } + let add_local x k h = + let le_hyps = add_local x k (tohyps h) in + let le_env = add_local_env x k h.le_env in + { h with le_hyps; le_env; } + + (* ------------------------------------------------------------------ *) + let init env ?(locals = []) tparams = + let buildenv env = + List.fold_right + (fun (x, k) env -> add_local_env x k env) + locals env in - let env = - match import, mode with - | _, `Concrete -> - let thname = EcPath.pqname (root env) cth.name in - let env_tci = bind_instance_th thname env.env_tci items in - let env_tc = bind_tc_th thname env.env_tc items in - let env_rwbase = bind_br_th thname env.env_rwbase items in - let env_atbase = bind_at_th thname env.env_atbase items in - let env_ntbase = bind_nt_th thname env.env_ntbase items in - let env_redbase = bind_rd_th thname env.env_redbase items in - let env = - { env with - env_tci ; env_tc ; env_rwbase; - env_atbase; env_ntbase; env_redbase; } - in - add_restr_th thname env items + { le_init = env; + le_env = buildenv env; + le_hyps = { h_tvar = tparams; h_local = locals; }; } - | _, _ -> - env + (* ------------------------------------------------------------------ *) + let clear ?(leniant = false) ids hyps = + let rec filter ids hyps = + match hyps with [] -> [] | ((id, lk) as bd) :: hyps -> + + let ids, bd = + if EcIdent.Sid.mem id ids then (ids, None) else + + let fv = ld_fv lk in + + if Mid.set_disjoint ids fv then + (ids, Some bd) + else + if leniant then + (Mid.set_diff ids fv, Some bd) + else + let inter = Mid.set_inter ids fv in + error (CannotClear (Sid.choose inter, id)) + in List.ocons bd (filter ids hyps) in - { env with - env_thenvs = Mp.set_union env.env_thenvs cth.compiled } + let locals = filter ids hyps.le_hyps.h_local in - (* ------------------------------------------------------------------ *) - let rebind name th env = - MC.bind_theory name th env + init hyps.le_init ~locals hyps.le_hyps.h_tvar (* ------------------------------------------------------------------ *) - let import (path : EcPath.path) (env : env) = - let rec import (env : env) path (th : theory) = - let xpath x = EcPath.pqname path x in - let import_th_item (env : env) (item : theory_item) = - if not item.ti_import then env else + let hyp_convert x check hyps = + let module E = struct exception NoOp end in - match item.ti_item with - | Th_type (x, ty) -> - MC.import_tydecl (xpath x) ty env + let init locals = init hyps.le_init ~locals hyps.le_hyps.h_tvar in - | Th_operator (x, op) -> - MC.import_operator (xpath x) op env + let rec doit locals = + match locals with + | (y, LD_hyp fp) :: locals when EcIdent.id_equal x y -> begin + let fp' = check (lazy (init locals)) fp in + if fp == fp' then raise E.NoOp else (x, LD_hyp fp') :: locals + end - | Th_axiom (x, ax) -> - MC.import_axiom (xpath x) ax env + | [] -> error (LookupError (`Ident x)) + | ld :: locals -> ld :: (doit locals) - | Th_modtype (x, mty) -> - MC.import_modty (xpath x) mty env + in (try Some (doit hyps.le_hyps.h_local) with E.NoOp -> None) |> omap init - | Th_module ({ tme_expr = me; tme_loca = lc; }) -> - let env = MC.import_mod (IPPath (xpath me.me_name)) (me, Some lc) env in - let env = MC.import_mc (IPPath (xpath me.me_name)) env in - env + (* ------------------------------------------------------------------ *) + let local_hyps x hyps = + let rec doit locals = + match locals with + | (y, _) :: locals -> + if EcIdent.id_equal x y then locals else doit locals + | [] -> + error (LookupError (`Ident x)) in - | Th_export (p, _) -> - import env p (by_path ~mode:`Concrete p env).cth_items + let locals = doit hyps.le_hyps.h_local in + init hyps.le_init ~locals hyps.le_hyps.h_tvar - | Th_theory (x, ({cth_mode = `Concrete} as th)) -> - let env = MC.import_theory (xpath x) th env in - let env = MC.import_mc (IPPath (xpath x)) env in - env + (* ------------------------------------------------------------------ *) + let by_name s hyps = by_name s (tohyps hyps) + let by_id x hyps = by_id x (tohyps hyps) - | Th_theory (x, ({cth_mode = `Abstract} as th)) -> - MC.import_theory (xpath x) th env + let has_name s hyps = has_name ~dep:false s (tohyps hyps) + let has_id x hyps = has_id x (tohyps hyps) - | Th_typeclass (x, tc) -> - MC.import_typeclass (xpath x) tc env + let hyp_by_name s hyps = hyp_by_name s (tohyps hyps) + let hyp_exists s hyps = hyp_exists s (tohyps hyps) + let hyp_by_id x hyps = snd (hyp_by_id x (tohyps hyps)) - | Th_baserw (x, _) -> - MC.import_rwbase (xpath x) env + let var_by_name s hyps = var_by_name s (tohyps hyps) + let var_exists s hyps = var_exists s (tohyps hyps) + let var_by_id x hyps = snd (var_by_id x (tohyps hyps)) - | Th_alias (name, path) -> - rebind_alias name path env + let can_unfold x hyps = can_unfold x (tohyps hyps) + let unfold x hyps = unfold x (tohyps hyps) - | Th_addrw _ | Th_instance _ | Th_auto _ | Th_reduction _ -> - env + let fresh_id hyps s = fresh_id (tohyps hyps) s + let fresh_ids hyps s = snd (fresh_ids (tohyps hyps) s) - in - List.fold_left import_th_item env th + (* ------------------------------------------------------------------ *) + let push_active_ss m lenv = + { lenv with le_env = Memory.push_active_ss m lenv.le_env } - in - import env path (by_path ~mode:`Concrete path env).cth_items + let push_active_ts ml mr lenv = + { lenv with le_env = Memory.push_active_ts ml mr lenv.le_env } - (* ------------------------------------------------------------------ *) - let export (path : EcPath.path) lc (env : env) = - let env = import path env in - { env with env_item = mkitem ~import:true (Th_export (path, lc)) :: env.env_item } + let push_all l lenv = + { lenv with le_env = Memory.push_all l lenv.le_env } - (* ------------------------------------------------------------------ *) - let rec filter clears root cleared items = - snd_map (List.pmap identity) - (List.map_fold (filter1 clears root) cleared items) + let hoareF mem xp lenv = + let env1, env2 = Fun.hoareF mem xp lenv.le_env in + { lenv with le_env = env1}, {lenv with le_env = env2 } - and filter_th clears root cleared items = - let mempty = List.exists (EcPath.p_equal root) clears in - let cleared, items = filter clears root cleared items in + let equivF ml mr xp1 xp2 lenv = + let env1, env2 = Fun.equivF ml mr xp1 xp2 lenv.le_env in + { lenv with le_env = env1}, {lenv with le_env = env2 } - if mempty && List.is_empty items - then (Sp.add root cleared, None) - else (cleared, Some items) + let inv_memenv ml mr lenv = + { lenv with le_env = Fun.inv_memenv ml mr lenv.le_env } - and filter1 clears root = - let inclear p = List.exists (EcPath.p_equal p) clears in - let thclear = inclear root in + let inv_memenv1 m lenv = + { lenv with le_env = Fun.inv_memenv1 m lenv.le_env } +end - fun cleared item -> - let cleared, item_r = - match item.ti_item with - | Th_theory (x, cth) -> - let cleared, items = - let xpath = EcPath.pqname root x in - filter_th clears xpath cleared cth.cth_items in - let item = items |> omap (fun items -> - let cth = { cth with cth_items = items } in - Th_theory (x, cth)) in - (cleared, item) +(* -------------------------------------------------------------------- *) +module Circuit = struct + let push_tyreverse (reverse : crb_tyrev_map) (p : path) (v : crb_tyrev_binding) = + Mp.change + (fun vs -> Some (v :: Option.value ~default:[] vs)) + p reverse + + let push_all_tyreverse (reverse : crb_tyrev_map) (pvs : (path * crb_tyrev_binding) list) = + List.fold_left (fun rv (p, v) -> push_tyreverse rv p v) reverse pvs + + let push_opreverse (reverse : crb_oprev_map) (p : path) (v : crb_oprev_binding) = + Mp.change + (fun vs -> Some (v :: Option.value ~default:[] vs)) + p reverse + + let push_all_opreverse (reverse : crb_oprev_map) (pvs : (path * crb_oprev_binding) list) = + List.fold_left (fun rv (p, v) -> push_opreverse rv p v) reverse pvs + + let rebind_bitstring_ (bs : crb_bitstring) (bindings : crbindings) = + { bindings with + bitstrings = Mp.add bs.type_ bs bindings.bitstrings; + tyreverse = push_tyreverse bindings.tyreverse bs.type_ (`Bitstring bs); + opreverse = + push_all_opreverse + bindings.opreverse + [ (bs.from_, `Bitstring (bs, `From )) + ; (bs.to_ , `Bitstring (bs, `To )) + ; (bs.touint, `Bitstring (bs, `ToUInt)) + ; (bs.tosint, `Bitstring (bs, `ToSInt)) + ; (bs.ofint, `Bitstring (bs, `OfInt)) ]; } + + let rebind_bitstring (bs : crb_bitstring) (env : env) : env = + { env with env_crbds = rebind_bitstring_ bs env.env_crbds } + + let bind_bitstring ?(import = true) (lc : is_local) (bs : crb_bitstring) (env : env) = + let env = if import then rebind_bitstring bs env else env in + { env with env_item = + mkitem ~import (Th_crbinding (CRB_Bitstring bs, lc)) :: env.env_item; } + + let rebind_array_ (ba : crb_array) (bindings : crbindings) = + { bindings with + arrays = Mp.add ba.type_ ba bindings.arrays; + tyreverse = push_tyreverse bindings.tyreverse ba.type_ (`Array ba); + opreverse = + push_all_opreverse + bindings.opreverse + [ (ba.set , `Array (ba, `Set)) + ; (ba.get , `Array (ba, `Get)) + ; (ba.tolist, `Array (ba, `ToList)) + ; (ba.oflist, `Array (ba, `OfList)) ]} + + let rebind_array (ba : crb_array) (env : env) : env = + { env with env_crbds = rebind_array_ ba env.env_crbds } + + let bind_array ?(import = true) (lc : is_local) (ba : crb_array) (env : env) = + let env = if import then rebind_array ba env else env in + { env with env_item = + mkitem ~import (Th_crbinding (CRB_Array ba, lc)) :: env.env_item; } - | _ -> let item_r = match item.ti_item with + let rebind_bvoperator_ (op : crb_bvoperator) (bindings : crbindings) = + { bindings with + bvoperators = Mp.add op.operator op bindings.bvoperators; + opreverse = push_opreverse bindings.opreverse op.operator (`BvOperator op); } - | Th_axiom (_, { ax_kind = `Lemma }) when thclear -> - None + let rebind_bvoperator (op : crb_bvoperator) (env : env) = + { env with env_crbds = rebind_bvoperator_ op env.env_crbds } - | Th_axiom (x, ({ ax_kind = `Axiom (tags, false) } as ax)) when thclear -> - Some (Th_axiom (x, { ax with ax_kind = `Axiom (tags, true) })) + let bind_bvoperator ?(import = true) (lc : is_local) (op : crb_bvoperator) (env : env) = + let env = if import then rebind_bvoperator op env else env in + { env with env_item = + mkitem ~import (Th_crbinding (CRB_BvOperator op, lc)) :: env.env_item; } + + let rebind_circuit_ (cr : crb_circuit) (bindings : crbindings) = + { bindings with + circuits = Mp.add cr.operator cr bindings.circuits; + opreverse = push_opreverse bindings.opreverse cr.operator (`Circuit cr); } + + let rebind_circuit (cr : crb_circuit) (env : env) = + { env with env_crbds = rebind_circuit_ cr env.env_crbds } + + let bind_circuit ?(import = true) (lc : is_local) (cr : crb_circuit) (env : env) = + let env = if import then rebind_circuit cr env else env in + { env with env_item = + mkitem ~import (Th_crbinding (CRB_Circuit cr, lc)) :: env.env_item; } + + let bind_crbinding ?import (lc : is_local) (crb : crbinding) (env : env) = + match crb with + | CRB_Bitstring bs -> bind_bitstring ?import lc bs env + | CRB_Array ba -> bind_array ?import lc ba env + | CRB_BvOperator op -> bind_bvoperator ?import lc op env + | CRB_Circuit cr -> bind_circuit ?import lc cr env + + let rec lookup_bitstring_path (env : env) (k : path) : crb_bitstring option = +(* Format.eprintf "Looking up bitstring binding for type with path %s@." (EcPath.tostring k); *) + let k, _ = Ty.lookup (EcPath.toqsymbol k) (env) in + match Mp.find_opt k env.env_crbds.bitstrings with + | Some _ as bs -> bs + | None -> try lookup_bitstring env (Ty.unfold k [] env) + with LookupFailure _ -> None + + and lookup_bitstring (env : env) (ty : ty) : crb_bitstring option = + match ty.ty_node with + | Tconstr (p, []) -> lookup_bitstring_path env p + | _ -> None + + let lookup_bitstring_size_path (env : env) (pth : path) : int option = + Option.bind (Option.map (fun (c : crb_bitstring) -> c.size) (lookup_bitstring_path env pth)) snd + + let lookup_circuit_path (env : env) (v : path) : Lospecs.Ast.adef option = + Mp.find_opt v env.env_crbds.circuits + |> Option.map (fun cr -> cr.circuit) + + let lookup_bitstring_size (env : env) (ty : ty) : int option = + Option.bind (Option.map (fun (c : crb_bitstring) -> c.size) (lookup_bitstring env ty)) snd + + let rec lookup_array_path (env : env) (pth : path) : crb_array option = + let k, _ = Ty.lookup (EcPath.toqsymbol pth) (env) in + match Mp.find_opt k env.env_crbds.arrays with + | Some arr -> Some arr + | None -> try + lookup_array env (Ty.unfold pth [] env) + with LookupFailure e -> None + + and lookup_array (env : env) (ty : ty) : crb_array option = + match ty.ty_node with + | Tconstr (p, [w]) -> lookup_array_path env p + | _ -> None - | Th_addrw (p, ps, lc) -> - let ps = List.filter ((not) |- inclear |- oget |- EcPath.prefix) ps in - if List.is_empty ps then None else Some (Th_addrw (p, ps,lc)) + let rec lookup_array_and_bitstring (env: env) (ty: ty) : (crb_array * crb_bitstring) option = + match ty.ty_node with + | Tconstr (p, [w]) -> +(* Format.eprintf "Unfolding parametric type with path %s@." (EcPath.tostring p); *) + let arr = lookup_array_path env p in + let bs = lookup_bitstring env w in + begin match arr, bs with + | Some arr, Some bs -> Some (arr, bs) + | _ -> None + end + | Tconstr (p, []) -> +(* Format.eprintf "Unfolding non parametric type with path %s@." (EcPath.tostring p); *) + (try + lookup_array_and_bitstring env (Ty.unfold p [] env) + with LookupFailure _ -> None) + | _ -> None - | Th_auto ({ axioms } as auto_rl) -> - let axioms = List.filter (fun (p, _) -> - let p = oget (EcPath.prefix p) in - not (inclear p) - ) axioms in - if List.is_empty axioms then None else Some (Th_auto {auto_rl with axioms}) + let lookup_array_size (env : env) (ty : ty) : int option = + Option.bind (Option.map (fun c -> c.size) (lookup_array env ty)) snd - | (Th_export (p, _)) as item -> - if Sp.mem p cleared then None else Some item + let lookup_bvoperator_path (env : env) (v : path) : crb_bvoperator option = + Mp.find_opt v env.env_crbds.bvoperators - | _ as item -> Some item + let lookup_bvoperator (env : env) (o : qsymbol) : crb_bvoperator option = + let p, _o = Op.lookup o env in + lookup_bvoperator_path env p - in (cleared, item_r) + let lookup_circuit (env : env) (o : qsymbol) : Lospecs.Ast.adef option = + let p, _o = Op.lookup o env in + lookup_circuit_path env p - in (cleared, omap (fun item_r -> { item with ti_item = item_r; }) item_r) + let reverse_type (env : env) (p : path) : crb_tyrev_binding list = + Mp.find_def [] p env.env_crbds.tyreverse - (* ------------------------------------------------------------------ *) - type clear_mode = [`Full | `ClearOnly | `No] + let reverse_operator (env: env) (p : path) : crb_oprev_binding list = + Mp.find_def [] p env.env_crbds.opreverse - let close - ?(clears : path list = []) - ?(pempty : clear_mode = `No) - (loca : is_local) - (mode : thmode) - (env : env) + let reverse_and_filter_operator + ~(filter : crb_oprev_binding -> 'a option) (env : env) (p : path) = - let items = List.rev env.env_item in - let items = - if List.is_empty clears - then (if List.is_empty items then None else Some items) - else snd (filter_th clears (root env) Sp.empty items) in + List.find_map_opt filter (reverse_operator env p) - let items = - match items, pempty with - | None, (`No | `ClearOnly) -> Some [] - | _, _ -> items - in + let reverse_bitstring_operator = + reverse_and_filter_operator + ~filter:(function `Bitstring x -> Some x | _ -> None) - items |> omap (fun items -> - let ctheory = - { cth_items = items - ; cth_source = None - ; cth_loca = loca - ; cth_mode = mode - } in + let reverse_array_operator = + reverse_and_filter_operator + ~filter:(function `Array x -> Some x | _ -> None) - let root = env.env_scope.ec_path in - let name = EcPath.basename root in + let reverse_bvoperator = + reverse_and_filter_operator + ~filter:(function `BvOperator x -> Some x | _ -> None) - let compiled = - Mp.filter - (fun path _ -> EcPath.isprefix ~prefix:root ~path) - env.env_thenvs in - let compiled = Mp.add env.env_scope.ec_path env compiled in + let reverse_circuit = + reverse_and_filter_operator + ~filter:(function `Circuit x -> Some x | _ -> None) - { name; ctheory; compiled; } - ) + let get_specification_by_name (env : env) ~(filename : string) (name : symbol) : Lospecs.Ast.adef option = + let specs = Lospecs.Circuit_spec.load_from_file ~filename in + List.Exceptionless.assoc name specs +end + +(* -------------------------------------------------------------------- *) +module Theory = struct + type t = ctheory + type mode = [`All | thmode] + + type compiled = env Mp.t + + type compiled_theory = { + name : symbol; + ctheory : t; + compiled : compiled; + } (* ------------------------------------------------------------------ *) - let require (compiled : compiled_theory) (env : env) = - let cth = compiled.ctheory in - let rootnm = EcCoreLib.p_top in - let thpath = EcPath.pqname rootnm compiled.name in + let enter name env = + enter `Theory name env - let env = - match cth.cth_mode with - | `Concrete -> - let (_, thmc), submcs = - MC.mc_of_theory_r rootnm (compiled.name, cth) - in MC.bind_submc env rootnm ((compiled.name, thmc), submcs) + (* ------------------------------------------------------------------ *) + let by_path_opt ?(mode = `All)(p : EcPath.path) (env : env) = + let obj = + match MC.by_path (fun mc -> mc.mc_theories) (IPPath p) env, mode with + | (Some (_, {cth_mode = `Concrete })) as obj, (`All | `Concrete) -> obj + | (Some (_, {cth_mode = `Abstract })) as obj, (`All | `Abstract) -> obj + | _, _ -> None - | `Abstract -> env - in + in omap check_not_suspended obj - let topmc = Mip.find (IPPath rootnm) env.env_comps in - let topmc = MC._up_theory false topmc compiled.name (IPPath thpath, cth) in - let topmc = MC._up_mc false topmc (IPPath thpath) in + let by_path ?mode (p : EcPath.path) (env : env) = + match by_path_opt ?mode p env with + | None -> lookup_error (`Path p) + | Some obj -> obj - let current = env.env_current in - let current = MC._up_theory true current compiled.name (IPPath thpath, cth) in - let current = MC._up_mc true current (IPPath thpath) in + let add (p : EcPath.path) (env : env) = + let obj = by_path p env in + MC.import_theory p obj env - let comps = env.env_comps in - let comps = Mip.add (IPPath rootnm) topmc comps in + let lookup ?(mode = `Concrete) qname (env : env) = + match MC.lookup_theory qname env, mode with + | (_, { cth_mode = `Concrete }) as obj, (`All | `Concrete) -> obj + | (_, { cth_mode = `Abstract }) as obj, (`All | `Abstract) -> obj + | _ -> lookup_error (`QSymbol qname) - let env = { env with env_current = current; env_comps = comps; } in + let lookup_opt ?mode name env = + try_lf (fun () -> lookup ?mode name env) - match cth.cth_mode with - | `Abstract -> - { env with - env_thenvs = Mp.set_union env.env_thenvs compiled.compiled; } + let lookup_path ?mode name env = + fst (lookup ?mode name env) - | `Concrete -> - { env with - env_tci = bind_instance_th thpath env.env_tci cth.cth_items; - env_tc = bind_tc_th thpath env.env_tc cth.cth_items; - env_rwbase = bind_br_th thpath env.env_rwbase cth.cth_items; - env_atbase = bind_at_th thpath env.env_atbase cth.cth_items; - env_ntbase = bind_nt_th thpath env.env_ntbase cth.cth_items; - env_redbase = bind_rd_th thpath env.env_redbase cth.cth_items; - env_thenvs = Mp.set_union env.env_thenvs compiled.compiled; } -end + (* ------------------------------------------------------------------ *) + let env_of_theory (p : EcPath.path) (env : env) = + if EcPath.isprefix ~prefix:p ~path:env.env_scope.ec_path then + env + else + Option.get (Mp.find_opt p env.env_thenvs) -(* -------------------------------------------------------------------- *) -let initial gstate = empty gstate +(* ------------------------------------------------------------------ *) + let rebind_alias (name : symbol) (path : path) (env : env) = + let th = by_path path env in + let src = EcPath.pqname (root env) name in + let env = MC.import_theory ~name path th env in + let env = MC.import_mc ~name (IPPath path) env in + let env = { env with env_albase = Mp.add path src env.env_albase } in + env -(* -------------------------------------------------------------------- *) -type ebinding = [ - | `Variable of EcTypes.ty - | `Function of function_ - | `Module of module_expr - | `ModType of module_sig -] + (* ------------------------------------------------------------------ *) + let alias ?(import = true) (name : symbol) (path : path) (env : env) = + let env = if import then rebind_alias name path env else env in + { env with env_item = mkitem ~import (Th_alias (name, path)) :: env.env_item } -(* FIXME section : Global ? *) -let bind1 ((x, eb) : symbol * ebinding) (env : env) = - match eb with - | `Variable ty -> Var .bind_pvglob x ty env - | `Function f -> Fun .bind x f env - | `Module m -> Mod .bind x {tme_expr = m; tme_loca = `Global} env - | `ModType i -> ModTy .bind x {tms_sig = i; tms_loca = `Global} env + (* ------------------------------------------------------------------ *) + let aliases (env : env) = + env.env_albase + + (* ------------------------------------------------------------------ *) + let rec bind_instance_th path inst cth = + List.fold_left (bind_instance_th_item path) inst cth + + and bind_instance_th_item path inst item = + if not item.ti_import then inst else -let bindall (items : (symbol * ebinding) list) (env : env) = - List.fold_left ((^~) bind1) env items + let xpath x = EcPath.pqname path x in -(* -------------------------------------------------------------------- *) -module LDecl = struct - type error = - | InvalidKind of EcIdent.t * [`Variable | `Hypothesis] - | CannotClear of EcIdent.t * EcIdent.t - | NameClash of [`Ident of EcIdent.t | `Symbol of symbol] - | LookupError of [`Ident of EcIdent.t | `Symbol of symbol] + match item.ti_item with + | Th_instance (ty, k, _) -> + TypeClass.bind_instance ty k inst - exception LdeclError of error + | Th_theory (x, cth) when cth.cth_mode = `Concrete -> + bind_instance_th (xpath x) inst cth.cth_items - let pp_error fmt (exn : error) = - match exn with - | LookupError (`Symbol s) -> - Format.fprintf fmt "unknown symbol %s" s + | Th_type (x, tyd) -> begin + match tyd.tyd_type with + | `Abstract tc -> + let myty = + let typ = List.map (fst_map EcIdent.fresh) tyd.tyd_params in + (typ, EcTypes.tconstr (xpath x) (List.map (tvar |- fst) typ)) + in + Sp.fold + (fun p inst -> TypeClass.bind_instance myty (`General p) inst) + tc inst - | NameClash (`Symbol s) -> - Format.fprintf fmt - "an hypothesis or variable named `%s` already exists" s + | _ -> inst + end - | InvalidKind (x, `Variable) -> - Format.fprintf fmt "`%s` is not a variable" (EcIdent.name x) + | _ -> inst - | InvalidKind (x, `Hypothesis) -> - Format.fprintf fmt "`%s` is not an hypothesis" (EcIdent.name x) + (* ------------------------------------------------------------------ *) + let rec bind_base_th tx path base cth = + List.fold_left (bind_base_th_item tx path) base cth - | CannotClear (id1,id2) -> - Format.fprintf fmt "cannot clear %s as it is used in %s" - (EcIdent.name id1) (EcIdent.name id2) + and bind_base_th_item tx path base item = + if not item.ti_import then base else - | LookupError (`Ident id) -> - Format.fprintf fmt "unknown identifier `%s`, please report" - (EcIdent.tostring_internal id) + let xpath x = EcPath.pqname path x in - | NameClash (`Ident id) -> - Format.fprintf fmt "name clash for `%s`, please report" - (EcIdent.tostring_internal id) + match item.ti_item with + | Th_theory (x, cth) -> begin + match cth.cth_mode with + | `Concrete -> + bind_base_th tx (xpath x) base cth.cth_items + | `Abstract -> base + end + | _ -> odfl base (tx path base item.ti_item) - let _ = EcPException.register (fun fmt exn -> - match exn with - | LdeclError e -> pp_error fmt e - | _ -> raise exn) + (* ------------------------------------------------------------------ *) + let bind_tc_th = + let for1 path base = function + | Th_typeclass (x, tc) -> + tc.tc_prt |> omap (fun prt -> + let src = EcPath.pqname path x in + TC.Graph.add ~src ~dst:prt base) + | _ -> None - let error e = raise (LdeclError e) + in bind_base_th for1 (* ------------------------------------------------------------------ *) - let ld_subst s ld = - match ld with - | LD_var (ty, body) -> - LD_var (ty_subst s ty, body |> omap (Fsubst.f_subst s)) + let bind_br_th = + let for1 path base = function + | Th_baserw (x,_) -> + let ip = IPPath (EcPath.pqname path x) in + assert (not (Mip.mem ip base)); + Some (Mip.add ip Sp.empty base) - | LD_mem mt -> - LD_mem (EcMemory.mt_subst (ty_subst s) mt) + | Th_addrw (b, r, _) -> + let change = function + | None -> assert false + | Some s -> Some (List.fold_left (fun s r -> Sp.add r s) s r) - | LD_modty mty -> - LD_modty (Fsubst.mty_mr_subst s mty) + in Some (Mip.change change (IPPath b) base) - | LD_hyp f -> - LD_hyp (Fsubst.f_subst s f) + | _ -> None - | LD_abs_st _ -> (* FIXME *) - assert false + in bind_base_th for1 (* ------------------------------------------------------------------ *) - let ld_fv = function - | LD_var (ty, None) -> - ty.ty_fv - | LD_var (ty,Some f) -> - EcIdent.fv_union ty.ty_fv f.f_fv - | LD_mem mt -> - EcMemory.mt_fv mt - | LD_hyp f -> - f.f_fv - | LD_modty p -> - gty_fv (GTmodty p) - | LD_abs_st us -> - let add fv (x,_) = match x with - | PVglob x -> EcPath.x_fv fv x - | PVloc _ -> fv in + let bind_at_th = + let for1 _path db = function + | Th_auto {level; base; axioms; _} -> + Some (Auto.updatedb ?base ~level axioms db) + | _ -> None - let fv = Mid.empty in - let fv = List.fold_left add fv us.aus_reads in - let fv = List.fold_left add fv us.aus_writes in - List.fold_left EcPath.x_fv fv us.aus_calls + in bind_base_th for1 (* ------------------------------------------------------------------ *) - let by_name s hyps = - match List.ofind ((=) s |- EcIdent.name |- fst) hyps.h_local with - | None -> error (LookupError (`Symbol s)) - | Some h -> h + let bind_nt_th = + let for1 path base = function + | Th_operator (x, ({ op_kind = OB_nott _ } as op)) -> + Some (Op.update_ntbase path (x, op) base) + | _ -> None - let by_id id hyps = - match List.ofind (EcIdent.id_equal id |- fst) hyps.h_local with - | None -> error (LookupError (`Ident id)) - | Some x -> snd x + in bind_base_th for1 (* ------------------------------------------------------------------ *) - let as_hyp = function - | (id, LD_hyp f) -> (id, f) - | (id, _) -> error (InvalidKind (id, `Hypothesis)) + let bind_rd_th = + let for1 _path db = function + | Th_reduction rules -> + let rules = List.map (fun (x, _, y) -> (x, y)) rules in + Some (Reduction.add_rules rules db) + | _ -> None - let as_var = function - | (id, LD_var (ty, _)) -> (id, ty) - | (id, _) -> error (InvalidKind (id, `Variable)) + in bind_base_th for1 (* ------------------------------------------------------------------ *) - let hyp_by_name s hyps = as_hyp (by_name s hyps) - let var_by_name s hyps = as_var (by_name s hyps) + let add_restr_th = + let for1 path env = function + | Th_module me -> Some (Mod.add_restr_to_declared path me env) + | _ -> None + in bind_base_th for1 (* ------------------------------------------------------------------ *) - let hyp_by_id x hyps = as_hyp (x, by_id x hyps) - let var_by_id x hyps = as_var (x, by_id x hyps) + let bind_cr_th = + let for1 (_ : path) (bindings : crbindings) = function + | Th_crbinding (CRB_Bitstring bs, _) -> + Some (Circuit.rebind_bitstring_ bs bindings) - (* ------------------------------------------------------------------ *) - let has_gen dcast s hyps = - try ignore (dcast (by_name s hyps)); true - with LdeclError (InvalidKind _ | LookupError _) -> false + | Th_crbinding (CRB_Array ba, _) -> + Some (Circuit.rebind_array_ ba bindings) - let hyp_exists s hyps = has_gen as_hyp s hyps - let var_exists s hyps = has_gen as_var s hyps + | Th_crbinding (CRB_BvOperator op, _) -> + Some (Circuit.rebind_bvoperator_ op bindings) + + | Th_crbinding (CRB_Circuit cr, _) -> + Some (Circuit.rebind_circuit_ cr bindings) + + | _ -> None + in bind_base_th for1 (* ------------------------------------------------------------------ *) - let has_id x hyps = - try ignore (by_id x hyps); true - with LdeclError (LookupError _) -> false + let bind + ?(import = true) + (cth : compiled_theory) + (env : env) + = + let { cth_items = items; cth_mode = mode; } = cth.ctheory in + let env = MC.bind_theory cth.name cth.ctheory env in + let env = { + env with + env_item = mkitem ~import (Th_theory (cth.name, cth.ctheory)) :: env.env_item } + in - let has_inld s = function - | LD_mem mt -> is_bound s mt - | _ -> false + let env = + match import, mode with + | _, `Concrete -> + let thname = EcPath.pqname (root env) cth.name in + let env_tci = bind_instance_th thname env.env_tci items in + let env_tc = bind_tc_th thname env.env_tc items in + let env_rwbase = bind_br_th thname env.env_rwbase items in + let env_atbase = bind_at_th thname env.env_atbase items in + let env_ntbase = bind_nt_th thname env.env_ntbase items in + let env_redbase = bind_rd_th thname env.env_redbase items in + let env_crbds = bind_cr_th thname env.env_crbds items in + let env = + { env with + env_tci ; env_tc ; env_rwbase; + env_atbase; env_ntbase; env_redbase; + env_crbds ; } + in + add_restr_th thname env items - let has_name ?(dep = false) s hyps = - let test (id, k) = - EcIdent.name id = s || (dep && has_inld s k) - in List.exists test hyps.h_local + | _, _ -> + env + in - (* ------------------------------------------------------------------ *) - let can_unfold id hyps = - try match by_id id hyps with LD_var (_, Some _) -> true | _ -> false - with LdeclError _ -> false + { env with + env_thenvs = Mp.set_union env.env_thenvs cth.compiled } - let unfold id hyps = - try - match by_id id hyps with - | LD_var (_, Some f) -> f - | _ -> raise NotReducible - with LdeclError _ -> raise NotReducible + (* ------------------------------------------------------------------ *) + let rebind name th env = + MC.bind_theory name th env (* ------------------------------------------------------------------ *) - let check_name_clash id hyps = - if has_id id hyps - then error (NameClash (`Ident id)) - else - let s = EcIdent.name id in - if s <> "_" && has_name ~dep:false s hyps then - error (NameClash (`Symbol s)) + let import (path : EcPath.path) (env : env) = + let rec import (env : env) path (th : theory) = + let xpath x = EcPath.pqname path x in + let import_th_item (env : env) (item : theory_item) = + if not item.ti_import then env else - let add_local id ld hyps = - check_name_clash id hyps; - { hyps with h_local = (id, ld) :: hyps.h_local } + match item.ti_item with + | Th_type (x, ty) -> + MC.import_tydecl (xpath x) ty env + + | Th_operator (x, op) -> + MC.import_operator (xpath x) op env + + | Th_axiom (x, ax) -> + MC.import_axiom (xpath x) ax env + + | Th_modtype (x, mty) -> + MC.import_modty (xpath x) mty env + + | Th_module ({ tme_expr = me; tme_loca = lc; }) -> + let env = MC.import_mod (IPPath (xpath me.me_name)) (me, Some lc) env in + let env = MC.import_mc (IPPath (xpath me.me_name)) env in + env - (* ------------------------------------------------------------------ *) - let fresh_id hyps s = - let s = - if s = "_" || not (has_name ~dep:true s hyps) - then s - else - let rec aux n = - let s = Printf.sprintf "%s%d" s n in - if has_name ~dep:true s hyps then aux (n+1) else s - in aux 0 + | Th_export (p, _) -> + import env p (by_path ~mode:`Concrete p env).cth_items - in EcIdent.create s + | Th_theory (x, ({cth_mode = `Concrete} as th)) -> + let env = MC.import_theory (xpath x) th env in + let env = MC.import_mc (IPPath (xpath x)) env in + env - let fresh_ids hyps names = - let do1 hyps s = - let id = fresh_id hyps s in - (add_local id (LD_var (tbool, None)) hyps, id) - in List.map_fold do1 hyps names + | Th_theory (x, ({cth_mode = `Abstract} as th)) -> + MC.import_theory (xpath x) th env - (* ------------------------------------------------------------------ *) - type hyps = { - le_init : env; - le_env : env; - le_hyps : EcBaseLogic.hyps; - } + | Th_typeclass (x, tc) -> + MC.import_typeclass (xpath x) tc env - let tohyps lenv = lenv.le_hyps - let toenv lenv = lenv.le_env - let baseenv lenv = lenv.le_init + | Th_baserw (x, _) -> + MC.import_rwbase (xpath x) env - let add_local_env x k env = - match k with - | LD_var (ty, _) -> Var.bind_local x ty env - | LD_mem mt -> Memory.push (x, mt) env - | LD_modty i -> Mod.bind_local x i env - | LD_hyp _ -> env - | LD_abs_st us -> AbsStmt.bind x us env + | Th_alias (name, path) -> + rebind_alias name path env - (* ------------------------------------------------------------------ *) - let add_local x k h = - let le_hyps = add_local x k (tohyps h) in - let le_env = add_local_env x k h.le_env in - { h with le_hyps; le_env; } + | Th_addrw _ + | Th_instance _ + | Th_auto _ + | Th_reduction _ + | Th_crbinding _ -> env + in + List.fold_left import_th_item env th - (* ------------------------------------------------------------------ *) - let init env ?(locals = []) tparams = - let buildenv env = - List.fold_right - (fun (x, k) env -> add_local_env x k env) - locals env in + import env path (by_path ~mode:`Concrete path env).cth_items - { le_init = env; - le_env = buildenv env; - le_hyps = { h_tvar = tparams; h_local = locals; }; } + (* ------------------------------------------------------------------ *) + let export (path : EcPath.path) lc (env : env) = + let env = import path env in + { env with env_item = mkitem ~import:true (Th_export (path, lc)) :: env.env_item } (* ------------------------------------------------------------------ *) - let clear ?(leniant = false) ids hyps = - let rec filter ids hyps = - match hyps with [] -> [] | ((id, lk) as bd) :: hyps -> + let rec filter clears root cleared items = + snd_map (List.pmap identity) + (List.map_fold (filter1 clears root) cleared items) - let ids, bd = - if EcIdent.Sid.mem id ids then (ids, None) else + and filter_th clears root cleared items = + let mempty = List.exists (EcPath.p_equal root) clears in + let cleared, items = filter clears root cleared items in - let fv = ld_fv lk in + if mempty && List.is_empty items + then (Sp.add root cleared, None) + else (cleared, Some items) - if Mid.set_disjoint ids fv then - (ids, Some bd) - else - if leniant then - (Mid.set_diff ids fv, Some bd) - else - let inter = Mid.set_inter ids fv in - error (CannotClear (Sid.choose inter, id)) - in List.ocons bd (filter ids hyps) - in + and filter1 clears root = + let inclear p = List.exists (EcPath.p_equal p) clears in + let thclear = inclear root in - let locals = filter ids hyps.le_hyps.h_local in + fun cleared item -> + let cleared, item_r = + match item.ti_item with + | Th_theory (x, cth) -> + let cleared, items = + let xpath = EcPath.pqname root x in + filter_th clears xpath cleared cth.cth_items in + let item = items |> omap (fun items -> + let cth = { cth with cth_items = items } in + Th_theory (x, cth)) in + (cleared, item) - init hyps.le_init ~locals hyps.le_hyps.h_tvar + | _ -> let item_r = match item.ti_item with - (* ------------------------------------------------------------------ *) - let hyp_convert x check hyps = - let module E = struct exception NoOp end in + | Th_axiom (_, { ax_kind = `Lemma }) when thclear -> + None - let init locals = init hyps.le_init ~locals hyps.le_hyps.h_tvar in + | Th_axiom (x, ({ ax_kind = `Axiom (tags, false) } as ax)) when thclear -> + Some (Th_axiom (x, { ax with ax_kind = `Axiom (tags, true) })) - let rec doit locals = - match locals with - | (y, LD_hyp fp) :: locals when EcIdent.id_equal x y -> begin - let fp' = check (lazy (init locals)) fp in - if fp == fp' then raise E.NoOp else (x, LD_hyp fp') :: locals - end + | Th_addrw (p, ps, lc) -> + let ps = List.filter ((not) |- inclear |- oget |- EcPath.prefix) ps in + if List.is_empty ps then None else Some (Th_addrw (p, ps,lc)) - | [] -> error (LookupError (`Ident x)) - | ld :: locals -> ld :: (doit locals) + | Th_auto ({ axioms } as auto_rl) -> + let axioms = List.filter (fun (p, _) -> + let p = oget (EcPath.prefix p) in + not (inclear p) + ) axioms in + if List.is_empty axioms then None else Some (Th_auto {auto_rl with axioms}) - in (try Some (doit hyps.le_hyps.h_local) with E.NoOp -> None) |> omap init + | (Th_export (p, _)) as item -> + if Sp.mem p cleared then None else Some item - (* ------------------------------------------------------------------ *) - let local_hyps x hyps = - let rec doit locals = - match locals with - | (y, _) :: locals -> - if EcIdent.id_equal x y then locals else doit locals - | [] -> - error (LookupError (`Ident x)) in + | _ as item -> Some item - let locals = doit hyps.le_hyps.h_local in - init hyps.le_init ~locals hyps.le_hyps.h_tvar + in (cleared, item_r) + + in (cleared, omap (fun item_r -> { item with ti_item = item_r; }) item_r) (* ------------------------------------------------------------------ *) - let by_name s hyps = by_name s (tohyps hyps) - let by_id x hyps = by_id x (tohyps hyps) + type clear_mode = [`Full | `ClearOnly | `No] - let has_name s hyps = has_name ~dep:false s (tohyps hyps) - let has_id x hyps = has_id x (tohyps hyps) + let close + ?(clears : path list = []) + ?(pempty : clear_mode = `No) + (loca : is_local) + (mode : thmode) + (env : env) + = + let items = List.rev env.env_item in + let items = + if List.is_empty clears + then (if List.is_empty items then None else Some items) + else snd (filter_th clears (root env) Sp.empty items) in - let hyp_by_name s hyps = hyp_by_name s (tohyps hyps) - let hyp_exists s hyps = hyp_exists s (tohyps hyps) - let hyp_by_id x hyps = snd (hyp_by_id x (tohyps hyps)) + let items = + match items, pempty with + | None, (`No | `ClearOnly) -> Some [] + | _, _ -> items + in - let var_by_name s hyps = var_by_name s (tohyps hyps) - let var_exists s hyps = var_exists s (tohyps hyps) - let var_by_id x hyps = snd (var_by_id x (tohyps hyps)) + items |> omap (fun items -> + let ctheory = + { cth_items = items + ; cth_source = None + ; cth_loca = loca + ; cth_mode = mode + } in - let can_unfold x hyps = can_unfold x (tohyps hyps) - let unfold x hyps = unfold x (tohyps hyps) + let root = env.env_scope.ec_path in + let name = EcPath.basename root in - let fresh_id hyps s = fresh_id (tohyps hyps) s - let fresh_ids hyps s = snd (fresh_ids (tohyps hyps) s) + let compiled = + Mp.filter + (fun path _ -> EcPath.isprefix ~prefix:root ~path) + env.env_thenvs in + let compiled = Mp.add env.env_scope.ec_path env compiled in + + { name; ctheory; compiled; } + ) (* ------------------------------------------------------------------ *) - let push_active_ss m lenv = - { lenv with le_env = Memory.push_active_ss m lenv.le_env } + let require (compiled : compiled_theory) (env : env) = + let cth = compiled.ctheory in + let rootnm = EcCoreLib.p_top in + let thpath = EcPath.pqname rootnm compiled.name in - let push_active_ts ml mr lenv = - { lenv with le_env = Memory.push_active_ts ml mr lenv.le_env } + let env = + match cth.cth_mode with + | `Concrete -> + let (_, thmc), submcs = + MC.mc_of_theory_r rootnm (compiled.name, cth) + in MC.bind_submc env rootnm ((compiled.name, thmc), submcs) - let push_all l lenv = - { lenv with le_env = Memory.push_all l lenv.le_env } + | `Abstract -> env + in - let hoareF mem xp lenv = - let env1, env2 = Fun.hoareF mem xp lenv.le_env in - { lenv with le_env = env1}, {lenv with le_env = env2 } + let topmc = Mip.find (IPPath rootnm) env.env_comps in + let topmc = MC._up_theory false topmc compiled.name (IPPath thpath, cth) in + let topmc = MC._up_mc false topmc (IPPath thpath) in - let equivF ml mr xp1 xp2 lenv = - let env1, env2 = Fun.equivF ml mr xp1 xp2 lenv.le_env in - { lenv with le_env = env1}, {lenv with le_env = env2 } + let current = env.env_current in + let current = MC._up_theory true current compiled.name (IPPath thpath, cth) in + let current = MC._up_mc true current (IPPath thpath) in - let inv_memenv ml mr lenv = - { lenv with le_env = Fun.inv_memenv ml mr lenv.le_env } + let comps = env.env_comps in + let comps = Mip.add (IPPath rootnm) topmc comps in - let inv_memenv1 m lenv = - { lenv with le_env = Fun.inv_memenv1 m lenv.le_env } -end + let env = { env with env_current = current; env_comps = comps; } in + + match cth.cth_mode with + | `Abstract -> + { env with + env_thenvs = Mp.set_union env.env_thenvs compiled.compiled; } + | `Concrete -> + { env with + env_tci = bind_instance_th thpath env.env_tci cth.cth_items; + env_tc = bind_tc_th thpath env.env_tc cth.cth_items; + env_rwbase = bind_br_th thpath env.env_rwbase cth.cth_items; + env_atbase = bind_at_th thpath env.env_atbase cth.cth_items; + env_ntbase = bind_nt_th thpath env.env_ntbase cth.cth_items; + env_redbase = bind_rd_th thpath env.env_redbase cth.cth_items; + env_crbds = bind_cr_th thpath env.env_crbds cth.cth_items; + env_thenvs = Mp.set_union env.env_thenvs compiled.compiled; } +end let pp_debug_form = ref (fun _env _f -> assert false) diff --git a/src/ecEnv.mli b/src/ecEnv.mli index 5a1d5bf602..55ab3b2b0c 100644 --- a/src/ecEnv.mli +++ b/src/ecEnv.mli @@ -16,8 +16,26 @@ type 'a suspension = { sp_params : int * (EcIdent.t * module_type) list; } +(* -------------------------------------------------------------------- *) +type crb_tyrev_binding = [ + | `Bitstring of crb_bitstring + | `Array of crb_array +] + +type crb_bitstring_operator = crb_bitstring * [`From | `To | `OfInt | `ToUInt | `ToSInt ] + +type crb_array_operator = crb_array * [`Get | `Set | `ToList | `OfList] + +type crb_oprev_binding = [ + | `Bitstring of crb_bitstring_operator + | `Array of crb_array_operator + | `BvOperator of crb_bvoperator + | `Circuit of crb_circuit +] + (* -------------------------------------------------------------------- *) type env + type scope = [ | `Theory | `Module of EcPath.mpath @@ -515,4 +533,40 @@ module LDecl : sig val inv_memenv1 : memory -> hyps -> hyps end +(* -------------------------------------------------------------------- *) +module Circuit : sig + val bind_bitstring : ?import:bool -> is_local -> crb_bitstring -> env -> env + val bind_array : ?import:bool -> is_local -> crb_array -> env -> env + val bind_bvoperator : ?import:bool -> is_local -> crb_bvoperator -> env -> env + val bind_circuit : ?import:bool -> is_local -> crb_circuit -> env -> env + val bind_crbinding : ?import:bool -> is_local -> crbinding -> env -> env + + val lookup_bitstring : env -> ty -> crb_bitstring option + val lookup_bitstring_path : env -> path -> crb_bitstring option + val lookup_bitstring_size : env -> ty -> int option + val lookup_bitstring_size_path : env -> path -> int option + + val lookup_bvoperator_path : env -> path -> crb_bvoperator option + val lookup_bvoperator : env -> qsymbol -> crb_bvoperator option + + val lookup_array : env -> ty -> crb_array option + val lookup_array_path : env -> path -> crb_array option + val lookup_array_size : env -> ty -> int option + + val lookup_array_and_bitstring : env -> ty -> (crb_array * crb_bitstring) option + + val lookup_circuit : env -> qsymbol -> Lospecs.Ast.adef option + val lookup_circuit_path : env -> path -> Lospecs.Ast.adef option + + val reverse_type : env -> path -> crb_tyrev_binding list + val reverse_operator : env -> path -> crb_oprev_binding list + + val reverse_bitstring_operator : env -> path -> crb_bitstring_operator option + val reverse_array_operator : env -> path -> crb_array_operator option + val reverse_bvoperator : env -> path -> crb_bvoperator option + val reverse_circuit : env -> path -> crb_circuit option + + val get_specification_by_name : env -> filename:string -> symbol -> Lospecs.Ast.adef option +end + val pp_debug_form : (env -> form -> unit) ref diff --git a/src/ecHiInductive.ml b/src/ecHiInductive.ml index a51dede082..981fca12d1 100644 --- a/src/ecHiInductive.ml +++ b/src/ecHiInductive.ml @@ -83,9 +83,10 @@ let trans_datatype (env : EcEnv.env) (name : ptydname) (dt : pdatatype) = let tpath = EcPath.pqname (EcEnv.root env) (unloc name) in let env0 = let myself = { - tyd_params = EcUnify.UniEnv.tparams ue; - tyd_type = `Abstract EcPath.Sp.empty; - tyd_loca = lc; + tyd_params = EcUnify.UniEnv.tparams ue; + tyd_type = `Abstract EcPath.Sp.empty; + tyd_loca = lc; + tyd_clinline = false; } in EcEnv.Ty.bind (unloc name) myself env in diff --git a/src/ecHiTacticals.ml b/src/ecHiTacticals.ml index 9eb6521e35..36b9a89d63 100644 --- a/src/ecHiTacticals.ml +++ b/src/ecHiTacticals.ml @@ -53,6 +53,11 @@ and process1_or (ttenv : ttenv) (t1 : ptactic) (t2 : ptactic) (tc : tcenv1) = and process1_try (ttenv : ttenv) (t : ptactic_core) (tc : tcenv1) = FApi.t_try (process1_core ttenv t) tc +(* -------------------------------------------------------------------- *) +and process1_extens (ttenv : ttenv) ((t, v) : ptactic_core * psymbol option) (tc : tcenv1) = + let v = Option.map unloc v in + EcPhlBDep.t_extens v (process1_core ttenv t) tc + (* -------------------------------------------------------------------- *) and process1_admit (_ : ttenv) (tc : tcenv1) = EcLowGoal.t_admit tc @@ -231,7 +236,13 @@ and process1_phl (_ : ttenv) (t : phltactic located) (tc : tcenv1) = | Plossless -> EcPhlHiAuto.t_lossless | Prepl_stmt infos -> EcPhlTrans.process_equiv_trans infos | Pprocrewrite (s, p, f) -> EcPhlRewrite.process_rewrite s p f - | Pchangestmt (s, p, c) -> EcPhlRewrite.process_change_stmt s p c + | Pchangestmt (s, b, p, c) -> EcPhlRewrite.process_change_stmt s b p c + | Pbdep bdinfo -> EcPhlBDep.process_bdep bdinfo + | Pbdepeval bdeinfo -> EcPhlBDep.process_bdep_eval bdeinfo + | Pbdepeq bdeinfo -> EcPhlBDep.process_bdepeq bdeinfo + | Pcircuit `Solve -> EcPhlBDep.t_bdep_solve + | Pcircuit `Simplify -> EcPhlBDep.t_bdep_simplify + | Pcirc (f, v) -> EcPhlBDep.process_bdep_form f v | Prwprgm infos -> EcPhlRwPrgm.process_rw_prgm infos in @@ -318,6 +329,7 @@ and process_core (ttenv : ttenv) ({ pl_loc = loc } as t : ptactic_core) (tc : tc | Psolve t -> `One (process1_solve ttenv t) | Pdo ((b, n), t) -> `One (process1_do ttenv (b, n) t) | Ptry t -> `One (process1_try ttenv t) + | Pextens (t, v) -> `One (process1_extens ttenv (t, v)) | Por (t1, t2) -> `One (process1_or ttenv t1 t2) | Pseq ts -> `One (process1_seq ttenv ts) | Pcase es -> `One (process1_case ttenv es) diff --git a/src/ecLexer.mll b/src/ecLexer.mll index 19536eaae7..025700d1d6 100644 --- a/src/ecLexer.mll +++ b/src/ecLexer.mll @@ -68,6 +68,7 @@ "last" , LAST ; (* KW: tactical *) "do" , DO ; (* KW: tactical *) "expect" , EXPECT ; (* KW: tactical *) + "extens" , EXTENS ; (* KW: tactical *) (* Lambda tactics *) "beta" , BETA ; (* KW: tactic *) @@ -168,7 +169,14 @@ "splitwhile" , SPLITWHILE ; (* KW: tactic *) "kill" , KILL ; (* KW: tactic *) "eager" , EAGER ; (* KW: tactic *) - + + "array" , ARRAY ; (* KW: global *) + "aig" , AIG ; (* KW: global *) + "bdep" , BDEP ; (* KW: global *) + "bdepeq" , BDEPEQ ; (* KW: global *) + "bind" , BIND ; (* KW: global *) + "circuit" , CIRCUIT ; (* KW: global *) + "bitstring" , BITSTRING ; (* KW: global *) "axiom" , AXIOM ; (* KW: global *) "axiomatized" , AXIOMATIZED; (* KW: global *) "lemma" , LEMMA ; (* KW: global *) diff --git a/src/ecLowCircuits.ml b/src/ecLowCircuits.ml new file mode 100644 index 0000000000..9440ebcebb --- /dev/null +++ b/src/ecLowCircuits.ml @@ -0,0 +1,2034 @@ +open EcBigInt +open EcUtils +open EcSymbols +open EcDecl +open EcIdent +open EcMemory + +(* -------------------------------------------------------------------- *) +module C = struct + include Lospecs.Aig + include Lospecs.Circuit + include Lospecs.Circuit_spec +end + +module HL = struct + include Lospecs.Hlaig + include Lospecs.Hlaig.Deps +end + +module Map = Batteries.Map +module Hashtbl = Batteries.Hashtbl +module Set = Batteries.Set +module Option = Batteries.Option + +exception CircError of string + +let debug : bool = true + +(* Backend implementing minimal functions needed for the translation *) +(* Minimal expected functionality is QF_ABV *) +(* Input are: some identifier + some bit *) +module type CBackend = sig + type node (* Corresponds to a single output node *) + type reg + (* Id + offset, both assumed starting at 0 *) + type inp = int * int + + val pp_node : Format.formatter -> node -> unit + + exception NonConstantCircuit (* FIXME: Rename later *) + + val true_ : node + val false_ : node + + val nodes_eq : node -> node -> bool + + val bad : node + val bad_reg : int -> reg + val has_bad : node -> bool + val have_bad : reg -> int option + + val node_array_of_reg : reg -> node array + val node_list_of_reg : reg -> node list + val reg_of_node_list : node list -> reg + val reg_of_node_array : node array -> reg + val reg_of_node : node -> reg + val node_of_reg : reg -> node + + val input_node : id:int -> int -> node + val input_of_size : ?offset:int -> id:int -> int -> reg + + val reg_of_zint : size:int -> zint -> reg + val bool_array_of_reg : reg -> bool array + val bool_list_of_reg : reg -> bool list + val szint_of_reg : reg -> zint + val uzint_of_reg : reg -> zint + val size_of_reg : reg -> int + + val apply : (inp -> node option) -> node -> node + val applys : (inp -> node option) -> reg -> reg + val circuit_from_spec : Lospecs.Ast.adef -> reg list -> reg + val equiv : ?inps:inp list -> pcond:node -> reg -> reg -> bool + val sat : ?inps:inp list -> node -> bool + val taut : ?inps:inp list -> node -> bool + + val slice : reg -> int -> int -> reg + val subcirc : reg -> (int list) -> reg + val insert : reg -> int -> reg -> reg + val get : reg -> int -> node + val permute : int -> (int -> int) -> reg -> reg + + val node_eq : node -> node -> node + val reg_eq : reg -> reg -> node + val node_ite : node -> node -> node -> node + val reg_ite : node -> reg -> reg -> reg + + val band : node -> node -> node + val bor : node -> node -> node + val bxor : node -> node -> node + val bnot : node -> node + val bxnor : node -> node -> node + val bnand : node -> node -> node + val bnor : node -> node -> node + + (* SMTLib Base Operations *) + (* FIXME: decide if boolean ops are going to be defined + on registers or on nodes *) + val add : reg -> reg -> reg + val sub : reg -> reg -> reg + val opp : reg -> reg + val mul : reg -> reg -> reg + val udiv : reg -> reg -> reg + val sdiv : reg -> reg -> reg + val umod : reg -> reg -> reg (* FIXME: mod or rem here? *) + val smod : reg -> reg -> reg + val lshl : reg -> reg -> reg + val lshr : reg -> reg -> reg + val ashr : reg -> reg -> reg + val rol : reg -> reg -> reg + val ror : reg -> reg -> reg + val land_ : reg -> reg -> reg + val lor_ : reg -> reg -> reg + val lxor_ : reg -> reg -> reg + val lnot_ : reg -> reg + val ult: reg -> reg -> node + val slt : reg -> reg -> node + val ule : reg -> reg -> node + val sle : reg -> reg -> node + val uext : reg -> int -> reg + val sext : reg -> int -> reg + val trunc : reg -> int -> reg + val concat : reg -> reg -> reg + + val flatten : reg list -> reg + + val reg_to_file : input_count:int -> ?inp_name_map:(int -> string) -> name:string -> reg -> symbol + + module Deps : sig + type dep = (int, int Set.t) Map.t + type deps = dep array + type block_deps + + val dep_of_node : node -> dep + val deps_of_reg : reg -> deps + val block_deps_of_deps : int -> deps -> block_deps + val block_deps_of_reg : int -> reg -> block_deps + + val pp_dep : Format.formatter -> dep -> unit + val pp_deps : Format.formatter -> deps -> unit + val pp_block_deps : Format.formatter -> block_deps -> unit + + (* Assumes only one reg -> reg and parallel blocks *) + val is_splittable : int -> int -> deps -> bool + + val are_independent : block_deps -> bool + + val single_dep : deps -> bool + (* Assumes single_dep *) + val dep_range : deps -> int * int + (* Checks if first dep is a subset of second dep *) + val dep_contained : dep -> dep -> bool + (* Checks if two dep sets are equal *) + val deps_equal : dep -> dep -> bool + (* Checks if all the deps are in a given list of inputs *) + val check_inputs : reg -> (int * int) list -> bool + + val forall_inputs : (int -> int -> bool) -> reg -> bool + val rename_inputs : ((int * int) -> (int * int) option) -> reg -> reg + (* TODO: Rename *) + val excise_bit : ?renamings:(int -> int option) -> node -> node * (int, int * int) Map.t + end +end + +module LospecsBack : CBackend = struct + type node = C.node + type reg = C.node array + type inp = int * int + + let pp_node (fmt : Format.formatter) (n: node) = + Format.fprintf fmt "%a" Lospecs.Aig.pp_node n + + exception NonConstantCircuit (* FIXME: Rename later *) + + let true_ = C.true_ + let false_ = C.false_ + let nodes_eq ({id=id1; _}: node) ({id=id2; _}: node) = id1 = id2 + let size_of_reg = Array.length + let bad = C.input (-1, -1) + let bad_reg = fun i -> Array.make i bad + let has_bad : node -> bool = + let cache : (int, bool) Hashtbl.t = Hashtbl.create 0 in + let rec doit (n: node) : bool = + match Hashtbl.find_option cache (Int.abs n.id) with + | Some b -> b + | None -> let b = doit_r n.gate in + Hashtbl.add cache (Int.abs n.id) b; + b + and doit_r (n: C.node_r) : bool = + match n with + | C.Input (-1, -1) -> true + | C.Input _ + | C.False -> false + | C.And (n1, n2) -> (doit n1) || (doit n2) + in + fun b -> doit b + + let have_bad (r: reg) : int option = + Array.find_opt (fun (_, r) -> has_bad r) (Array.mapi (fun i r -> (i,r)) r) |> Option.map fst + + let node_array_of_reg : reg -> node array = fun x -> x + + let node_list_of_reg : reg -> node list = fun x -> Array.to_list x + + let reg_of_node_list : node list -> reg = fun x -> Array.of_list x + + let reg_of_node_array : node array -> reg = fun x -> x + + let reg_of_node : node -> reg = fun x -> [| x |] + (* FIXME: throws array error, error handling TODO *) + let node_of_reg : reg -> node = fun x -> x.(0) + + let reg_of_zint ~(size: int) (v: zint) : reg = + C.of_bigint_all ~size (to_zt v) + + let bool_array_of_reg (r: reg) : bool array = + C.bools_of_reg r + + let bool_list_of_reg (r: reg) = + C.bool_list_of_reg r + + let szint_of_reg (r: reg) : zint = + C.bools_of_reg r |> C.sbigint_of_bools |> of_zt + + let uzint_of_reg (r: reg) : zint = + C.bools_of_reg r |> C.ubigint_of_bools |> of_zt + + let node_eq (n1: node) (n2: node) = C.xnor n1 n2 + let reg_eq (r1: reg) (r2: reg) = + Array.fold (fun acc r -> + C.and_ acc r) + C.true_ + (Array.map2 node_eq r1 r2) + let node_ite (c: node) (t: node) (f: node) = C.mux2 f t c + let reg_ite (c: node) = Array.map2 (node_ite c) + + let equiv ?(inps: inp list option) ~(pcond: node) (r1: reg) (r2: reg) : bool = + let open HL in + let module BWZ = (val makeBWZinterface ()) in + BWZ.circ_equiv ?inps r1 r2 pcond + + let sat ?(inps: inp list option) (n: node) : bool = + let open HL in + let module BWZ = (val makeBWZinterface ()) in + BWZ.circ_sat ?inps n + + let taut ?(inps: inp list option) (n: node) : bool = + let open HL in + let module BWZ = (val makeBWZinterface ()) in + BWZ.circ_taut ?inps n + + let slice (r: reg) (idx: int) (len: int) : reg = + Array.sub r idx len + + let subcirc (r: reg) (outs: int list) : reg = + List.map (fun i -> r.(i)) outs |> Array.of_list + + let insert (r: reg) (idx: int) (r_in: reg) : reg = + let ret = Array.copy r in + Array.blit r_in 0 ret idx (Array.length r_in); + ret + + (* FIXME: Error handling *) + let get (r: reg) (idx: int) = r.(idx) + + let permute (w: int) (perm: int -> int) (r: reg) : reg = + if debug then Format.eprintf "Applying permutation to reg of size %d with block size of %d@." (size_of_reg r) w; + Array.init (size_of_reg r) (fun i -> + let block_idx, bit_idx = perm (i / w), (i mod w) in + if block_idx < 0 then None + else + let idx = block_idx*w + bit_idx in + Some r.(idx) + ) |> Array.filter_map (fun x -> x) + + + (* Node operations *) + let band : node -> node -> node = C.and_ + let bor : node -> node -> node = C.or_ + let bxor : node -> node -> node = C.xor + let bnot : node -> node = C.neg + let bxnor : node -> node -> node = C.xnor + let bnand : node -> node -> node = C.nand + let bnor : node -> node -> node = fun n1 n2 -> C.neg @@ C.or_ n1 n2 + + (* FIXME: maybe convert to BigInt? *) + let input_node ~id i = C.input (id, i) + let input_of_size ?(offset = 0) ~id (i: int) = Array.init i (fun i -> C.input (id, offset + i)) + + let apply (map_: inp -> node option) (n: node) : node= + C.map map_ n + + let applys (map_: inp -> node option) : reg -> reg = + fun r -> Array.map (C.map map_) r + + let circuit_from_spec (def: Lospecs.Ast.adef) (args: reg list) : reg = + C.circuit_of_specification args def + + (* SMTLib Base Ops *) + let add (r1: reg) (r2: reg) : reg = C.add_dropc r1 r2 + let sub (r1: reg) (r2: reg) : reg = C.sub_dropc r1 r2 + let opp (r: reg) : reg = C.opp r + let mul (r1: reg) (r2: reg) : reg = C.umull r1 r2 + let udiv (r1: reg) (r2: reg) : reg = C.udiv r1 r2 + let sdiv (r1: reg) (r2: reg) : reg = C.sdiv r1 r2 + (* FIXME: mod or rem here? *) + let umod (r1: reg) (r2: reg) : reg = C.umod r1 r2 + let smod (r1: reg) (r2: reg) : reg = C.smod r1 r2 + let lshl (r1: reg) (r2: reg) : reg = C.shift ~side:`L ~sign:`L r1 r2 + let lshr (r1: reg) (r2: reg) : reg = C.shift ~side:`R ~sign:`L r1 r2 + let ashr (r1: reg) (r2: reg) : reg = C.shift ~side:`R ~sign:`A r1 r2 + let rol (r1: reg) (r2: reg) : reg = C.rol r1 r2 + let ror (r1: reg) (r2: reg) : reg = C.ror r1 r2 + let land_ (r1: reg) (r2: reg) : reg = C.land_ r1 r2 + let lor_ (r1: reg) (r2: reg) : reg = C.lor_ r1 r2 + let lxor_ (r1: reg) (r2: reg) : reg = C.lxor_ r1 r2 + let lnot_ (r1: reg) : reg = C.lnot_ r1 + let ult (r1: reg) (r2: reg) : node = C.ugt r2 r1 + let slt (r1: reg) (r2: reg) : node = C.sgt r2 r1 + let ule (r1: reg) (r2: reg) : node = C.uge r2 r1 + let sle (r1: reg) (r2: reg) : node = C.sge r2 r1 + let uext (r1: reg) (size: int) : reg = C.uextend ~size r1 + let sext (r1: reg) (size: int) : reg = C.sextend ~size r1 + let trunc (r1: reg) (size: int) : reg = Array.sub r1 0 size + let concat (r1: reg) (r2: reg) : reg = Array.append r1 r2 + let flatten (rs: reg list) : reg = Array.concat rs + + let reg_to_file ~(input_count: int) ?(inp_name_map: (int -> string) option) ~(name: string) (r: reg) : symbol = + C.write_aiger_bin_temp ~input_count ?inp_name_map ~name r + + module Deps = struct + type dep = (int, int Set.t) Map.t + type deps = dep array + type block_deps = (int * dep) array + + let dep_of_node = fun n -> HL.dep n + let deps_of_reg = fun r -> HL.deps r + let block_deps_of_deps (w: int) (d: deps) : block_deps = + assert (Array.length d mod w = 0); + Array.init (Array.length d / w) (fun i -> + let deps = Array.sub d (i*w) w in + let block = Array.fold_left (fun acc m -> + Map.merge (fun idx d1 d2 -> + match d1, d2 with + | None, None -> None + | None, Some d | Some d, None -> Some d + | Some d1, Some d2 -> Some (Set.union d1 d2) + ) acc m) Map.empty deps in + (w, block) + ) + + let block_deps_of_reg (w: int) (r: reg) : block_deps = + let deps = deps_of_reg r in + block_deps_of_deps w deps + + let pp_dep (fmt: Format.formatter) (d: dep) : unit = + Map.iter (fun id bits -> + Format.fprintf fmt "%d: " id; + Set.iter (fun bit -> Format.fprintf fmt "%d " bit) bits; + Format.fprintf fmt "@\n" + ) d + + let pp_deps (fmt: Format.formatter) (d: deps) : unit = + Array.iteri (fun i d -> + Format.fprintf fmt "@[[%d]:@\n%a@]@\n" i + pp_dep d + ) d + + let pp_block_deps (fmt: Format.formatter) (bd: block_deps) : unit = + ignore @@ Array.fold_left (fun idx (w, d) -> + Format.fprintf fmt "@[[%d..%d]:@\n%a@]@\n" idx (idx + w - 1) + pp_dep d; + idx + w + ) 0 bd + + (* Assumes only one reg -> reg and parallel blocks *) + let is_splittable (w_in: int) (w_out: int) (d: deps) : bool = + match Set.cardinal + (Array.fold_left (Set.union) Set.empty + (Array.map (fun dep -> Map.keys dep |> Set.of_enum) d)) + with + | 0 -> true + | 1 -> + let blocks = block_deps_of_deps w_out d in + if debug then Format.eprintf "Checking block width...@."; + Array.fold_left_map (fun idx (width, d) -> + if Map.is_empty d then idx + width, true + else + let _, bits = Map.any d in + idx + width, Set.is_empty bits || + let base = Set.at_rank_exn 0 bits in + if Set.for_all (fun bit -> + let dist = bit - base in + if 0 <= dist && dist < w_in then true else false +(* + (Format.eprintf "Current bit: %d | Current dist: %d | Limit: %d@." bit dist w_in; + Format.eprintf "Base for current block: %d@." base; + false) +*) + ) bits then true else + begin + if debug then Format.eprintf "Bad block: [%d..%d] %a@." idx (idx + width - 1) pp_dep d; false + end + ) 0 blocks |> snd |> Array.for_all (fun x -> x) + | _ -> begin + if debug then Format.eprintf "Failed first check@\n"; + if debug then Format.eprintf "Map keys: "; + if debug then Array.iteri (fun i dep -> + Format.eprintf "Bit %d: " i; + List.iter (Format.eprintf "%d") (Map.keys dep |> List.of_enum); + Format.eprintf "@\n") d; + false + end + + + let are_independent (bd: block_deps) : bool = + let exception BreakOut in + try + ignore @@ Array.fold_left (fun acc (_, d) -> + Map.merge (fun _ d1 d2 -> + match d1, d2 with + | None, None -> None + | Some d, None | None, Some d -> Some d + | Some d1, Some d2 -> + if not (Set.disjoint d1 d2) then raise BreakOut else + Some (Set.union d1 d2) + ) acc d + ) Map.empty bd; + true + with BreakOut -> + false + + + let single_dep (d: deps) : bool = + match Set.cardinal + (Array.fold_left (Set.union) Set.empty + (Array.map (fun dep -> Map.keys dep |> Set.of_enum) d)) + with + | 0 | 1 -> true + | _ -> false + + (* Assumes single_dep, returns range (bot, top) such that valid idxs are bot <= i < top *) + let dep_range (d: deps) : int * int = + assert (single_dep d); + let idxs = + Array.fold_left (fun acc d -> + Set.union (Map.fold Set.union d Set.empty) acc) Set.empty d + in + if debug then Format.eprintf "%a@." pp_deps d; + if debug then Format.eprintf "Dep range for dependencies:@."; + if debug then Set.iter (fun i -> Format.eprintf "%d " i) idxs; + if debug then Format.eprintf "@.Min: %d | Max: %d@." (Set.min_elt idxs) (Set.max_elt idxs); + (Set.min_elt idxs, Set.max_elt idxs + 1) + + (* Checks that all dependencies of r are in the set inps *) + (* Each elements of inps is (id, width) *) + let check_inputs (r: reg) (inps: (int * int) list) : bool = + let ds = deps_of_reg r in + Array.for_all (fun d -> + Map.for_all (fun id b -> + match List.find_opt (fun (id_, _) -> id = id_) inps with + | Some (_, b_) -> Set.for_all (fun b -> 0 <= b && b < b_) b + | None -> false + ) d + ) ds + + let dep_contained (subd: dep) (superd: dep) : bool = + Map.for_all (fun id bits -> + match Map.find_opt id superd with + | None -> false + | Some superbits -> Set.subset bits superbits + ) subd + + let deps_equal (d1: dep) (d2: dep) : bool = + (Map.equal (Set.equal) d1 d2) + + let forall_inputs (check: int -> int -> bool) (r: reg) : bool = + let d = deps_of_reg r in + Array.for_all (fun d -> + Map.for_all (fun id bs -> Set.for_all (check id) bs) d) + d + + let rename_inputs (renamer: (int * int) -> (int * int) option) (r: reg) : reg = + C.maps (fun (id, b) -> + Option.map (fun (id, b) -> input_node ~id b) (renamer (id, b)) + ) r + + let excise_bit ?renamings (n: node) : node * (int, int * int) Map.t = + HL.realign_inputs ?renamings n + end +end + +module type CircuitInterface = sig + type flatcirc + type ctype = + CArray of {width: int; count: int} + | CBitstring of int + | CTuple of ctype list + | CBool + type cinp = { + type_ : ctype; + id: int + } + type circ = { + reg: flatcirc ; + type_: ctype ; + } + type 'a cfun = 'a * (cinp list) + type circuit = circ cfun + + val pp_flatcirc : Format.formatter -> flatcirc -> unit + + module CArgs : sig + type arg = + [ `Circuit of circuit + | `Constant of zint + | `Init of int -> circuit + | `List of circuit list ] + + val arg_of_circuit : circuit -> arg + val arg_of_zint : zint -> arg + val arg_of_circuits : circuit list -> arg + val arg_of_init : (int -> circuit) -> arg + val pp_arg : Format.formatter -> arg -> unit + end + open CArgs + + module TranslationState : sig + type state + + val empty_state : state + + val update_state_pv : state -> memory -> symbol -> circuit -> state + val state_get_pv_opt : state -> memory -> symbol -> circuit option + val state_get_pv : state -> memory -> symbol -> circuit + val state_get_all_memory : state -> memory -> (symbol * circuit) list + val state_get_all_pv : state -> ((memory * symbol) * circuit) list +(* val map_state_pv : (symbol -> circuit -> circuit) -> state -> state *) + + val update_state : state -> ident -> circuit -> state + val state_get_opt : state -> ident -> circuit option + val state_get : state -> ident -> circuit + val state_bindings : state -> (ident * circuit) list + val state_lambdas : state -> cinp list + val state_is_closed : state -> bool + val state_close_circuit : state -> circuit -> circuit + val map_state_var : (ident -> circuit -> circuit) -> state -> state + + (* Circuit lambdas, for managing inputs *) + val open_circ_lambda : state -> (ident * ctype) list -> state + val open_circ_lambda_pv : state -> ((memory * symbol) * ctype) list -> state + val close_circ_lambda : state -> state + val circ_lambda_oneshot : state -> (ident * ctype) list -> (state -> circuit) -> circuit (* FIXME: rename or redo *) + end + + module BVOps : sig + val circuit_of_bvop : EcDecl.crb_bvoperator -> circuit + val circuit_of_parametric_bvop : EcDecl.crb_bvoperator -> arg list -> circuit + end + + module ArrayOps : sig + val array_get : circuit -> int -> circuit + val array_set : circuit -> int -> circuit -> circuit + val array_oflist : circuit list -> circuit -> int -> circuit + end + + (* Circuit type utilities *) + val size_of_ctype : ctype -> int + val convert_type : ctype -> circuit -> circuit + val can_convert_input_type : ctype -> ctype -> bool + + (* Pretty Printers *) + val pp_ctype : Format.formatter -> ctype -> unit + val pp_cinp : Format.formatter -> cinp -> unit + val pp_circ : Format.formatter -> circ -> unit + val pp_circuit : Format.formatter -> circuit -> unit + + (* General utilities *) + val circ_of_zint : size:int -> zint -> circ + val circuit_of_zint : size:int -> zint -> circuit + + (* Type conversions *) + (* TODO: Redo this *) +(* + val cbool_of_circuit : ?strict:bool -> circuit -> circuit + val cbitstring_of_circuit : ?strict:bool -> circuit -> circuit + val carray_of_circuit : ?strict:bool -> circuit -> circuit + val ctuple_of_circuit : ?strict:bool -> circuit -> circuit +*) + + (* Type constructors *) + val new_cbool_inp : ?name:[`Str of string | `Idn of ident] -> unit -> circ * cinp + val new_cbitstring_inp : ?name:[`Str of string | `Idn of ident] -> int -> circ * cinp + val new_carray_inp : ?name:[`Str of string | `Idn of ident] -> int -> int -> circ * cinp + val new_ctuple_inp : ?name:[`Str of string | `Idn of ident] -> ctype list -> circ * cinp + + (* Construct an input *) + val input_of_ctype : ?name:[`Str of string | `Idn of ident | `Bad] -> ctype -> circuit + + (* Aggregation functions *) + val circuit_aggregate : circuit list -> circuit + val circuit_aggregate_inputs : circuit -> circuit + + (* Circuits representing booleans *) + val circuit_true : circuit + val circuit_false : circuit + val circuit_and : circuit -> circuit -> circuit + val circuit_or : circuit -> circuit -> circuit + val circuit_not : circuit -> circuit + + (* <=> circuit has not inputs (every input is unbound) *) + val circuit_is_free : circuit -> bool + + (* Direct circuuit constructions *) + val circuit_ite : c:circuit -> t:circuit -> f:circuit -> circuit + val circuit_eq : circuit -> circuit -> circuit + val circuit_eqs : circuit -> circuit -> circuit list + + + (* Circuit tuples *) + val circuit_tuple_proj : circuit -> int -> circuit + val circuit_tuple_of_circuits : circuit list -> circuit + val circuits_of_circuit_tuple : circuit -> circuit list + + (* Avoid nodes for uninitialized inputs *) + val circuit_uninit : ctype -> circuit + val circuit_has_uninitialized : circuit -> int option + + (* Logical reasoning over circuits *) + val circ_equiv : ?pcond:circuit -> circuit -> circuit -> bool + val circ_sat : circuit -> bool + val circ_taut : circuit -> bool + + (* Composition of circuit functions, should deal with inputs and call some backend *) + val circuit_compose : circuit -> circuit list -> circuit + + (* Computing the function given by a circuit *) + val compute : sign:bool -> circuit -> arg list -> zint option + + (* Mapreduce/Dependecy analysis related functions *) + val is_decomposable : int -> int -> circuit -> bool + val decompose : int -> int -> circuit -> circuit list + val permute : int -> (int -> int) -> circuit -> circuit + val align_inputs : circuit -> (int * int) option list -> circuit + val circuit_slice : size:int -> circuit -> int -> circuit + val circuit_slice_insert : circuit -> int -> circuit -> circuit + val fillet_circuit : circuit -> circuit list + val fillet_tauts : ?mode:[`Seq | `Quad] -> circuit list -> circuit list -> bool + val batch_checks : ?sort:bool -> ?mode:[`ByEq | `BySub ] -> circuit list -> circuit list + + (* Wraps the backend call to deal with args/inputs *) + val circuit_to_file : name:string -> circuit -> symbol + + val circuit_from_spec : ?name:symbol -> (ctype list * ctype) -> Lospecs.Ast.adef -> circuit +end + +module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface = struct + (* Module Types *) + type flatcirc = Backend.reg + type width = int + type count = int + type ctype = + CArray of {width: int; count: int; } + | CBitstring of width + | CTuple of ctype list + | CBool + type circ = { + reg: flatcirc; + type_: ctype; +} + type cinp = { + type_ : ctype; + id : int; + } + type 'a cfun = 'a * (cinp list) + type circuit = circ cfun + + (* Helper functions *) + let (|->) ((a,b)) ((f,g)) = (f a, g b) + let idnt = fun x -> x + + let pp_flatcirc fmt fc = + let r = Backend.node_list_of_reg fc in + List.iter (fun n -> + Format.fprintf fmt "%a@." Backend.pp_node n + ) r + + let circ_of_zint ~(size: int) (i: zint) : circ = + {reg = Backend.reg_of_zint ~size i; type_= CBitstring size } + + let circuit_of_zint ~(size: int) (i: zint) : circuit = + ((circ_of_zint ~size i, []) :> circuit) + + let rec size_of_ctype (t: ctype) : int = + match t with + | CBitstring n -> n + | CArray {width; count} -> width * count + | CTuple tys -> List.sum (List.map size_of_ctype tys) + | CBool -> 1 + + (* Pretty printers *) + let rec pp_ctype (fmt: Format.formatter) (t: ctype) : unit = + match t with + | CArray {width; count} -> Format.fprintf fmt "Array(%d@%d)" count width + | CBool -> Format.fprintf fmt "Bool" + | CTuple szs -> Format.fprintf fmt "Tuple(%a)" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") pp_ctype) szs + | CBitstring w -> Format.fprintf fmt "Bitstring@%d" w + + let pp_cinp (fmt: Format.formatter) (inp: cinp) : unit = + Format.fprintf fmt "Input(id: %d, type = %a)" inp.id pp_ctype inp.type_ + + let pp_circ (fmt : Format.formatter) (c: circ) : unit = + Format.fprintf fmt "Circ(%a)" pp_ctype c.type_ + + let pp_circuit (fmt: Format.formatter) ((c, inps) : circuit) : unit = + Format.fprintf fmt "@[Circuit:@\nOut type %a@\nInputs: %a@]" + pp_circ c + (fun fmt inps -> List.iter (fun inp -> Format.fprintf fmt "%a@\n" pp_cinp inp) inps) inps + + (* arg for circuit construction *) + module CArgs = struct + type arg = + [ `Circuit of circuit + | `Constant of zint + | `Init of int -> circuit + | `List of circuit list ] + let arg_of_circuit c = + `Circuit c + let arg_of_zint z = + `Constant z + let arg_of_circuits cs = + `List cs + let arg_of_init f = + `Init f + let pp_arg fmt arg : unit = + match arg with + | `Circuit c -> Format.fprintf fmt "%a" pp_circuit c + | `Constant i -> Format.fprintf fmt "Constant: %s" (to_string i) + | `Init f -> Format.fprintf fmt "Init: Type of f(0): %a" pp_circuit (f 0) + | `List cs -> Format.fprintf fmt "@[ Circuit list: @\n%a@]" + (fun fmt cs -> List.iter (Format.fprintf fmt "%a@\n" pp_circuit) cs) cs + end + open CArgs + + module TranslationState = struct + type state = { + circs : circuit Mid.t; + lambdas : cinp list list; (* actually a stack *) + pv_ids : (ident * symbol, ident) Map.t; (* can be changed to int Msym.t if needed *) + } + + let empty_state : state = { + circs = Mid.empty; + lambdas = []; + pv_ids = Map.empty; (* can be changed to int Msym.t if needed *) + } + + let update_state_pv (st: state) (m: memory) (s: symbol) (c: circuit) : state = + match Map.find_opt (m, s) st.pv_ids with + | Some id -> {st with circs = Mid.add id c st.circs} + | None -> let id = EcIdent.create s in + { st with + pv_ids = Map.add (m, s) id st.pv_ids; + circs = Mid.add id c st.circs } + + let state_get_pv_opt (st: state) (m:memory) (s: symbol) : circuit option = + Option.bind (Map.find_opt (m, s) st.pv_ids) (fun id -> Mid.find_opt id st.circs) + + let state_get_pv : state -> memory -> symbol -> circuit = (fun st m s -> Option.get @@ state_get_pv_opt st m s) (* FIXME *) + let state_get_all_pv (st: state) : ((memory * symbol) * circuit) list = + let pvs = Map.bindings st.pv_ids in + List.filter_map (fun (pv, id) -> match Mid.find_opt id st.circs with | None -> None | Some c -> Some (pv, c)) pvs + + let state_get_all_memory (st: state) (m: memory) : (symbol * circuit) list = + List.filter_map (fun ((m_, s), c) -> if m = m_ then Some (s, c) else None) + (state_get_all_pv st) + +(* let map_state_pv : (symbol -> circuit -> circuit) -> state -> state = assert false *) + + let update_state (st: state) (id: ident) (c: circuit) : state = + { st with circs = Mid.add id c st.circs } + + let state_get_opt (st: state) (id: ident) : circuit option = Mid.find_opt id st.circs + let state_get (st: state) (id: ident) : circuit = Mid.find id st.circs + let state_bindings (st: state) : (ident * circuit) list = Mid.bindings st.circs + let state_lambdas (st: state) : cinp list = st.lambdas |> List.rev |> List.flatten + let state_is_closed (st: state) : bool = List.is_empty st.lambdas + let state_close_circuit (st: state) ((c, inps): circuit) : circuit = + c, List.fold_left (fun inps lamb -> lamb @ inps) inps st.lambdas + + let map_state_var (f: (ident -> circuit -> circuit)) (st: state) : state = + {st with circs = Mid.mapi f st.circs} + + let cinput_of_type (name: [`Idn of ident | `Str of string]) (t: ctype) : cinp * circuit = + let name = match name with + | `Idn id -> id + | `Str s -> EcIdent.create s + in + { id = name.id_tag; type_ = t}, + ({ reg = Backend.input_of_size ~id:name.id_tag (size_of_ctype t); type_ = t}, []) + + (* Circuit lambdas, for managing inputs *) + let open_circ_lambda (st: state) (bnds: (ident * ctype) list) : state = + let inps, cs = List.map (fun (id, t) -> + if debug then Format.eprintf "Opening circuit lambda for ident: (%s, %d)@." (name id) (tag id); + let inp, c = cinput_of_type (`Idn id) t + in inp, (id, c)) bnds |> List.split in + {st with + lambdas = inps::st.lambdas; + circs = List.fold_left (fun circs (id, c) -> Mid.add id c circs) st.circs cs } + + let open_circ_lambda_pv (st: state) (bnds : ((memory * symbol) * ctype) list) : state = + let st, bnds = List.fold_left_map (fun st ((m, s), t) -> + match Map.find_opt (m, s) st.pv_ids with + | Some id -> st, (id, t) + | None -> let id = EcIdent.create s in + { st with pv_ids = Map.add (m, s) id st.pv_ids}, (id, t)) st bnds + in open_circ_lambda st bnds + + (* FIXME: should we remove id from the mapping? *) + let close_circ_lambda (st: state) : state = + match st.lambdas with + | [] -> raise (CircError "no lambda to close in state") + | inps::lambdas -> + {st with lambdas = lambdas; + circs = Mid.map (fun (c, cinps) -> (c, inps @ cinps)) st.circs } + + let circ_lambda_oneshot (st: state) (bnds : (ident * ctype) list) (c: state -> circuit) : circuit = + let st' = open_circ_lambda st bnds in + let (c, inps) = c st' in + (c, (List.hd st'.lambdas) @ inps) + end + + (* Inputs helper functions *) + let merge_inputs (cs: cinp list) (ds: cinp list) : cinp list = +(* + Format.eprintf "Comparing input lists: @."; + List.iter (Format.eprintf "%a " pp_cinp) cs; + Format.eprintf "@."; + List.iter (Format.eprintf "%a " pp_cinp) ds; + Format.eprintf "@."; +*) + if List.for_all2 (fun {id=id1; type_=ct1} {id=id2; type_=ct2} -> id1 = id2 && ct1 = ct2) cs ds then cs + else cs @ ds + + let merge_inputs_list (cs: cinp list list) : cinp list = + List.fold_right (merge_inputs) cs [] + + let merge_circuit_inputs (c: circuit) (d: circuit) : cinp list = + merge_inputs (snd c) (snd d) + + let unify_inputs_renamer (target: cinp list) (inps: cinp list) : Backend.inp -> Backend.node option = + let map = List.fold_left2 (fun map inp1 inp2 -> match inp1, inp2 with + | {type_ = CBitstring w ; id=id_tgt}, + {type_ = CBitstring w'; id=id_orig} when w = w' -> + List.fold_left (fun map i -> Map.add (id_orig, i) (Backend.input_node ~id:id_tgt i) map) + map (List.init w (fun i -> i)) + | {type_ = CArray {width=w; count=n}; id=id_tgt}, + {type_ = CArray {width=w'; count=n'}; id=id_orig} when w = w' && n = n' -> + List.fold_left (fun map i -> Map.add (id_orig, i) (Backend.input_node ~id:id_tgt i) map) + map (List.init (w*n) (fun i -> i)) + | {type_ = CTuple tys ; id=id_tgt}, + {type_ = CTuple tys'; id=id_orig} when List.for_all2 (=) tys tys' -> + let w = List.sum (List.map size_of_ctype tys) in + List.fold_left (fun map i -> Map.add (id_orig, i) (Backend.input_node ~id:id_tgt i) map) + map (List.init (w) (fun i -> i)) + | {type_ = CBool; id=id_tgt}, + {type_ = CBool; id=id_orig} -> + Map.add (id_orig, 0) (Backend.input_node ~id:id_tgt 0) map + | _ -> raise (CircError (Format.asprintf "Mismatched inputs:@\nInp1: %a@\nInp2: %a@\n" pp_cinp inp1 pp_cinp inp2)) + ) Map.empty target inps in + fun inp -> Map.find_opt inp map + + (* Renames circuit2 inputs to match circuit 1 *) + let unify_inputs (target: cinp list) ((c, inps): circuit) : circ = + let map_ = unify_inputs_renamer target inps in + {c with reg = Backend.applys map_ c.reg} + + let circuit_input_compatible ?(strict = false) ((c, _): circuit) (cinp: cinp) : bool = + match c.type_, cinp with + | CBitstring n, { type_ = CBitstring n' } when n = n' -> true + | CArray {width=w; count=n}, { type_ = CArray {width=w'; count=n'}} when w = w' && n = n' -> true + | CTuple (szs), { type_ = CTuple szs' } when List.all2 (=) szs szs' -> true + | CBool, { type_ = CBool } -> true + | CBool, { type_ = CBitstring 1 } when not strict -> true + | CBitstring 1, { type_ = CBool } when not strict -> true + | _ -> false + + (* Circuit tuples *) + let circuit_tuple_proj (({reg = r; type_= CTuple tys}, inps): circuit) (i: int) = + let idx, ty = List.takedrop i tys in + let ty = List.hd ty in + let idx = List.fold_left (+) 0 (List.map size_of_ctype idx) in + {reg = (Backend.slice r idx (size_of_ctype ty)); type_ = ty}, inps + + let circuit_tuple_of_circuits (cs: circuit list) : circuit = + let tys = (List.map (fun (c : circuit) -> (fst c).type_) cs) in + let circ = Backend.flatten (List.map (fun (c : circuit) -> (fst c).reg) cs) in + let inps = List.snd cs in + {reg = circ; type_= CTuple tys}, merge_inputs_list inps + + let circuits_of_circuit_tuple (({reg= tp; type_=CTuple szs}, tpinps) : circuit) : circuit list = + snd @@ List.fold_left_map + (fun idx ty -> + let sz = (size_of_ctype ty) in + (idx + sz, + ({reg = (Backend.slice tp idx sz); type_ = ty}, tpinps))) + 0 szs + + (* Convert a circuit's output to a given circuit type *) + let convert_type (t: ctype) (({type_;_} as c, inps) as circ: circuit) : circuit = + match t, type_ with + (* When types are the same, do nothing *) + | (CArray {width=w; count=n}, CArray {width=w'; count=n'}) when w = w' && n = n' -> circ + | (CBitstring n, CBitstring n') when n = n' -> circ + | (CTuple tys, CTuple tys') when List.for_all2 (=) tys tys' -> circ + | (CBool, CBool) -> circ + + (* Bistring => Type conversions *) + | (CArray {width=w; count=n}, CBitstring n') when w * n = n' -> { c with type_ = t }, inps + | (CTuple tys, CBitstring n) when List.sum @@ List.map size_of_ctype tys = n -> { c with type_ = t}, inps + | (CBool, CBitstring 1) -> { c with type_ = t}, inps + + (* Type => Bitstring conversions *) + | (CBitstring n, CArray {width=w'; count=n'}) when n = w' * n' -> { c with type_ = t}, inps + | (CBitstring n, CTuple tys') when n = List.sum @@ List.map size_of_ctype tys' -> { c with type_ = t}, inps + | (CBitstring 1, CBool) -> {c with type_ = t}, inps + + (* Fail on everything else *) + | _ -> + raise (CircError (Format.asprintf "Failed to convert circuit %a of type %a to type %a@." + pp_circ c pp_ctype type_ pp_ctype t)) + + let can_convert_input_type (t1: ctype) (t2: ctype) : bool = + size_of_ctype t1 = size_of_ctype t2 + + let convert_input_types ((c, inps) : circuit) (tys: ctype list) : circuit = + let exception IncompatibleTypes in + c, List.map2 (fun inp ty -> + if can_convert_input_type inp.type_ ty then + { inp with type_ = ty } + else raise IncompatibleTypes + ) inps tys + + (* Input Helper Functions *) + (* FIXME: maybe change name from inp -> input? *) + let new_cbool_inp ?(name = `Str "input") () : circ * cinp = + let id, inp = match name with + | `Str name -> let id = EcIdent.create name |> tag in + id, Backend.input_node ~id 0 + | `Idn idn -> let id = tag idn in + id, Backend.input_node ~id 0 + | `Bad -> + -1, Backend.bad + in + { reg = Backend.reg_of_node inp; type_= CBool }, { type_ = CBool; id; } + + let new_cbitstring_inp ?(name = `Str "input") (sz: int) : circ * cinp = + let id, r = match name with + | `Str name -> let id = EcIdent.create name |> tag in + id, Backend.input_of_size ~id sz + | `Idn idn -> let id = tag idn in + id, Backend.input_of_size ~id sz + | `Bad -> + -1, Backend.bad_reg sz + in + { reg = r; type_ = CBitstring sz}, + { type_ = CBitstring sz; id; } + + (* TODO: maybe remove? *) + let new_cbitstring_inp_reg ?name (sz: int) : flatcirc * cinp = + let c, inp = new_cbitstring_inp ?name sz in + (c.reg, inp) + + let new_carray_inp ?(name = `Str "input") (el_sz: int) (arr_sz: int) : circ * cinp = + let id, arr = match name with + | `Str name -> let id = EcIdent.create name |> tag in + id, Backend.input_of_size ~id (el_sz * arr_sz) + | `Idn idn -> let id = tag idn in + id, Backend.input_of_size ~id (el_sz * arr_sz) + | `Bad -> + -1, Backend.bad_reg (el_sz * arr_sz) + in + { reg = arr; type_ = CArray {width=el_sz; count=arr_sz}}, + { type_ = CArray {width=el_sz; count=arr_sz}; id; } + + let new_ctuple_inp ?(name = `Str "input") (tys: ctype list) : circ * cinp = + let id, tp = match name with + | `Str name -> let id = EcIdent.create name |> tag in + id, Backend.input_of_size ~id (List.sum @@ List.map size_of_ctype tys) + | `Idn idn -> let id = tag idn in + id, Backend.input_of_size ~id (List.sum @@ List.map size_of_ctype tys) + | `Bad -> + -1, Backend.bad_reg (List.sum @@ List.map size_of_ctype tys) + in + { reg = tp; type_ = CTuple tys}, + { type_ = CTuple tys; id; } + + let input_of_ctype ?(name : [`Str of string | `Idn of ident | `Bad ] = `Str "input") (ct: ctype) : circuit = + let id, c = match name with + | `Str name -> let id = EcIdent.create name |> tag in + id, Backend.input_of_size ~id (size_of_ctype ct) + | `Idn idn -> let id = idn.id_tag in + id, Backend.input_of_size ~id (size_of_ctype ct) + | `Bad -> + -1, Backend.bad_reg (size_of_ctype ct) + in + { reg = c; type_ = ct; }, [{ id; type_ = ct; }] + + let circuit_true = {reg = Backend.reg_of_node Backend.true_; type_=CBool}, [] + let circuit_false = {reg = Backend.reg_of_node Backend.false_; type_=CBool}, [] + + let circuit_and ((c, cinps): circuit) ((d, dinps): circuit) = + if c.type_ = d.type_ then + { reg = Backend.land_ c.reg d.reg; type_ = c.type_ }, merge_inputs cinps dinps + else + raise (CircError "Cannot logical and circuits of different types ") + + let circuit_or ((c, cinps): circuit) ((d, dinps): circuit) = + if c.type_ = d.type_ then + { reg = Backend.lor_ c.reg d.reg; type_ = c.type_ }, merge_inputs cinps dinps + else + raise (CircError "Cannot logical or circuits of different types ") + + let circuit_not ((c, cinps): circuit) = + {c with reg = Backend.lnot_ c.reg}, cinps + + let circuit_is_free (f: circuit) : bool = List.is_empty @@ snd f + + let circuit_ite ~(c: circuit) ~(t: circuit) ~(f: circuit) : circuit = + assert ((circuit_is_free t) && (circuit_is_free f) && (circuit_is_free c)); + let c = match (fst c).type_ with + | CBool -> Backend.node_of_reg (fst c).reg + | _ -> assert false + in + let res_r = Backend.reg_ite c (fst t).reg (fst f).reg in + match ((fst t).type_, (fst f).type_) with + | CBitstring nt, CBitstring nf when nt = nf -> {reg = res_r; type_ = (fst t).type_}, [] + | CArray {width=wt; count=nt}, CArray {width=wf; count=nf} when wt = wf && nt = nf -> {reg = res_r; type_ = (fst t).type_}, [] + | CTuple szs_t, CTuple szs_f when List.all2 (=) szs_t szs_f -> {reg = res_r; type_ = (fst t).type_}, [] + | CBool, CBool -> {reg = res_r; type_ = (fst t).type_}, [] + | _ -> raise (CircError (Format.asprintf "Invalid arguments for circuit_ite (%a)" Format.(pp_print_list ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") pp_ctype) (List.map (fun (c: circuit) -> (fst c).type_) [t; f]))) + + (* TODO: type check? *) + let circuit_eq (c: circuit) (d: circuit) : circuit = + match (fst c).type_, (fst d).type_ with + | (CArray _), (CArray _) + | (CTuple _), (CTuple _) + | (CBitstring _), (CBitstring _) -> + {reg = (Backend.reg_eq (fst c).reg (fst d).reg |> Backend.reg_of_node); type_ = CBool}, merge_inputs (snd c) (snd d) + | CBool, CBool -> + {reg = (Backend.reg_eq (fst c).reg (fst d).reg |> Backend.reg_of_node); type_ = CBool}, merge_inputs (snd c) (snd d) + | CBool, CBitstring 1 -> + {reg = (Backend.reg_eq (fst c).reg (fst d).reg |> Backend.reg_of_node); type_ = CBool}, merge_inputs (snd c) (snd d) + | CBitstring 1, CBool -> + {reg = (Backend.reg_eq (fst c).reg (fst d).reg |> Backend.reg_of_node); type_ = CBool}, merge_inputs (snd c) (snd d) + | _ -> raise (CircError (Format.asprintf "Invalid arguments for circuit_eq (%a)" Format.(pp_print_list ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") pp_ctype) (List.map (fun (c : circuit) -> (fst c).type_) [c; d]))) + + (* Ignore types, do extensionally over bits, return the circuits evaluating to the condition *) + let circuit_eqs ((c, cinps): circuit) ((d, dinps): circuit) : circuit list = + let inps = merge_inputs cinps dinps in + assert (size_of_ctype c.type_ = size_of_ctype d.type_); + let cs = Backend.node_list_of_reg c.reg in + let ds = Backend.node_list_of_reg d.reg in + List.map2 (fun c d -> + let r = Backend.node_eq c d |> Backend.reg_of_node in + {reg = r; type_ = CBool}, inps) cs ds + + + let circuit_compose (c: circuit) (args: circuit list) : circuit = + (let exception InputIncompatible in + try + if not (List.for_all2 (fun c cinp -> circuit_input_compatible c cinp) args (snd c)) then raise InputIncompatible; + with + InputIncompatible -> + if debug then Format.eprintf "Error on application:@\nTarget:%a@\n@[Args:%a@]@\n" + pp_circuit c + (fun fmt cs -> List.iter (Format.fprintf fmt "%a@\n" pp_circuit) cs) args; + raise (CircError "Failed to compose circuits") + | Invalid_argument _ -> raise (CircError (Format.asprintf "Bad number of arguments to circuit (expected: %d, got: %d)" (List.length (snd c)) (List.length args)))); + let map = List.fold_left2 (fun map {id} c -> Map.add id c map) Map.empty (snd c) (List.fst args) in + let map_ (id, idx) = + let circ = Map.find_opt id map in + Option.bind circ (fun c -> + match c.type_ with + | CArray _ | CTuple _ | CBitstring _ -> + begin try + Some (Backend.get c.reg idx) + with Invalid_argument _ -> None + end + | CBool when idx = 0 -> Some (Backend.node_of_reg c.reg) + | _ -> None + ) + in + + let circ = {(fst c) with reg = Backend.applys map_ (fst c).reg} in + let inps = merge_inputs_list (List.snd args) in + (circ, inps) + + (* Circuit Lambda functions *) + + (* Functions for dealing with uninitialized inputs *) + let circuit_uninit (t: ctype) : circuit = + match t with + | CTuple szs -> + let ctp, cinp = new_ctuple_inp ~name:`Bad szs in + ((ctp, []) :> circuit) + | CArray {width=el_sz; count=arr_sz} -> + let carr, cinp = new_carray_inp ~name:`Bad el_sz arr_sz in + ((carr, []) :> circuit) + | CBitstring sz -> + let c, cinp = new_cbitstring_inp ~name:`Bad sz in + ((c, []) :> circuit) + | CBool -> + let c, inp = new_cbool_inp ~name:`Bad () in + ((c, []) :> circuit) + + let circuit_has_uninitialized (c: circuit) : int option = + Backend.have_bad (fst c).reg + + let circ_equiv ?(pcond:circuit option) ((c1, inps1): circuit) ((c2, inps2): circuit) : bool = +(* let () = if debug then Format.eprintf "c1: %a@\nc2: %a@\n" pp_circuit (c1, inps1) pp_circuit (c2, inps2) in *) + let pcond = Option.map (convert_type CBool) pcond in (* Try to convert to bool *) (* FIXME: duplicated check *) + let pcc = match pcond with + | Some ({reg = b; type_ = CBool}, pcinps) -> + Backend.apply (unify_inputs_renamer inps1 pcinps) (Backend.node_of_reg b) + | None -> Backend.true_ + | _ -> raise (CircError "non bool input for circuit equiv precondition") + in + (* TODO: add code to check that inputs match *) + let c2 = unify_inputs inps1 (c2, inps2) in + let inps = List.map (function + | { type_ = CBool; id } -> (id, 1) + | { type_ = CBitstring w; id } -> (id, w) + | { type_ = CArray {width=w1; count=w2}; id } -> (id, w1*w2) + | { type_ = CTuple tys; id } -> (id, List.sum @@ List.map size_of_ctype tys) + + ) inps1 in + if (c1.type_ = c2.type_) then + Backend.equiv ~inps ~pcond:pcc c1.reg c2.reg + else false + + let circ_sat ((c, inps): circuit) : bool = + if debug then Format.eprintf "Calling circ_sat on circuit: %a@." pp_circuit (c, inps); + let c = match c with + | {type_ = CBool; reg} -> Backend.node_of_reg reg + | _ -> raise (CircError "Cannot apply circ_sat on a non bool circuit") + in + let inps = List.map (function + | { type_ = CBool; id } -> (id, 1) + | { type_ = CBitstring w; id } -> (id, w) + | { type_ = CArray {width=w1; count=w2}; id } -> (id, w1*w2) + | { type_ = CTuple tys; id } -> (id, List.sum @@ List.map size_of_ctype tys) + + ) inps in + Backend.sat ~inps c + + let circ_taut ((c, inps): circuit) : bool = + if debug then Format.eprintf "Calling circ_taut on circuit: %a@." pp_circuit (c, inps); + let c = match c with + | {type_ = CBool; reg} -> Backend.node_of_reg reg + | _ -> raise (CircError "Cannot apply circ_sat on a non bool circuit") + in + let inps = List.map (function + | { type_ = CBool; id } -> (id, 1) + | { type_ = CBitstring w; id } -> (id, w) + | { type_ = CArray {width=w1; count=w2}; id } -> (id, w1*w2) + | { type_ = CTuple tys; id } -> (id, List.sum @@ List.map size_of_ctype tys) + + ) inps in + Backend.taut ~inps c + + (* Dependency analysis related functions. These assume one input/output and all bitstring types *) + (* For more complex circuits, we might be able to simulate this with a int -> (int, int) map *) + let is_decomposable (in_w: width) (out_w: width) ((r, inps) as c: circuit) : bool = + match r, inps with + | {type_= CBitstring w; reg = r}, {type_=CBitstring w'} :: [] when (w mod out_w = 0) -> + let deps = Backend.Deps.deps_of_reg r in + Backend.Deps.is_splittable in_w out_w deps && + let base, top = Backend.Deps.dep_range deps in + let () = if debug then Format.eprintf "Passed backend check, checking width of deps (top - base = %d | in_w = %d)@." (top - base) in_w in + (top - base) mod in_w = 0 + | _ -> + if debug then Format.eprintf "Failed decomposition type check@\n"; + if debug then Format.eprintf "In_w: %d | Out_w : %d | Circ: %a" in_w out_w pp_circuit c; + false + + (* TODO: Extend this for multiple inputs? *) + let align_renamer ((r, inps) : circuit) : (int * int) * cinp * (Backend.inp -> Backend.inp option) = + begin match r.type_ with + | CBitstring _ -> () + | _ -> assert false (* TODO: FIXME *) + end; + match inps with + | [{type_ = CBitstring w; id}] -> + let d = Backend.Deps.deps_of_reg r.reg in + assert (Backend.Deps.single_dep d); + let (start_idx, end_idx) as range = Backend.Deps.dep_range d in + range, + {type_ = CBitstring (end_idx - start_idx); id}, + (fun (id_, w) -> + if id <> id_ then None else + if w < start_idx || w >= end_idx then None + else Some (id_, w - start_idx)) + | _ -> raise (CircError "Failed to rename inputs in align_renamer") + + let align_inputs ((c, inps): circuit) (slcs: (int * int) option list) : circuit = + assert (List.compare_lengths inps slcs = 0); + let alignment = List.combine slcs inps in + let inps = List.map + (function + | None, inp -> inp + | Some (sz, offset), ({type_ = CBitstring w_} as inp) -> + {inp with type_ = CBitstring sz} + | Some (sz, offset), ({type_ = CArray {width=w; count=n}} as inp) -> + assert (sz mod w = 0); + {inp with type_ = CArray {width=w; count=sz / w}} + | _ -> raise (CircError "Failed to align inputs") + ) alignment + in + let aligners = + List.map + (function + | None, _ -> fun (id_, w) -> None + | Some (sz, offset), {id} -> + (fun (id_, w) -> + if debug then Format.eprintf "Aligning id=%d w=%d offset=%d sz=%d@." id_ w offset sz; + if id <> id_ then None else + if w < offset || w >= offset + sz then Some Backend.bad + else Some (Backend.input_node ~id (w - offset))) + ) alignment + in + let aligner = List.fold_left (fun f1 f2 -> + fun inp -> match f1 inp with + | Some _ as res -> res + | None -> f2 inp + ) (fun (id_, w) -> None) aligners + in + {c with reg = Backend.applys aligner c.reg}, inps + + (* Inputs mean different things depending on circuit type *) + (* FIXME PR: maybe differentiate the two functions ? *) + let circuit_slice ~(size:int) ((c, inps): circuit) (offset: int) : circuit = + assert (size >= 0); + assert (offset >= 0); + match c.type_ with + | CArray {width=w; count=n} when size mod w = 0 && offset mod w = 0 -> {reg = Backend.slice c.reg offset size; type_ = CArray {width=w; count=size}}, inps + | CArray _ -> raise (CircError "Bad array slice") + | CBitstring w -> + { reg = (Backend.slice c.reg offset size); type_ = CBitstring size}, inps + | CTuple tys -> + assert (List.length tys >= offset + size); + let offset, tys = List.takedrop offset tys in + let offset = List.sum @@ List.map size_of_ctype offset in + let tys = (List.take size tys) in + let sz = List.sum @@ List.map size_of_ctype tys in + {reg = (Backend.slice c.reg offset sz); type_ = CTuple tys}, inps + | CBool -> + raise (CircError "Cannot slice boolean circuit") + + (* Does not type check *) + let circuit_slice_insert ((orig_c, orig_inps): circuit) (idx: int) ((new_c, new_inps): circuit) : circuit = + { orig_c with reg = (Backend.insert orig_c.reg idx new_c.reg)}, merge_inputs orig_inps new_inps + + let split_renamer (n: count) (in_w: width) (inp: cinp) : (cinp array) * (Backend.inp -> Backend.node option) = + match inp with + | {type_ = CBitstring w; id} when w mod in_w = 0 -> + let ids = Array.init n (fun i -> create ("split_" ^ (string_of_int i)) |> tag) in + Array.map (fun id -> {type_ = CBitstring in_w; id}) ids, + (fun (id_, w) -> + if id <> id_ then None else + let id_idx, bit_idx = (w / in_w), (w mod in_w) in + Some (Backend.input_node ~id:ids.(id_idx) bit_idx)) + | {type_ = CBitstring w; id} -> + if debug then Format.eprintf "Failed to build split renamer for n=%d in_w=%d w=%d@." n in_w w; + raise (CircError "Failed to rename during split") + | _ -> raise (CircError "Failed to rename during split") + + let check_decomp_inputs ((c, inps): circuit) : bool = + begin match c.type_ with + | CBitstring _ -> () + | _ -> assert false (* TODO: FIXME *) + end; + let inps = List.map (function + | {type_ = CBitstring w; id} -> + (id, w) + | _ -> raise (CircError "Cannot apply mapreduce with more than one input") + ) inps in + Backend.Deps.check_inputs c.reg inps + + + (* + Takes a circuit and uses dependency analysis to separate it into + subcircuits corresponding to the output bits + + In particular, equivalence between two circuits is equivalent + to equivalence between the subcircuits + + Implicitly flattens everything to bitstrings + *) + let fillet_circuit ((c, inps) : circuit) : circuit list = + let r = c.reg |> Backend.node_list_of_reg in + List.map (fun n -> + let new_inps = List.map (fun {id;type_} -> + {id=EcIdent.create "_" |> tag; type_}) inps + in + let renamings = List.combine + (List.map (fun {id} -> id) inps) + (List.map (fun {id} -> id) new_inps) + |> List.to_seq |> Map.of_seq + in + let renamings = fun v -> Map.find_opt v renamings in + let n', shifts = Backend.Deps.excise_bit ~renamings n in + + let new_inps = List.filter_map (fun {id;type_} -> + match Map.find_opt id shifts with + | Some (low, hi) -> Some {id; type_ = CBitstring (hi - low + 1)} + | None -> None + ) new_inps in + { reg = Backend.reg_of_node n'; + type_ = CBool }, + new_inps + ) r + +(* + + Correct order is: + - Build two sided equality + - Dependency collapse (into lanes) + - Attach preconditions + - Realign inputs + - Structural equality check + - SMT check +*) + + +(* + let fillet_circuit ((c, inps) : circuit) : circuit list = + let rec collapse (acc: Backend.node list) (cur, d: Backend.node * Backend.Deps.dep) (cs: (Backend.node * Backend.Deps.dep) list) : Backend.node list = + match cs with + | [] -> (cur::acc) + | (c, d')::cs -> + if debug && false then Format.eprintf "Comparing deps:@.%a@.To deps:@.%a@." + Backend.Deps.pp_dep d + Backend.Deps.pp_dep d'; + if d = d' then + collapse acc ((Backend.band cur c), d) cs + else + collapse (cur::acc) (c, d') cs + in + + + let r = c.reg |> Backend.node_list_of_reg in + let nbits = List.length r in +(* let r = List.take nbits r in *) + + + let r = List.map (fun n -> + n, Backend.Deps.dep_of_node n) r in + + let r = match r with + | [] -> [] + | n::ns -> collapse [] n ns + in + + Format.eprintf "%d bits after collapsing (from %d initial)@." (List.length r) nbits; + + + let r = List.map Backend.Deps.excise_bit r in + let n1, s1 = List.hd r in + List.iteri (fun i (n, s) -> + Format.eprintf "Comparing node 0 to node %d => " i; + if Backend.nodes_eq n n1 then + Format.eprintf "Structurally equal@." + else + Format.eprintf "Structurally different@." + ) r; assert false + +*) + + + (* Batches circuit checks by dependencies. Assumes equivalent checks are contiguous *) + let batch_checks ?(sort = true) ?(mode : [`ByEq | `BySub] = `ByEq) (checks: circuit list) : circuit list = + (* Order by dependencies *) + let checks = if sort then begin + + let checks = List.map (fun (c, inps) -> + (c, inps), Backend.(Deps.dep_of_node (node_of_reg c.reg))) checks in + let checks = List.stable_sort (fun (_, d1) (_, d2) -> + let m1 = (Map.keys d1 |> Set.of_enum |> Set.min_elt_opt) in + let m2 = (Map.keys d2 |> Set.of_enum |> Set.min_elt_opt) in + (* FIXME: Check this *) + match m1, m2 with + | None, None -> 0 + | None, Some _ -> -1 + | Some _, None -> 1 + | Some m1, Some m2 -> + let c1 = Int.compare m1 m2 in + if c1 = 0 then (* FIXME: check default value V V *) + Int.compare (Map.find m1 d1 |> Set.min_elt_opt |> Option.default (-1)) (Map.find m1 d2 |> Set.min_elt_opt |> Option.default (-1)) + else + c1 + ) checks in + checks + end else + List.map (fun c -> + c, Backend.(Deps.dep_of_node (node_of_reg (fst c).reg))) checks + in + + let rec doit (acc: circuit list) (cur, d: circuit * Backend.Deps.dep) (cs: (circuit * Backend.Deps.dep) list) : circuit list = + match cs with + | [] -> (cur::acc) + | (c, d')::cs -> + if debug && false then Format.eprintf "Comparing deps:@.%a@.To deps:@.%a@." + Backend.Deps.pp_dep d + Backend.Deps.pp_dep d'; + begin match mode with + | `ByEq when Backend.Deps.deps_equal d d' -> + doit acc ((circuit_and cur c), d) cs + | `BySub when Backend.Deps.(dep_contained d d') -> + doit acc ((circuit_and cur c), d') cs + | `BySub when Backend.Deps.(dep_contained d' d) -> + doit acc ((circuit_and cur c), d) cs + | _ -> + Format.eprintf "Consolidated lane deps: %a@." Backend.Deps.pp_dep d; + doit (cur::acc) (c, d') cs + end + in + + match checks with + | [] -> [] + | c::cs -> doit [] c cs + + + + (* Assumes all the pre and post have been split, takes all the pres and one post *) + let fillet_taut (pres: (circuit * Backend.Deps.dep) list) ((post_circ, post_inps): circuit) : bool = + assert (List.for_all (fun ((_c, inps), _) -> inps = post_inps) pres); + assert (List.for_all (fun (({type_;reg}, _), _) -> type_ = CBool) pres); + assert (post_circ.type_ = CBool); + let d = Backend.(Deps.dep_of_node (node_of_reg post_circ.reg)) in + let compat_pres = List.filteri (fun i (c, pre_dep) -> + Backend.Deps.dep_contained pre_dep d + ) pres in + let compat_pres = List.fst compat_pres in + let node_post = Backend.node_of_reg post_circ.reg in + let nodes_pre = List.map (fun (c, _) -> Backend.node_of_reg c.reg) compat_pres in + let node_post, shifts = Backend.Deps.excise_bit node_post in + let inps = List.filter_map (fun {id; type_} -> + match Map.find_opt id shifts with + | Some (low, hi) -> Some {id; type_ = CBitstring (hi - low + 1)} + | None -> None + ) post_inps in + let inp_map = fun (id, v) -> + match Map.find_opt id shifts with + | Some (min, max) -> + let new_id = v - min in + assert (new_id <= max); + Some (id, v - min) + | None -> assert false + in + let nodes_pre = Backend.Deps.rename_inputs inp_map (Backend.reg_of_node_list nodes_pre) in + let pre = List.fold_left Backend.band Backend.true_ (Backend.node_list_of_reg nodes_pre) |> Backend.reg_of_node in + let pre = {reg = pre; type_ = CBool}, inps in + let post = Backend.reg_of_node node_post in + let post = {reg = post; type_ = CBool}, inps in + let cond = circuit_or (circuit_not pre) post in + circ_taut cond + + + let collapse_lanes (lanes: circuit list) = + (* Circuit structural equality after renaming *) + let (===) (c1: circ) (c2: circ) : bool = + let n', _ = Backend.node_of_reg c1.reg |> Backend.Deps.excise_bit in + let n, _ = Backend.node_of_reg c2.reg |> Backend.Deps.excise_bit in + Backend.nodes_eq n n' + in + let rec collapse (acc: circuit list) (cur: circuit) (cs: circuit list) : circuit list = + match cs with + | [] -> cur::acc + | c::cs -> + if (fst c) === (fst cur) then + collapse acc cur cs + else + collapse (cur::acc) c cs + in + (* FIXME: optimize later *) + let rec doit (cs: circuit list) : circuit list = + match cs with + | [] -> [] + | c::[] -> c::[] + | c::cs -> begin try + let idx, _ = List.findi (fun _ c2 -> (fst c) === (fst c2)) cs in + let idx = idx + 1 in (* Length of the list to merge *) + if idx = 1 then + doit (collapse [] c cs) + else + if (List.length (cs) + 1) mod idx != 0 then + (Format.eprintf "Cannot correctly infer lanes, defaulting to bruteforce checking@."; + (c::cs)) + else + let cs = List.chunkify idx (c::cs) |> List.map (List.reduce circuit_and) in + doit cs + with Not_found -> + c::cs + end + in + doit lanes + + (* + - Attaches preconditions to postconditions + - Realigns inputs + - Checks for structural equality of circuits + - SMT check for any remainings ones + *) + let fillet_tauts ?(mode: [`Seq | `Quad] = `Seq) (pres: circuit list) (posts: circuit list) : bool = + (* Remove structurally equal circuits *) + (* FIXME: not working because you have to 0 align everything before *) + (* Assumes everything is single bit outputs *) +(* + let rec collapse (acc: circuit list) (cur: circuit) (cs: circuit list) : circuit list = + match cs with + | [] -> cur::acc + | (({reg;_}, _) as c)::cs -> + let n', _ = Backend.node_of_reg reg |> Backend.Deps.excise_bit in + let n, _ = Backend.node_of_reg (fst cur).reg |> Backend.Deps.excise_bit in + if Backend.nodes_eq n n' then + collapse acc cur cs + else + collapse (cur::acc) c cs + in +*) + + let posts = List.filter_map (fun ((postc, _) as post) -> + if Backend.nodes_eq (Backend.node_of_reg postc.reg) Backend.true_ then None + else Some post + ) posts in + + (* FIXME: V Testing V *) +(* + Format.eprintf "Running new collapse lanes@."; + let tm = Unix.gettimeofday () in + let new_posts = collapse_lanes posts in + Format.eprintf "Done with new collapse, took %fs and collapsed down to %d lanes@." (Unix.gettimeofday () -. tm) (List.length new_posts); +*) + + (* Quadratic check FIXME *) +(* + List.iteri (fun i post -> + let n', _ = Backend.node_of_reg (List.hd posts |> fst).reg |> Backend.Deps.excise_bit in + let n, _ = Backend.node_of_reg (fst post).reg |> Backend.Deps.excise_bit in + Format.eprintf "Checking structural equality for node 0 and node %d -> %s@." i + (if Backend.nodes_eq n n' then "EQ" else "Not EQ") + ) posts; +*) + (* FIXME: A Testing A *) + + match posts with + | [] -> true + | posts -> + assert (List.for_all (fun ({type_;reg}, _) -> type_ = CBool) pres); + assert (List.for_all (fun ({type_;reg}, _) -> type_ = CBool) posts); +(* let posts = collapse [] post posts in *) + let posts = collapse_lanes posts in + if debug then Format.eprintf "%d conditions to check after structural equality collapse@." (List.length posts); + let pres = List.map (fun ((c, _) as circ) -> circ, + Backend.Deps.dep_of_node (Backend.node_of_reg c.reg)) pres in + List.mapi (fun i post -> + if debug then Format.eprintf "Checking equivalence for bit %d@." i; (* FIXME *) + let res = fillet_taut pres post in + if not res then Format.eprintf "Failed for bit %d@." i; + res) posts |> + List.for_all identity + + + (* General Mapreduce Procedure: + Assumes: + Input bits start from 0 + Input: + Circuit: One Bitstring Input => One Bitstring Output + Lane Descriptions: Output Bit List/Set/Array /\ Input Bit Set + Output: + Circuit List: One circuit for each lane, inputs renamed to be sequential + Throws: + CircuitError + | -> When lane outputs do not fully cover circuit + | -> When lane dependency description does not fit circuit + *) + let general_decompose (r: circ) (inp: cinp) (lanes: ((int list) * (int Set.t)) list): circuit list = + let exception DependencyError in + (* Check that outputs cover the circuit *) + let outputs = List.fst lanes |> List.flatten |> List.sort (Int.compare) in + assert (List.for_all2 (fun a b -> a = b) outputs (List.init (List.length outputs) (fun i -> i))); + + (* Separate one lane *) + let doit ((outputs, inputs): int list * int Set.t) : circuit = + let c = Backend.subcirc r.reg outputs in + if not @@ Backend.Deps.forall_inputs (fun id b -> + inp.id = id && + Set.mem b inputs) c then raise DependencyError; + let _, new_inp = new_cbitstring_inp (Set.cardinal inputs) in + let bit_renames = List.mapi (fun i b -> (b, i)) (Set.to_list inputs) in + let bit_renamer = Map.of_seq (List.to_seq bit_renames) in + let renamer (id, b) = + if id = inp.id then + Option.map (fun new_b -> (new_inp.id, new_b)) (Map.find_opt b bit_renamer) + else None + in {reg = (Backend.Deps.rename_inputs renamer c); type_ = CBitstring (Backend.size_of_reg c)}, [new_inp] + (* A FIXME A *) + in + try List.map doit lanes + with DependencyError -> + raise (CircError "dep_check_general_decompose") + + let decompose (in_w: width) (out_w: width) ((c, inps): circuit) : circuit list = + let n = Backend.size_of_reg c.reg in + assert (n mod out_w = 0); + let n_lanes = n / out_w in + let inp = match inps with + | ({type_ = CBitstring n'; id} as inp)::[] when n' mod in_w = 0 && n' / in_w = n_lanes -> inp + | _ -> raise (CircError "bad inputs in circ for mapreduce") + in + let lanes = List.map (fun i -> List.init out_w (fun j -> i*out_w + j), Set.of_list (List.init in_w (fun j -> i*in_w + j))) (List.init n_lanes (fun i -> i)) in + try general_decompose c inp lanes + with CircError _ -> + let d = Backend.Deps.block_deps_of_reg out_w c.reg in + Format.eprintf "Dependencies:@.%a@." Backend.Deps.pp_block_deps d; + raise (CircError "Split dependency check failed") + + + (* FIXME: different things based on input or just fix bitstrings? *) + let permute (w: width) (perm: (int -> int)) ((r, inps): circuit) : circuit = + begin match r.type_ with + | CBitstring _ -> () + | _ -> assert false (* TODO: FIXME *) + end; + assert false; {r with reg = (Backend.permute w perm r.reg)}, inps + + let compute ~(sign: bool) ((r, inps) as c: circuit) (args: arg list) : zint option = + begin match r.type_ with + | CBitstring _ -> () + | _ -> assert false (* TODO: FIXME *) + end; + + if List.compare_lengths args inps <> 0 + then raise (CircError (Format.asprintf "Bad number of arguments for compute (expected: %d, got: %d)" (List.length inps) (List.length args))); + let args = List.map2i (fun i arg inp -> + match arg, inp with + | `Circuit c, inp when circuit_input_compatible c inp -> c + | `Constant i, {type_ = CBitstring size} -> { reg = (Backend.reg_of_zint ~size i); type_ = CBitstring size}, [] + | _ -> raise (CircError (Format.asprintf "Arg mismatch at index %d, (arg: %a, inp: %a)" i pp_arg arg pp_cinp inp)) + ) args inps + in + match circuit_compose c args with + | {reg = r; type_ = CBitstring _}, [] -> + begin try + Some (if sign + then Backend.szint_of_reg r + else Backend.uzint_of_reg r) + with Backend.NonConstantCircuit -> None + end + | _, _::_ -> raise (CircError ("Non constant circuit in compute after arg application")) + | _ -> raise (CircError ("Got non-bitstring type in compute")) + + let circuit_aggregate (cs: circuit list) : circuit = + let inps = List.snd cs in + let cs = List.map (fun c -> (fst c).reg) cs in + let c = Backend.flatten cs in + let inps = merge_inputs_list inps in + {reg = c; type_ = CBitstring (Backend.size_of_reg c)}, inps + + let input_aggregate_renamer (inps: cinp list) : cinp * (Backend.inp -> Backend.node option) = + let new_id = create "aggregated" |> tag in + let (size, map) = List.fold_left (fun (size, map) inp -> + match inp with + | { type_ = CBitstring w; id} -> + (size + w, Map.add id (size, w) map) + | { type_ = CArray {width=w; count=n}; id} -> + (size + (w*n), Map.add id (size, w*n) map) + | { type_ = CTuple tys; id} -> + let w = List.sum @@ List.map size_of_ctype tys in + (size + w, Map.add id (size, w) map) + | { type_ = CBool; id} -> + (size + 1, Map.add id (size, 1) map) + ) (0, Map.empty) inps + in + {type_ = CBitstring size; id=new_id}, + fun (id, bit) -> + let base_sz = Map.find_opt id map in + Option.bind base_sz (fun (base, sz) -> + let idx = bit + base in + if bit >= 0 && bit < sz then + Some (Backend.input_node ~id:new_id idx) + else None + ) + + let circuit_aggregate_inputs ((c, inps): circuit) : circuit = + let inp, renamer = input_aggregate_renamer inps in + {c with reg = Backend.applys renamer c.reg}, [inp] + + let circuit_to_file ~(name: string) ((c, inps) as circ : circuit) : symbol = + match c, inps with + | {reg = r; type_ = CBitstring _}, {type_ = CBitstring w; id}::[] -> (* TODO: rename inputs? *) + Backend.reg_to_file ~input_count:w ~name (Backend.applys (fun (id_, i) -> if id_ = id then Some (Backend.input_node ~id:0 (i+1)) else None) r) + | _ -> raise (CircError (Format.asprintf "Unsupported circuit for output (%a), only one bitstring input one bitstring output supported@." pp_circuit circ)) + + let circuit_from_spec ?(name: symbol option) ((arg_tys, ret_ty) : (ctype list * ctype)) (spec: Lospecs.Ast.adef) : circuit = + let c = Backend.circuit_from_spec spec in + + let name = match name with + | Some name -> name ^ "_spec_input" + | None -> "spec_input" + in + + let cinps, inps = List.mapi (fun i ty -> + let id = EcIdent.create (name ^ "_" ^ (string_of_int i)) |> tag in + let size : int = size_of_ctype ty in + (Backend.input_of_size ~id size, { type_ = ty; id = id; } ) + ) arg_tys |> List.split in + let c = c cinps in + { reg = c; type_ = ret_ty}, inps (* TODO: type checking ? *) +(* { reg = c; CBitstring c, inps) |> convert_type ret_ty *) + + module BVOps = struct + let circuit_of_parametric_bvop (op: EcDecl.crb_bvoperator) (args: arg list) : circuit = + match op with + | { kind = `ASliceGet (((_, Some n), (_, Some w)), (_, Some m)) } -> + begin match args with + (* Assume type checking from EC? *) + | [ `Circuit (({type_ = CArray _}, _) as circ) ; `Constant i ] -> + begin + match (fst circ).type_ with + | CArray {width=w'; count=n'} when n' = n && w = w' -> + circuit_slice ~size:m ({reg = (fst circ).reg; type_ = CBitstring (w' * n')}, (snd circ)) (to_int i) + | CArray {width=w'; count=n'} -> + raise (CircError (Format.asprintf "Bad array size in asliceget (expected (%d, %d), got (%d, %d))" n w n' w')) + | _ -> assert false (* Does not happen, guarded by match above *) + end + | args -> + raise (CircError (Format.asprintf "Bad arguments for asliceget, expected (arr, idx), got (%a)" Format.(pp_print_list ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") pp_arg) args)) + end + | { kind = `ASliceSet (((_, Some n), (_, Some w)), (_, Some m)) } -> + begin match args with + | [ `Circuit (({type_ = CArray _}, _) as arr_circ) ; `Constant i ; `Circuit (({type_ = CBitstring _}, _) as bs_circ) ] -> + begin match (fst arr_circ).type_, (fst bs_circ).type_ with + | CArray {width=w'; count=n'}, CBitstring m' when n' = n && w' = w && m = m' -> + circuit_slice_insert arr_circ (to_int i) bs_circ + | ct1, ct2 -> raise (CircError (Format.asprintf "Bad sizes in asliceget (expected arr=(%d, %d) bs=(%d), got (%a, %a))" n w m pp_ctype ct1 pp_ctype ct2)) + end + | args -> + raise (CircError (Format.asprintf "Bad arguments for asliceset, expected (arr, idx, bitstring), got (%a)" Format.(pp_print_list ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") pp_arg) args)) + end + + (* FIXME: what do we want for out of bounds extract? Decide later *) + | { kind = `Extract ((_, Some w_in), (_, Some w_out)) } -> + begin match args with + | [ `Circuit (({type_ = CBitstring _}, _ ) as c) ; `Constant i ] -> + circuit_slice ~size:w_out c (to_int i) + | _ -> raise (CircError (Format.asprintf "Bad arguments for extract, expected (bitstring, idx), got (%a)" Format.(pp_print_list ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") pp_arg) args)) + end + | { kind = `Insert ((_, Some w_orig), (_, Some w_ins)) } -> + begin match args with + | [ `Circuit (({type_ = CBitstring _}, _) as orig_c) ; `Constant i ; `Circuit (({ type_=CBitstring _}, _) as new_c) ] -> + (circuit_slice_insert orig_c (to_int i) new_c :> circuit) + | _ -> raise (CircError (Format.asprintf "Bad arguments for insert, expected (orig_bs, idx, new_bs), got (%a)" Format.(pp_print_list ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") pp_arg) args)) + end + + | { kind = `Map ((_, Some w_i), (_, Some w_o), (_, Some n)) } -> + begin match args with + | [ `Circuit (({ type_=CBitstring _}, [{type_=CBitstring w_i'}; _]) as cf); `Circuit ({reg = arr; type_ = CArray {width=w_i''; count=n_i''}}, arr_inps) ] when (w_i' = w_i && w_i'' = w_i') && (n_i'' = n) -> + let circs, inps = List.split @@ List.map (fun c -> + match circuit_compose cf [c] with + | { type_ = CBitstring _; reg}, inps -> reg, inps + | c, _ -> raise (CircError (Format.asprintf "Bad return from map, expected bitstring, got %a" pp_circ c)) + ) + (List.init n (fun i -> {reg = (Backend.slice arr (i*w_i) w_i); type_ = CBitstring w_i}, [])) + in + (assert (List.for_all ((=) (List.hd inps)) inps)); + let inps = List.hd inps in + let circ = { reg = (Backend.flatten circs); type_ = CArray {width=w_o; count=n}} in + (circ, inps) + | args -> raise (CircError (Format.asprintf "Bad arguments for map, expected (lane_f, arr), got (%a)" Format.(pp_print_list ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") pp_arg) args)) + end + | { kind = `Get (_, Some w_in) } -> + begin match args with + | [ `Circuit ({reg = bs; type_ = CBitstring _}, cinps); `Constant i ] -> + {type_ = CBool; reg = Backend.reg_of_node (Backend.get bs (to_int i))}, cinps + | _ -> raise (CircError (Format.asprintf "Bad arguments for get, expected (bs, idx), got (%a)" Format.(pp_print_list ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") pp_arg) args)) + end + | { kind = `AInit ((_, Some n), (_, Some w_o)) } -> + begin match args with + | [ `Init init_f ] -> + let lanes = +(* + if debug then (fun i -> (* FIXME: Debug, remove later *) + let tm = Unix.gettimeofday () in + Format.eprintf "Generating lane %d of init, " i; + let res = init_f i in + Format.eprintf "took %f seconds@." (Unix.gettimeofday () -. tm); + res + ) + else +*) + init_f + in + let circs, cinps = List.split @@ List.init n lanes in + let circs = List.map + (function + | {type_ = CBitstring _; reg = r} when Backend.size_of_reg r = w_o -> r + | ret -> raise (CircError (Format.asprintf "Bad return for init_fun, expected bitstring, got (%a)" pp_circ ret))) + circs in + (assert (List.for_all ((=) (List.hd cinps)) cinps)); + let cinps = List.hd cinps in + {type_ = CArray {width=w_o; count=n} ; reg = Backend.flatten circs}, cinps + | _ -> raise (CircError (Format.asprintf "Bad arguments for ainit, expected (init_fun), got (%a)" Format.(pp_print_list ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") pp_arg) args)) + end + | { kind = `Init (_, Some w) } -> + begin match args with + | [ `Init init_f ] -> + let circs, cinps = List.split @@ List.init w init_f in + let circs = List.map + (function + | {type_ = CBool; reg = b} -> Backend.node_of_reg b + | ret -> raise (CircError (Format.asprintf "Bad return for init_fun, expected bitstring, got (%a)" pp_circ ret))) circs in + (assert (List.for_all ((=) (List.hd cinps)) cinps)); + let cinps = List.hd cinps in + {type_ = CBitstring w; reg = (Backend.reg_of_node_list circs)}, cinps + | _ -> raise (CircError (Format.asprintf "Bad arguments for init, expected (init_fun), got (%a)" Format.(pp_print_list ~pp_sep:(fun fmt () -> Format.pp_print_string fmt ", ") pp_arg) args)) + end + | _ -> assert false (* Should not happen because calls should be guarded by call to op_is_parametric_bvop *) + + + let circuit_of_bvop (op: EcDecl.crb_bvoperator) : circuit = + match op with + | { kind = `Add (_, Some size) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.add c1 c2 )}, [inp1; inp2] + + | { kind = `Sub (_, Some size) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.sub c1 c2)}, [inp1; inp2] + + | { kind = `Mul (_, Some size) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.mul c1 c2)}, [inp1; inp2] + + | { kind = `Div ((_, Some size), false) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.udiv c1 c2)}, [inp1; inp2] + + | { kind = `Div ((_, Some size), true) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.sdiv c1 c2)}, [inp1; inp2] + + | { kind = `Rem ((_, Some size), false) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.umod c1 c2)}, [inp1; inp2] + + | { kind = `Rem ((_, Some size), true) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.smod c1 c2)}, [inp1; inp2] + (* Should this be mod or rem? TODO FIXME*) + + | { kind = `Shl (_, Some size) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.lshl c1 c2)}, [inp1; inp2] + + | { kind = `Shr ((_, Some size), false) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.lshr c1 c2)}, [inp1; inp2] + + | { kind = `Shr ((_, Some size), true) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.ashr c1 c2)}, [inp1; inp2] + + | { kind = `Shls ((_, Some size1), (_, Some size2)) } -> + let c1, inp1 = new_cbitstring_inp_reg size1 in + let c2, inp2 = new_cbitstring_inp_reg size2 in + {type_ = CBitstring size1; reg = (Backend.lshl c1 c2)}, [inp1; inp2] + + | { kind = `Shrs ((_, Some size1), (_, Some size2), false) } -> + let c1, inp1 = new_cbitstring_inp_reg size1 in + let c2, inp2 = new_cbitstring_inp_reg size2 in + {type_ = CBitstring size1; reg = (Backend.lshr c1 c2)}, [inp1; inp2] + + | { kind = `Shrs ((_, Some size1), (_, Some size2), true) } -> + let c1, inp1 = new_cbitstring_inp_reg size1 in + let c2, inp2 = new_cbitstring_inp_reg size2 in + {type_ = CBitstring size1; reg = (Backend.ashr c1 c2)}, [inp1; inp2] + + | { kind = `Rol (_, Some size) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.rol c1 c2)}, [inp1; inp2] + + | { kind = `Ror (_, Some size) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.ror c1 c2)}, [inp1; inp2] + + | { kind = `And (_, Some size) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.land_ c1 c2)}, [inp1; inp2] + + | { kind = `Or (_, Some size) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.lor_ c1 c2)}, [inp1; inp2] + + | { kind = `Xor (_, Some size) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.lxor_ c1 c2)}, [inp1; inp2] + + | { kind = `Not (_, Some size) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.lnot_ c1)}, [inp1] + + | { kind = `Opp (_, Some size) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + {type_ = CBitstring size; reg = (Backend.opp c1)}, [inp1] + + | { kind = `Lt ((_, Some size), false) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBool; reg = Backend.reg_of_node (Backend.ult c1 c2)}, [inp1; inp2] + + | { kind = `Lt ((_, Some size), true) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBool; reg = Backend.reg_of_node (Backend.slt c1 c2)}, [inp1; inp2] + + | { kind = `Le ((_, Some size), false) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBool; reg = Backend.reg_of_node (Backend.ule c1 c2)}, [inp1; inp2] + + | { kind = `Le ((_, Some size), true) } -> + let c1, inp1 = new_cbitstring_inp_reg size in + let c2, inp2 = new_cbitstring_inp_reg size in + {type_ = CBool; reg = Backend.reg_of_node (Backend.sle c1 c2)}, [inp1; inp2] + + | { kind = `Extend ((_, Some size), (_, Some out_size), false) } -> + (* assert (size <= out_size); *) + let c1, inp1 = new_cbitstring_inp_reg size in + {type_ = CBitstring out_size; reg = (Backend.uext c1 out_size)}, [inp1] + + | { kind = `Extend ((_, Some size), (_, Some out_size), true) } -> + (* assert (size <= out_size); *) + let c1, inp1 = new_cbitstring_inp_reg size in + {type_ = CBitstring out_size; reg = (Backend.sext c1 out_size)}, [inp1] + + | { kind = `Truncate ((_, Some size), (_, Some out_sz)) } -> + (* assert (size >= out_sz); *) + let c1, inp1 = new_cbitstring_inp_reg size in + {type_ = CBitstring out_sz; reg = (Backend.trunc c1 out_sz)}, [inp1] + + | { kind = `Concat ((_, Some sz1), (_, Some sz2), (_, Some szo)) } -> + (* assert (sz1 + sz2 = szo); *) + let c1, inp1 = new_cbitstring_inp_reg sz1 in + let c2, inp2 = new_cbitstring_inp_reg sz2 in + {type_ = CBitstring szo; reg = (Backend.concat c1 c2)}, [inp1; inp2] + + | { kind = `A2B (((_, Some w), (_, Some n)), (_, Some m))} -> + (* assert (n * w = m); *) + let c1, inp1 = new_carray_inp w n in + {c1 with type_ = CBitstring m}, [inp1] + + | { kind = `B2A ((_, Some m), ((_, Some w), (_, Some n)))} -> + (* assert (n * w = m); *) + let c1, inp1 = new_cbitstring_inp m in + {c1 with type_ = CArray {width=w; count=n}}, [inp1] + + | { kind = `ASliceGet _ | `ASliceSet _ | `Extract _ | `Insert _ | `Map _ | `AInit _ | `Get _ | `Init _ } + | _ + -> assert false (* Should be guarded by call to op_is_bvop *) + + (* | _ -> raise @@ CircError "Failed to generate op" *) + end + + module ArrayOps = struct + let array_get (({reg = c; type_ = CArray {width=w; count=n}}, inps) : circuit) (i: int) : circuit = + try + { type_ = CBitstring w; reg = (Backend.slice c (i*w) w)}, inps + with Invalid_argument _ -> + raise (CircError (Format.asprintf "Bad index for bitstring get (expected < %d, got %d)" n i)) + + let array_set (({reg = arr; type_ = CArray {width=w; count=n}}, inps) : circuit) (pos: int) (({reg = bs; type_ = CBitstring w'}, cinps): circuit) : circuit = + let exception SizeMismatch in + try + assert (w = w'); + { type_ = CArray {width=w; count=n}; reg = (Backend.insert arr (pos * w) bs)}, + merge_inputs inps cinps + with Invalid_argument _ -> + raise (CircError (Format.asprintf "Bad index for bitstring set (expected < %d, got %d)" n pos)) + | SizeMismatch -> + raise (CircError (Format.asprintf "Array set size mismatch (expected: %d, got: %d)" w (Backend.size_of_reg bs))) + + (* FIXME: review this functiono | FIXME: Not axiomatized in QFABV.ec file *) + let array_oflist (circs : circuit list) (dfl: circuit) (len: int) : circuit = + let circs, inps = List.split circs in + let dif = len - List.length circs in assert (dif >= 0); + (* if debug then Format.eprintf "Len, Dif in array_oflist: %d, %d@." len dif; *) + let circs = circs @ (List.init dif (fun _ -> fst dfl)) in + let inps = if dif > 0 then inps @ [snd dfl] else inps in + let circs = List.map + (function + | {type_ = CBitstring _; reg = r} -> r + | list_elem -> raise (CircError (Format.asprintf "Bad circuit type for list element in array of_list, expected bitstring got %a" pp_circ list_elem)) + ) circs + in + { type_ = CArray {width=Backend.size_of_reg (List.hd circs); count=len}; reg = (Backend.flatten circs)}, merge_inputs_list inps + end +end + +include MakeCircuitInterfaceFromCBackend(LospecsBack) +include CArgs +include TranslationState +include BVOps +include ArrayOps + +let reset_backend_state () = + C.HCons.clear (); + HL.reset_state () diff --git a/src/ecLowPhlGoal.ml b/src/ecLowPhlGoal.ml index 97fe5f0b46..75b103beba 100644 --- a/src/ecLowPhlGoal.ml +++ b/src/ecLowPhlGoal.ml @@ -194,16 +194,16 @@ let is_program_logic (f : form) (ks : hlkind list) = let tc1_get_stmt side tc = let concl = FApi.tc1_goal tc in match side, concl.f_node with - | None, FhoareS hs -> hs.hs_m, hs.hs_s - | None, FeHoareS hs -> hs.ehs_m, hs.ehs_s - | None, FbdHoareS hs -> hs.bhs_m, hs.bhs_s + | None, FhoareS hs -> (hs.hs_m, hs.hs_s) + | None, FeHoareS hs -> (hs.ehs_m, hs.ehs_s) + | None, FbdHoareS hs -> (hs.bhs_m, hs.bhs_s) | Some _ , (FhoareS _ | FbdHoareS _) -> tc_error_noXhl ~kinds:[`Hoare `Stmt; `PHoare `Stmt] !!tc - | Some `Left, FequivS es -> es.es_ml, es.es_sl - | Some `Right, FequivS es -> es.es_mr, es.es_sr + | Some `Left, FequivS es -> (es.es_ml, es.es_sl) + | Some `Right, FequivS es -> (es.es_mr, es.es_sr) | None, FequivS _ -> tc_error_noXhl ~kinds:[`Equiv `Stmt] !!tc - | _ -> + | _ -> tc_error_noXhl ~kinds:(hlkinds_Xhl_r `Stmt) !!tc (* ------------------------------------------------------------------ *) diff --git a/src/ecOptions.ml b/src/ecOptions.ml index 4997737964..f02cf63a80 100644 --- a/src/ecOptions.ml +++ b/src/ecOptions.ml @@ -25,11 +25,13 @@ and cmp_option = { cmpo_tstats : string option; cmpo_noeco : bool; cmpo_script : bool; + cmpo_specs : spec_options; } and cli_option = { clio_emacs : bool; clio_provers : prv_options; + clio_specs : spec_options; } and run_option = { @@ -39,6 +41,7 @@ and run_option = { runo_provers : prv_options; runo_jobs : int option; runo_rawargs : string list; + runo_specs : spec_options; } and doc_option = { @@ -59,6 +62,10 @@ and prv_options = { prvo_why3server : string option; } +and spec_options = { + files : string list; +} + and ldr_options = { ldro_idirs : (string option * string * bool) list; ldro_boot : bool; @@ -80,6 +87,7 @@ type ini_options = { ini_timeout : int option; ini_idirs : (string option * string) list; ini_rdirs : (string option * string) list; + ini_specs : string list; } type ini_context = { @@ -98,6 +106,8 @@ module Ini : sig val get_provers : ini_context -> string list + val get_specs : ini_context -> string list + val get_timeout : ini_context -> int option val get_idirs : ini_context -> (string option * string) list @@ -113,6 +123,8 @@ module Ini : sig val get_all_provers : ini_context list -> string list + val get_all_specs : ini_context list -> string list + val get_all_timeout : ini_context list -> int option val get_all_idirs : ini_context list -> (string option * string) list @@ -144,6 +156,10 @@ end = struct let get_provers (ini : ini_context) = ini.inic_ini.ini_provers + let get_specs (ini : ini_context) = + List.map (absolute ?root:ini.inic_root) + ini.inic_ini.ini_specs + let get_timeout (ini : ini_context) = ini.inic_ini.ini_timeout @@ -170,6 +186,9 @@ end = struct let get_all_provers (ini : ini_context list) = List.flatten (List.map get_provers ini) + let get_all_specs (ini : ini_context list) = + List.flatten (List.map get_specs ini) + let get_all_timeout (ini : ini_context list) = List.find_map_opt get_timeout ini @@ -505,9 +524,14 @@ let prv_options_of_values ini values = prvo_why3server = get_string "why3server" values; } +let spec_options_of_values ini values = + { files = (Ini.get_all_specs ini) @ (get_strings "spec" values); } + let cli_options_of_values ini values = { clio_emacs = get_flag "emacs" values; - clio_provers = prv_options_of_values ini values; } + clio_provers = prv_options_of_values ini values; + clio_specs = spec_options_of_values ini values; + } let cmp_options_of_values ini values input = { cmpo_input = input; @@ -516,7 +540,9 @@ let cmp_options_of_values ini values input = cmpo_compact = get_int "compact" values; cmpo_tstats = get_string "tstats" values; cmpo_noeco = get_flag "no-eco" values; - cmpo_script = get_flag "script" values; } + cmpo_script = get_flag "script" values; + cmpo_specs = spec_options_of_values ini values; + } let runtest_options_of_values ini values (input, scenarios) = { runo_input = input; @@ -524,7 +550,9 @@ let runtest_options_of_values ini values (input, scenarios) = runo_report = get_string "report" values; runo_provers = prv_options_of_values ini values; runo_jobs = get_int "jobs" values; - runo_rawargs = get_strings "raw-args" values; } + runo_rawargs = get_strings "raw-args" values; + runo_specs = spec_options_of_values ini values; + } let doc_options_of_values values input = { doco_input = input; @@ -682,7 +710,9 @@ let read_ini_file (filename : string) = ini_provers = trylist "provers" ; ini_timeout = tryint "timeout" ; ini_idirs = List.map parse_idir (trylist "idirs"); - ini_rdirs = List.map parse_idir (trylist "rdirs"); } in + ini_rdirs = List.map parse_idir (trylist "rdirs"); + ini_specs = trylist "spec"; + } in { ini_ppwidth = ini.ini_ppwidth; ini_why3 = omap expand ini.ini_why3; @@ -690,4 +720,6 @@ let read_ini_file (filename : string) = ini_provers = ini.ini_provers; ini_timeout = ini.ini_timeout; ini_idirs = ini.ini_idirs; - ini_rdirs = ini.ini_rdirs; } + ini_rdirs = ini.ini_rdirs; + ini_specs = ini.ini_specs; + } diff --git a/src/ecOptions.mli b/src/ecOptions.mli index 5ba1d0f63a..c4abd74c30 100644 --- a/src/ecOptions.mli +++ b/src/ecOptions.mli @@ -21,11 +21,13 @@ and cmp_option = { cmpo_tstats : string option; cmpo_noeco : bool; cmpo_script : bool; + cmpo_specs : spec_options; } and cli_option = { clio_emacs : bool; clio_provers : prv_options; + clio_specs : spec_options; } and run_option = { @@ -35,6 +37,7 @@ and run_option = { runo_provers : prv_options; runo_jobs : int option; runo_rawargs : string list; + runo_specs : spec_options; } and doc_option = { @@ -55,6 +58,10 @@ and prv_options = { prvo_why3server : string option; } +and spec_options = { + files : string list; +} + and ldr_options = { ldro_idirs : (string option * string * bool) list; ldro_boot : bool; @@ -76,6 +83,7 @@ type ini_options = { ini_timeout : int option; ini_idirs : (string option * string) list; ini_rdirs : (string option * string) list; + ini_specs : string list; } type ini_context = { diff --git a/src/ecPV.ml b/src/ecPV.ml index 6d5c5bd5e1..e2da2dfa2a 100644 --- a/src/ecPV.ml +++ b/src/ecPV.ml @@ -111,6 +111,8 @@ module Mpv = struct check_glob env mp m; raise Not_found + let pvs { s_pv } = s_pv + type esubst = (expr, unit) t let rec esubst env (s : esubst) e = diff --git a/src/ecPV.mli b/src/ecPV.mli index 6821864cc3..0e9df4354b 100644 --- a/src/ecPV.mli +++ b/src/ecPV.mli @@ -53,10 +53,12 @@ module Mpv : sig val find_glob : env -> mpath -> ('a,'b) t -> 'b - val esubst : env -> (expr, unit) t -> expr -> expr + val pvs : ('a,'b) t -> 'a Mnpv.t + + val esubst : env -> (expr, unit) t -> expr -> expr val issubst : env -> (expr, unit) t -> instr list -> instr list - val isubst : env -> (expr, unit) t -> instr -> instr - val ssubst : env -> (expr, unit) t -> stmt -> stmt + val isubst : env -> (expr, unit) t -> instr -> instr + val ssubst : env -> (expr, unit) t -> stmt -> stmt end (* -------------------------------------------------------------------- *) diff --git a/src/ecParser.mly b/src/ecParser.mly index 46205d02b5..e8298f19b1 100644 --- a/src/ecParser.mly +++ b/src/ecParser.mly @@ -379,10 +379,12 @@ %token ABSTRACT %token ADMIT %token ADMITTED +%token AIG %token ALGNORM %token ALIAS %token AMP %token APPLY +%token ARRAY %token AS %token ASSERT %token ASSUMPTION @@ -393,7 +395,11 @@ %token AXIOMATIZED %token BACKS %token BACKSLASH +%token BDEP +%token BDEPEQ %token BETA +%token BITSTRING +%token BIND %token BY %token BYEQUIV %token BYPHOARE @@ -402,6 +408,7 @@ %token BYUPTO %token CALL %token CASE +%token CIRCUIT %token CBV %token CEQ %token CFOLD @@ -449,6 +456,7 @@ %token EXLIM %token EXPECT %token EXPORT +%token EXTENS %token FAIL %token FEL %token FIRST @@ -659,7 +667,10 @@ _lident: | x=LIDENT { x } | ABORT { "abort" } | ADMITTED { "admitted" } +| ARRAY { "array" } | ASYNC { "async" } +| BIND { "bind" } +| BITSTRING { "bitstring" } | DEBUG { "debug" } | DUMP { "dump" } | EXPECT { "expect" } @@ -715,6 +726,7 @@ _lident: %inline sword: | n=word { n } +| PLUS n=word { n } | MINUS n=word { -n } (* -------------------------------------------------------------------- *) @@ -2572,7 +2584,7 @@ codepos: codepos_range: | LBRACKET cps=codepos DOTDOT cpe=codepos RBRACKET { (cps, `Base cpe) } -| LBRACKET cps=codepos MINUS cpe=codepos1 RBRACKET { (cps, `Offset cpe) } +| LBRACKET cps=codepos PLUS cpe=codepos1 RBRACKET { (cps, `Offset cpe) } codepos_or_range: | cp=codepos { (cp, `Offset (0, `ByPos 0)) } @@ -3061,8 +3073,11 @@ interleave_info: | FUSION s=side? o=codepos NOT i=word AT d1=word COMMA d2=word { Pfusion (s, o, (i, (d1, d2))) } -| UNROLL b=boption(FOR) s=side? o=codepos - { Punroll (s, o, b) } +| UNROLL s=side? o=codepos + { Punroll (s, o, `While) } + +| UNROLL FOR b=boption(STAR) s=side? o=codepos + { Punroll (s, o, `For b) } | SPLITWHILE s=side? o=codepos COLON c=expr %prec prec_tactic { Psplitwhile (c, s, o) } @@ -3196,8 +3211,8 @@ interleave_info: | LOSSLESS { Plossless } -| PROC CHANGE side=side? pos=loc(codepos_or_range) COLON s=brace(stmt) - { Pchangestmt (side, (unloc pos), s) } +| PROC CHANGE side=side? pos=loc(codepos_or_range) COLON b=option(bracket(ptybindings)) s=brace(stmt) + { Pchangestmt (side, b, (unloc pos), s) } | PROC REWRITE side=side? pos=codepos f=pterm { Pprocrewrite (side, pos, `Rw f) } @@ -3205,9 +3220,88 @@ interleave_info: | PROC REWRITE side=side? pos=codepos SLASHEQ { Pprocrewrite (side, pos, `Simpl) } +| PROC CHANGE CIRCUIT b=option(bracket(ptybindings)) o=codepos PLUS w=word s=brace(stmt) + { Prwprgm (`Change (o, b, w, s)) } + | IDASSIGN o=codepos x=lvalue_var { Prwprgm (`IdAssign (o, x)) } +bd_var: +| s=lident LBRACKET t=qoident COLON j=uint RBRACKET + { `Slice (s, (t,j)) :> bdepvar } + +| s=lident + { `Var s :> bdepvar } + +bd_vars: +| vs=plist0(bd_var, SEMICOLON) + { vs } + +| v=lident COLON w=word + { [(`VarRange (v, w) :> bdepvar)] } + +bdepeq_out_info: +| m=word COLON LBRACKET outvs_l=bd_vars TILD outvs_r=bd_vars RBRACKET + { (m, outvs_l, outvs_r) } + +%public phltactic: +| BDEP + n=word + m=word + invs=bracket(bd_vars) + inpvs=bracket(bd_vars) + outvs=bracket(bd_vars) + lane=oident + pcond=oident + perm=oident? + debug=AIG? + { Pbdep { n; m; invs; inpvs; outvs; pcond; lane; perm; debug = Option.is_some debug} } + +| BDEP STAR + in_ty=bracket(loc(simpl_type_exp)) + invs=bracket(bd_vars) + inpvs=bracket(bd_vars) + out_ty=bracket(loc(simpl_type_exp)) + outvs=bracket(bd_vars) + lane=oident + range=sform + { Pbdepeval { in_ty; out_ty; invs; inpvs; outvs; lane; range; sign=false } } + +| BDEP STAR STAR + in_ty=bracket(loc(simpl_type_exp)) + invs=bracket(bd_vars) + inpvs=bracket(bd_vars) + out_ty=bracket(loc(simpl_type_exp)) + outvs=bracket(bd_vars) + lane=oident + range=sform + { Pbdepeval { in_ty; out_ty; invs; inpvs; outvs; lane; range; sign=true } } + +| BDEPEQ + n=word + inpvs_l=bracket(bd_vars) + inpvs_r=bracket(bd_vars) + out_blocks=brace(plist0(bdepeq_out_info, SEMICOLON)) + pcond=oident? + { Pbdepeq { n; inpvs_l; inpvs_r; out_blocks; pcond; preprocess=false} } + +| BDEPEQ STAR + n=word + inpvs_l=bracket(bd_vars) + inpvs_r=bracket(bd_vars) + out_blocks=brace(plist0(bdepeq_out_info, SEMICOLON)) + pcond=oident? + { Pbdepeq { n; inpvs_l; inpvs_r; out_blocks; pcond; preprocess=true} } + +| CIRCUIT STAR f=bracket(form) v=lident + { Pcirc (f, (`Var v :> bdepvar)) } + +| CIRCUIT + { Pcircuit (`Solve ) } + +| CIRCUIT SIMPLIFY + { Pcircuit (`Simplify ) } + bdhoare_split: | b1=sform b2=sform b3=sform? { BDH_split_bop (b1,b2,b3) } @@ -3271,9 +3365,9 @@ eqobs_in_eqpost: eqobs_in: | pos=eqobs_in_pos? i=eqobs_in_eqinv p=eqobs_in_eqpost? { - { sim_pos = pos; - sim_hint = i; - sim_eqs = p; } + { psim_pos = pos; + psim_hint = i; + psim_eqs = p; } } pgoptionkw: @@ -3375,6 +3469,9 @@ tactic_core_r: { Pcase (odfl false eq, odfl [] opts, { pr_view = vw; pr_rev = gp; } ) } +| EXTENS v=option(bracket(lident)) COLON t=tactic_core + { Pextens (t, v) } + | PROGRESS opts=pgoptions? t=tactic_core? { Pprogress (odfl [] opts, t) } @@ -3821,6 +3918,35 @@ user_red_option: (Some ("invalid option: " ^ (unloc x))) } +(* -------------------------------------------------------------------- *) +(* Circuit & bo bindings *) + +(* FIXME:merge-bdep generic option parser *) + +spec_binding: +| op=qoident LARROW circ=loc(STRING) + { (op, circ) } + +cr_binding_r: +| BIND BITSTRING from_=qoident to_=qoident touint=qoident tosint=qoident ofint=qoident type_=loc(simpl_type_exp) size=sform + { CRB_Bitstring { from_; to_; touint; tosint; ofint; type_; size; } } + +| BIND ARRAY get=qoident set=qoident tolist=qoident oflist=qoident type_=qoident size=sform + { CRB_Array { get; set; tolist; oflist; type_; size; } } + +| BIND OP type_=qident operator=qoident name=loc(STRING) + { CRB_BvOperator { types = [type_]; operator; name; } } + +| BIND OP types=bracket(plist1(qident, AMP)) operator=qoident name=loc(STRING) + { CRB_BvOperator { types; operator; name; } } + +| BIND CIRCUIT bindings=plist1(spec_binding, COMMA) + { CRB_Circuit { bindings } } + +%inline cr_binding: +| locality=is_local binding=cr_binding_r + { { locality; binding; }} + (* -------------------------------------------------------------------- *) (* Search pattern *) %inline search: x=sform_h { x } @@ -3859,6 +3985,7 @@ global_action: | gprover_info { Gprover_info $1 } | addrw { Gaddrw $1 } | hint { Ghint $1 } +| cr_binding { Gcrbinding $1 } | x=loc(proofend) { Gsave x } | PRINT p=print { Gprint p } | SEARCH x=search+ { Gsearch x } diff --git a/src/ecParsetree.ml b/src/ecParsetree.ml index 0189383d61..df6ae1c14a 100644 --- a/src/ecParsetree.ml +++ b/src/ecParsetree.ml @@ -694,10 +694,10 @@ type conseq_info = type conseq_ppterm = ((pformula option pair) * (conseq_info) option) gppterm (* -------------------------------------------------------------------- *) -type sim_info = { - sim_pos : pcodepos1 pair option; - sim_hint : (pgamepath option pair * pformula) list * pformula option; - sim_eqs : pformula option +type psim_info = { + psim_pos : pcodepos1 pair option; + psim_hint : (pgamepath option pair * pformula) list * pformula option; + psim_eqs : pformula option } (* -------------------------------------------------------------------- *) @@ -722,6 +722,38 @@ type matchmode = [ | `SSided of side ] +(* -------------------------------------------------------------------- *) +type bdepvar = [`Var of psymbol | `VarRange of psymbol * int | `Slice of psymbol * (pqsymbol * zint)] + +type bdep_info = + { n : int + ; m : int + ; invs : bdepvar list + ; inpvs : bdepvar list + ; outvs : bdepvar list + ; pcond : psymbol + ; lane : psymbol + ; perm : psymbol option + ; debug : bool } + +type bdep_eval_info = + { in_ty : pty + ; out_ty : pty + ; invs : bdepvar list + ; inpvs : bdepvar list + ; outvs : bdepvar list + ; lane : psymbol + ; range : pformula + ; sign : bool } + +type bdepeq_info = + { n : int + ; inpvs_l : bdepvar list + ; inpvs_r : bdepvar list + ; out_blocks : (int * bdepvar list * bdepvar list) list + ; pcond : psymbol option + ; preprocess : bool } + (* -------------------------------------------------------------------- *) type prrewrite = [`Rw of ppterm | `Simpl] @@ -737,7 +769,7 @@ type phltactic = | Pasyncwhile of async_while_info | Pfission of (oside * pcodepos * (int * (int * int))) | Pfusion of (oside * pcodepos * (int * (int * int))) - | Punroll of (oside * pcodepos * bool) + | Punroll of (oside * pcodepos * [`While | `For of bool]) | Psplitwhile of (pexpr * oside * pcodepos) | Pcall of oside * call_info gppterm | Pcallconcave of (pformula * call_info gppterm) @@ -771,14 +803,13 @@ type phltactic = | Pfel of (pcodepos1 * fel_info) | Phoare | Pprbounded - | Psim of crushmode option* sim_info + | Psim of crushmode option* psim_info | Ptrans_stmt of trans_info | Prw_equiv of rw_eqv_info | Psymmetry | Pbdhoare_split of bdh_split - | Prwprgm of rwprgm | Pprocrewrite of side option * pcodepos * prrewrite - | Pchangestmt of side option * pcodepos_range * pstmt + | Pchangestmt of side option * ptybindings option * pcodepos_range * pstmt (* Eager *) @@ -797,8 +828,24 @@ type phltactic = | Pauto | Plossless + (* Map-reduce *) + | Pbdep of bdep_info + | Pbdepeval of bdep_eval_info + | Pbdepeq of bdepeq_info + | Pcirc of (pformula * bdepvar) + | Pcircuit of circuit_mode + + (* Program rewriting *) + | Prwprgm of rwprgm + and rwprgm = [ | `IdAssign of pcodepos * pqsymbol + | `Change of pcodepos * ptybindings option * int * pstmt +] + +and circuit_mode = [ + | `Simplify + | `Solve ] (* -------------------------------------------------------------------- *) @@ -1016,6 +1063,7 @@ and ptactic_core_r = | Por of ptactic * ptactic | Pseq of ptactics | Pcase of (bool * pcaseoptions * prevertv) + | Pextens of (ptactic_core * psymbol option) | Plogic of logtactic | PPhl of phltactic | Pprogress of ppgoptions * ptactic_core option @@ -1274,6 +1322,45 @@ type puserred = type threquire = psymbol option * (psymbol * psymbol option) list * [`Import|`Export] option +(* -------------------------------------------------------------------- *) +type pbind_bitstring = + { from_ : pqsymbol + ; to_ : pqsymbol + ; touint : pqsymbol + ; tosint : pqsymbol + ; ofint : pqsymbol + ; type_ : pty + ; size : pformula } + +(* -------------------------------------------------------------------- *) +type pbind_array = + { get : pqsymbol + ; set : pqsymbol + ; tolist : pqsymbol + ; oflist : pqsymbol + ; type_ : pqsymbol + ; size : pformula } + +(* -------------------------------------------------------------------- *) +type pbind_bvoperator = + { name : string located + ; types : pqsymbol list + ; operator : pqsymbol } + +(* -------------------------------------------------------------------- *) +type pbind_circuit = + { bindings : (pqsymbol * string located) list } + +(* -------------------------------------------------------------------- *) +type pcrbinding_r = + | CRB_Bitstring of pbind_bitstring + | CRB_Array of pbind_array + | CRB_BvOperator of pbind_bvoperator + | CRB_Circuit of pbind_circuit + +(* -------------------------------------------------------------------- *) +type pcrbinding = { locality : is_local; binding : pcrbinding_r } + (* -------------------------------------------------------------------- *) type global_action = | Gmodule of pmodule_def_or_decl @@ -1313,6 +1400,7 @@ type global_action = | Gpragma of psymbol | Goption of (psymbol * [`Bool of bool | `Int of int]) | GdumpWhy3 of string + | Gcrbinding of pcrbinding type global = { gl_action : global_action located; diff --git a/src/ecPrinting.ml b/src/ecPrinting.ml index b37fbe2283..30b1b4c74d 100644 --- a/src/ecPrinting.ml +++ b/src/ecPrinting.ml @@ -3652,6 +3652,80 @@ let rec pp_theory ppe (fmt : Format.formatter) (path, cth) = level (odfl "" base) (pp_list "@ " (pp_axhnt ppe)) axioms + | EcTheory.Th_crbinding (binding, lc) -> begin + match binding with + | CRB_Bitstring bs -> + Format.fprintf fmt "%abind bitstring %a %a %a %a%s." + pp_locality lc + (pp_opname ppe) bs.to_ + (pp_opname ppe) bs.from_ + (pp_tyname ppe) bs.type_ + (pp_form ppe) (fst bs.size) + (if Option.is_some (snd bs.size) then " (concrete)" else " (abstract)") + + | CRB_Array ba -> + Format.fprintf fmt "%abind array %a %a %a %a %a %a%s." + pp_locality lc + (pp_tyname ppe) ba.type_ + (pp_opname ppe) ba.get + (pp_opname ppe) ba.set + (pp_opname ppe) ba.tolist + (pp_opname ppe) ba.oflist + (pp_form ppe) (fst ba.size) + (if Option.is_some (snd ba.size) then " (concrete)" else " (abstract)") + + | CRB_BvOperator op -> + let kind = + match op.kind with + | `Add _ -> "add" + | `Sub _ -> "sub" + | `Mul _ -> "mul" + | `Div (_, false) -> "udiv" + | `Div (_, true ) -> "sdiv" + | `Rem (_, false) -> "urem" + | `Rem (_, true ) -> "srem" + | `Shl _ -> "shl" + | `Shls _ -> "shls" + | `Rol _ -> "rol" + | `Ror _ -> "ror" + | `Shr (_, false) -> "shr" + | `Shr (_, true ) -> "ashr" + | `Shrs (_, _, false) -> "shrs" + | `Shrs (_, _, true ) -> "ashrs" + | `Not _ -> "not" + | `Opp _ -> "opp" + | `And _ -> "and" + | `Or _ -> "or" + | `Xor _ -> "xor" + | `Lt (_, false) -> "ult" + | `Lt (_, true ) -> "slt" + | `Le (_, false) -> "ule" + | `Le (_, true ) -> "sle" + | `Init _ -> "init" + | `Get _ -> "get" + | `AInit _ -> "ainit" + | `Extend (_, _, false) -> "zextend" + | `Extend (_, _, true ) -> "sextend" + | `Extract _ -> "extract" + | `Insert _ -> "insert" + | `Concat _ -> "concat" + | `Truncate _ -> "truncate" + | `A2B _ -> "a2b" + | `B2A _ -> "b2a" + | `Map _ -> "map" + | `ASliceGet _ -> "asliceget" + | `ASliceSet _ -> "asliceset" + in + Format.fprintf fmt "%abind op [%a] %a \"%s\"." + pp_locality lc + (pp_list " & " (pp_tyname ppe)) op.types + (pp_opname ppe) op.operator + kind + + | CRB_Circuit cr -> + Format.fprintf fmt "%abind circuit %a \"%s\"." + pp_locality lc (pp_opname ppe) cr.operator cr.name + end | EcTheory.Th_alias (name, target) -> Format.fprintf fmt "theory %s = %a." name (pp_thname ~alias:false ppe) target diff --git a/src/ecScope.ml b/src/ecScope.ml index d8a4676f14..8713f8756d 100644 --- a/src/ecScope.ml +++ b/src/ecScope.ml @@ -339,6 +339,7 @@ type scope = { sc_options : GenOptions.options; sc_globdoc : string list; sc_locdoc : docstate; + sc_specs : string list; } and docstate = { @@ -449,7 +450,8 @@ let empty (gstate : EcGState.gstate) = sc_pr_uc = None; sc_options = GenOptions.freeze (); sc_globdoc = []; - sc_locdoc = DocState.empty; } + sc_locdoc = DocState.empty; + sc_specs = []; } (* -------------------------------------------------------------------- *) let env (scope : scope) = @@ -570,7 +572,8 @@ let for_loading (scope : scope) = sc_pr_uc = None; sc_options = GenOptions.for_loading scope.sc_options; sc_globdoc = []; - sc_locdoc = DocState.empty; } + sc_locdoc = DocState.empty; + sc_specs = scope.sc_specs; } (* FIXME: is this correct? *) (* -------------------------------------------------------------------- *) let subscope (scope : scope) (mode : EcTheory.thmode) (name : symbol) lc = @@ -587,6 +590,7 @@ let subscope (scope : scope) (mode : EcTheory.thmode) (name : symbol) lc = sc_options = GenOptions.for_subscope scope.sc_options; sc_globdoc = []; sc_locdoc = DocState.empty; + sc_specs = scope.sc_specs; } (* -------------------------------------------------------------------- *) @@ -2260,7 +2264,11 @@ module Ty = struct record.ELI.rc_tparams, `Record (scheme, record.ELI.rc_fields) in - bind scope (unloc name, { tyd_params; tyd_type; tyd_loca; }) + let tydecl = + { tyd_params; tyd_type; tyd_loca; + tyd_clinline = false; } in + + bind scope (unloc name, tydecl) (* ------------------------------------------------------------------ *) let add_subtype (scope : scope) ({ pl_desc = subtype } : psubtype located) = @@ -2269,9 +2277,10 @@ module Ty = struct let scope = let decl = EcDecl.{ - tyd_params = []; - tyd_type = `Abstract Sp.empty; - tyd_loca = `Global; (* FIXME:SUBTYPE *) + tyd_params = []; + tyd_type = `Abstract Sp.empty; + tyd_loca = `Global; (* FIXME:SUBTYPE *) + tyd_clinline = false; (* FIXME: tyd_clinline PR *) } in bind scope (unloc subtype.pst_name, decl) in let carrier = @@ -2364,9 +2373,10 @@ module Ty = struct let asty = let body = ofold (fun p tc -> Sp.add p tc) Sp.empty uptc in - { tyd_params = []; - tyd_type = `Abstract body; - tyd_loca = (lc :> locality); } in + { tyd_params = []; + tyd_type = `Abstract body; + tyd_loca = (lc :> locality); + tyd_clinline = false; } in let scenv = EcEnv.Ty.bind name asty scenv in (* Check for duplicated field names *) @@ -2691,7 +2701,517 @@ module Ty = struct failwith "unsupported" (* FIXME *) end -(* -------------------------------------------------------------------- *)module Search = struct +(* -------------------------------------------------------------------- *) +module Circuit = struct + type preoperator = [`Path of path | `Form of pformula] + + type clone = { + path : EcPath.path; + name : symbol; + local : is_local; + theories : (symbol * path) list; + types_ : (symbol * path) list; + operators : (symbol * preoperator) list; + proofs : symbol list; + } + + let doclone (scope : scope) (clone : clone) = + let loced x = mk_loc _dummy x in + let env = env scope in + + let evclone = + let do_type ((x, type_) : symbol * path) : symbol * ty_override located = + (x, loced (`ByPath type_, `Inline `Keep)) in + + let do_operator ((x, operator) : symbol * preoperator) : symbol * op_override located = + let operator = + match operator with + | `Path name -> `ByPath name + | `Form f -> + `BySyntax + { opov_tyvars = None + ; opov_args = [] + ; opov_retty = loced PTunivar + ; opov_body = f } + in (x, loced (operator, `Inline `Keep)) + in + + let do_theory (x : symbol) (theory : path) : EcThCloning.evclone = + let thenv = EcEnv.Theory.env_of_theory clone.path env in + let atheory = EcEnv.Theory.by_path (pqname clone.path x) thenv in + + List.fold_left (fun (evc : EcThCloning.evclone) (item : EcTheory.theory_item) -> + match item.ti_item with + | Th_operator (x, opdecl) -> begin + match opdecl.op_kind with + | OB_oper None -> + let ovrd = (`ByPath (pqname theory x), `Inline `Clear) in + { evc with evc_ops = Msym.add x (loced ovrd) evc.evc_ops } + | _ -> evc + end + | Th_type (x, _) -> + let ovrd = (`ByPath (pqname theory x), `Inline `Clear) in + { evc with evc_types = Msym.add x (loced ovrd) evc.evc_types } + | Th_axiom (x, _) -> + let evc_lemmas = + let proof = loced (EcPath.toqsymbol (pqname theory x)) in + let proof = Papply (`ExactType proof, None) in + let proof = loced (Plogic proof) in + let proof = (Some proof, `Inline `Clear, false) in + { evc.evc_lemmas with + ev_bynames = Msym.add x proof evc.evc_lemmas.ev_bynames } + in { evc with evc_lemmas } + | _ -> assert false + ) EcThCloning.evc_empty atheory.cth_items in + + { EcThCloning.evc_empty with + (* FIXME: PR: what to do here? *) + evc_types = (Msym.of_list (List.map do_type clone.types_) :> (EcThCloning.xty_override located MSym.t)); + (* FIXME: PR: what to do here? *) + evc_ops = (Msym.of_list (List.map do_operator clone.operators) :> (EcThCloning.xop_override located MSym.t)); + evc_ths = Msym.of_list (List.map (fun (x, th) -> (x, (do_theory x th, false))) clone.theories); (* FIXME PR: is the false here correct? *) + evc_lemmas = { + ev_bynames = + clone.proofs + |> List.map (fun name -> (name, (Some (loced (Ptry (loced (Pby None)))), `Alias, false))) + |> Msym.of_list; + ev_global = + (* FIXME PR: get this to work *) + [ +(* (Some (loced (Pby None)), Some [`Include, "bydone"]) *) + (None, None) + ; (None, None) ]; } } in + + let npath = EcPath.pqname (EcEnv.root env) clone.name in + let theory = EcEnv.Theory.by_path clone.path env in + + let (proofs, scope) = + EcTheoryReplay.replay (Cloning.hooks ~override_locality:(Some clone.local)) + ~abstract:false ~override_locality:(Some clone.local) ~incl:false + ~clears:Sp.empty ~renames:[] ~opath:clone.path ~npath + evclone scope (EcPath.basename npath, false, theory.cth_items, clone.local) (* FIXME PR: check extra arguments here *) + in + + let proofs = Cloning.replay_proofs scope `Check proofs in + + (proofs, scope) + + let add_bitstring (scope : scope) (local : is_local) (bs : pbind_bitstring) : scope = + let env = env scope in + + let type_ = + let ue = EcUnify.UniEnv.create None in + let ty = EcTyping.transty tp_tydecl env ue bs.type_ in + assert (EcUnify.UniEnv.closed ue); + ty_subst (Tuni.subst (EcUnify.UniEnv.close ue)) ty in + + let bspath = + match (EcEnv.ty_hnorm type_ env).ty_node with + | Tconstr (p, []) -> p + | _ -> + hierror ~loc:(bs.type_.pl_loc) + "bit-string type must be a monomorphic named type" in + + let from_, _ = EcEnv.Op.lookup bs.to_.pl_desc env in + let to_ , _ = EcEnv.Op.lookup bs.from_.pl_desc env in + let touint, _ = EcEnv.Op.lookup bs.touint.pl_desc env in + let tosint, _ = EcEnv.Op.lookup bs.tosint.pl_desc env in + let ofint, _ = EcEnv.Op.lookup bs.ofint.pl_desc env in + let name = String.concat "_" ("BVA" :: EcPath.tolist bspath) (* FIXME: not stable*) in + + let preclone = + { path = EcPath.fromqsymbol (["Top"; "QFABV"], "BV") + ; name = name + ; local = local + ; theories = [] + ; types_ = ["bv", bspath] + ; operators = + [ ("size" , `Form bs.size) + ; ("tolist", `Path to_) + ; ("oflist", `Path from_) + ; ("touint", `Path touint) + ; ("tosint", `Path tosint) + ; ("ofint" , `Path ofint) ] + ; proofs = [] } in + + let proofs, scope = doclone scope preclone in + + let size_f = EcTyping.trans_form env (EcUnify.UniEnv.create None) bs.size tint in + let size_i = try + Some (EcCallbyValue.norm_cbv EcReduction.full_red (EcEnv.LDecl.init env []) size_f |> destr_int |> BI.to_int) + with + | DestrError "destr_int" -> None + | EcEnv.NotReducible -> None + in + + let item = CRB_Bitstring + { from_; to_; touint; tosint; ofint; + type_ = bspath; + size = (size_f, size_i); + theory = pqname (EcEnv.root env) name; } in + + let item = EcTheory.mkitem ~import:true (EcTheory.Th_crbinding (item, local)) in + + let scope = { scope with sc_env = EcSection.add_item item scope.sc_env } in + + Ax.add_defer scope proofs + + let add_array (scope : scope) (local : is_local) (ba : pbind_array) : scope = + let env = env scope in + + let bspath = + match EcEnv.Ty.lookup_opt (unloc ba.type_) env with + | None -> + hierror ~loc:(loc ba.type_) + "cannot find named type: `%s'" + (string_of_qsymbol (unloc ba.type_)) + + | Some (path, decl) -> (* FIXME: normalize? *) + if List.length decl.tyd_params <> 1 then + hierror ~loc:(loc ba.type_) + "type constructor should take exactly one parameter: `%s'" + (string_of_qsymbol (unloc ba.type_)); + path in + + let get , _ = EcEnv.Op.lookup ba.get.pl_desc env in + let set , _ = EcEnv.Op.lookup ba.set.pl_desc env in + let tolist, _ = EcEnv.Op.lookup ba.tolist.pl_desc env in + let oflist, _ = EcEnv.Op.lookup ba.oflist.pl_desc env in + let name = String.concat "_" ("BVA" :: EcPath.tolist bspath) in + + let preclone = + { path = EcPath.fromqsymbol (["Top"; "QFABV"], "A") + ; name = name + ; local = local + ; theories = [] + ; types_ = ["t", bspath] + ; operators = + [ ("size" , `Form ba.size) + ; ("get" , `Path get) + ; ("set" , `Path set) + ; ("to_list", `Path tolist) + ; ("of_list", `Path oflist) ] + ; proofs = [] } in + + let proofs, scope = doclone scope preclone in + + let size_f = EcTyping.trans_form env (EcUnify.UniEnv.create None) ba.size tint in + let size_i = try + Some (EcCallbyValue.norm_cbv EcReduction.full_red (EcEnv.LDecl.init env []) size_f |> destr_int |> BI.to_int) + with + | DestrError "destr_int" -> None + | EcEnv.NotReducible -> None + in + + let item = CRB_Array + { get; set; tolist; oflist; + type_ = bspath; + size = (size_f, size_i); + theory = pqname (EcEnv.root env) name; } in + + let item = EcTheory.mkitem ~import:true (Th_crbinding (item, local)) in + + let scope = { scope with sc_env = EcSection.add_item item scope.sc_env } in + + Ax.add_defer scope proofs + + let add_bvoperator (scope : scope) (local : is_local) (op : pbind_bvoperator) : scope = + let env = env scope in + + let (kind, sig_, subname) : (_ -> EcDecl.bv_opkind) * _ * _ = + match unloc op.name with + | "add" -> (fun sz -> `Add (as_seq1 sz )), [`BV None], "Add" + | "sub" -> (fun sz -> `Sub (as_seq1 sz )), [`BV None], "Sub" + | "mul" -> (fun sz -> `Mul (as_seq1 sz )), [`BV None], "Mul" + | "udiv" -> (fun sz -> `Div (as_seq1 sz, false)), [`BV None], "UDiv" + | "sdiv" -> (fun sz -> `Div (as_seq1 sz, true )), [`BV None], "SDiv" + | "urem" -> (fun sz -> `Rem (as_seq1 sz, false)), [`BV None], "URem" + | "srem" -> (fun sz -> `Rem (as_seq1 sz, true )), [`BV None], "SRem" + | "shl" -> (fun sz -> `Shl (as_seq1 sz )), [`BV None], "SHL" + | "rol" -> (fun sz -> `Rol (as_seq1 sz )), [`BV None], "ROL" + | "ror" -> (fun sz -> `Ror (as_seq1 sz )), [`BV None], "ROR" + | "shr" -> (fun sz -> `Shr (as_seq1 sz, false)), [`BV None], "SHR" + | "ashr" -> (fun sz -> `Shr (as_seq1 sz, true )), [`BV None], "ASHR" + | "and" -> (fun sz -> `And (as_seq1 sz )), [`BV None], "And" + | "or" -> (fun sz -> `Or (as_seq1 sz )), [`BV None], "Or" + | "xor" -> (fun sz -> `Xor (as_seq1 sz )), [`BV None], "Xor" + | "not" -> (fun sz -> `Not (as_seq1 sz )), [`BV None], "Not" + | "opp" -> (fun sz -> `Opp (as_seq1 sz )), [`BV None], "Opp" + + | "ult" -> (fun sz -> `Lt (snd (as_seq2 sz), false)), [`BV (Some 1); `BV None], "ULt" + | "slt" -> (fun sz -> `Lt (snd (as_seq2 sz), true )), [`BV (Some 1); `BV None], "SLt" + | "ule" -> (fun sz -> `Le (snd (as_seq2 sz), false)), [`BV (Some 1); `BV None], "ULe" + | "sle" -> (fun sz -> `Le (snd (as_seq2 sz), true )), [`BV (Some 1); `BV None], "SLe" + + | "init" -> (fun sz -> `Init (snd (as_seq2 sz))), [`BV (Some 1); `BV None], "Init" + | "get" -> (fun sz -> `Get (fst (as_seq2 sz))), [`BV None; `BV (Some 1)], "Get" + + | "ainit" -> (fun sz -> `AInit (as_seq2 (sz |> List.rev))), [`BV None; `A], "AInit" + + | "shls" -> + let mk sz = let sz1, sz2 = as_seq2 sz in `Shls (sz1, sz2) in + mk, [`BV None; `BV None], "SHLS" + + | "shrs" -> + let mk sz = let sz1, sz2 = as_seq2 sz in `Shrs (sz1, sz2, false) in + mk, [`BV None; `BV None], "SHRS" + + | "ashrs" -> + let mk sz = let sz1, sz2 = as_seq2 sz in `Shrs (sz1, sz2, true) in + mk, [`BV None; `BV None], "ASHRS" + + | "zextend" -> + let mk sz = let sz1, sz2 = as_seq2 sz in `Extend (sz1, sz2, false) in + mk, [`BV None; `BV None], "ZExtend" + + | "sextend" -> + let mk sz = let sz1, sz2 = as_seq2 sz in `Extend (sz1, sz2, true) in + mk, [`BV None; `BV None], "SExtend" + + | "truncate" -> + let mk sz = let sz1, sz2 = as_seq2 sz in `Truncate (sz1, sz2) in + mk, [`BV None; `BV None], "Truncate" + + | "insert" -> + let mk sz = let sz1, sz2 = as_seq2 sz in `Insert (sz1, sz2) in + mk, [`BV None; `BV None], "Insert" + + | "extract" -> + let mk sz = let sz1, sz2 = as_seq2 sz in `Extract (sz1, sz2) in + mk, [`BV None; `BV None], "Extract" + + | "asliceget" -> + let mk sz = let sz1, sz2, arr_sz = as_seq3 sz in `ASliceGet ((arr_sz, sz1), sz2) in + mk, [`BV None; `BV None; `A], "ASliceGet" + + | "asliceset" -> + let mk sz = let sz1, sz2, arr_sz = as_seq3 sz in `ASliceSet ((arr_sz, sz1), sz2) in + mk, [`BV None; `BV None; `A], "ASliceSet" + + | "concat" -> + let mk sz = let sz1, sz2, sz3 = as_seq3 sz in `Concat (sz1, sz2, sz3) in + mk, [`BV None; `BV None; `BV None], "Concat" + + | "a2b" -> + let mk sz = + let sz1, sz2, asz = as_seq3 sz in `A2B ((sz2, asz), sz1) in + mk, [`BV None; `BV None; `A], "A2B" + + | "b2a" -> + let mk sz = + let sz1, sz2, asz = as_seq3 sz in `B2A (sz1, (sz2, asz)) in + mk, [`BV None; `BV None; `A], "B2A" + + | "map" -> + let mk sz = + let sz1, sz2, asz = as_seq3 sz in `Map (sz1, sz2, asz) in + mk, [`BV None; `BV None; `A], "Map" + + | _ -> + hierror ~loc:(loc op.name) + "invalid bv operator name: %s" (unloc op.name) in + + if List.compare_lengths sig_ op.types <> 0 then + hierror ~loc:(loc op.operator) + "%d type(s) should be provided" (List.length sig_); + + let check_type (mode : [`BV of int option | `A]) (ty : pqsymbol) = + let path = + match EcEnv.Ty.lookup_opt (unloc ty) env, mode with + | None, _ -> + hierror ~loc:(loc ty) + "cannot find named type: `%s'" + (string_of_qsymbol (unloc ty)) + + | Some (path, decl), `BV _ -> (* FIXME: normalize? *) + if List.length decl.tyd_params <> 0 then + hierror ~loc:(loc ty) + "a bit-string type must be a monomorphic named type"; + path + + | Some (path, decl), `A -> + if List.length decl.tyd_params <> 1 then + hierror ~loc:(ty.pl_loc) + "an array type must be a 1-polymorphic named type"; + path + in + + let (size, theory) = + match mode with + | `BV osize -> begin + match EcEnv.Circuit.lookup_bitstring_path env path with + | None -> + hierror ~loc:(ty.pl_loc) + "this type is not bound to a bitstring type" + | Some {size = (_ , Some csize) as size; theory} -> + osize |> Option.iter (fun osize -> + if osize <> csize then + hierror ~loc:(ty.pl_loc) + "this type is not bound to a bitstring type of size %d (but of size %d)" + osize csize + ); + (size, theory) + | Some { size = (_, None) as size; theory} -> + osize |> Option.iter (fun osize -> + hierror ~loc:(ty.pl_loc) + "This type is not bound to a concrete bitstring of size %d (it is abstract)" + osize + ); + (size, theory) + end + | `A -> begin + match EcEnv.Circuit.lookup_array_path env path with + | None -> + hierror ~loc:(ty.pl_loc) + "this type is not bound to an array type" + | Some ba -> (ba.size, ba.theory) + end + in (path, size, (mode, theory)) + + in + + let types = List.map2 check_type sig_ op.types in + let subname = "BV" ^ subname in + + let operator, _ = EcEnv.Op.lookup op.operator.pl_desc env in + let name = + let suffix = List.map (EcPath.tolist |- proj3_1) types in + let suffix = List.flatten suffix in + String.concat "_" ("BVA" :: unloc op.name :: suffix) (* FIXME: not stable*) in + + let _, cltheories = + let string_of_mode = function `A -> "A" | `BV -> "BV" in + let strip_mode_arg = function `A -> `A | `BV _ -> `BV in + + let counts0 = + [`A; `BV] + |> List.to_seq + |> Seq.map (fun mode -> (mode, 0)) + |> BatMap.of_seq in + + let maxs = + List.fold_left (fun counts mode -> + let mode = strip_mode_arg mode in + BatMap.modify mode ((+) 1) counts + ) counts0 sig_ in + + List.fold_left_map (fun counts (_, _, (mode, theory)) -> + let mode = strip_mode_arg mode in + let prefix = string_of_mode mode in + + let counts, name = + if BatMap.find mode maxs < 2 then + (counts, prefix) + else + let counts = BatMap.modify mode ((+) 1) counts in + let name = Format.sprintf "%s%d" prefix (BatMap.find mode counts) in + (counts, name) + in (counts, (name, theory)) + ) counts0 types in + + let preclone = + { path = EcPath.fromqsymbol (["Top"; "QFABV"; "BVOperators"], subname) + ; name = name + ; local = local + ; theories = cltheories + ; types_ = [] + ; operators = ["bv" ^ unloc op.name, `Path operator] + ; proofs = [] } in + + let proofs, scope = doclone scope preclone in + + let item = CRB_BvOperator + { kind = kind (List.map proj3_2 types); + types = List.map proj3_1 types; + operator = operator; + theory = EcPath.pqname (EcEnv.root env) subname; } in + + let item = EcTheory.mkitem ~import:true (Th_crbinding (item, local)) in + + let scope = + { scope with sc_env = EcSection.add_item item scope.sc_env } in + + Ax.add_defer scope proofs + + let add_circuit1 (scope : scope) (local : is_local) ((op, circ) : (pqsymbol * string located)) : scope = + let env = env scope in + let operator, opdecl = EcEnv.Op.lookup op.pl_desc env in + + if not (List.is_empty opdecl.op_tparams) then + hierror ~loc:(loc op) "operator must be monomorphic"; + + let matches = List.filteri_map (fun i filename -> + EcEnv.Circuit.get_specification_by_name ~filename env (unloc circ)) scope.sc_specs + in + + match matches with + | [] -> + hierror ~loc:(loc circ) + "unknown circuit: %s" (unloc circ) + + | circuit::[] -> + let sig_ = List.map (fun (_, `W i) -> i) circuit.arguments in + let ret = Lospecs.Ast.get_size circuit.rettype in + let dom, codom = EcEnv.Ty.decompose_fun opdecl.op_ty env in + + if List.length dom <> List.length sig_ then + hierror ~loc:(loc op) + "the given operator must take %d arguments" + (List.length sig_); + + List.iteri (fun position (ty, size) -> + match EcEnv.Circuit.lookup_bitstring env ty with + | Some {size = (_, Some bs_size)} when bs_size = size -> () + | Some {size = (_, bs_size)} -> + let ppe = EcPrinting.PPEnv.ofenv env in + hierror ~loc:(loc op) + "%d-th argument (of type %a) must be a bitstring of size %d, not %s" + (position + 1) (EcPrinting.pp_type ppe) ty + size (Option.value (Option.map string_of_int bs_size) ~default:("abstract")) + | None -> + let ppe = EcPrinting.PPEnv.ofenv env in + hierror ~loc:(loc op) + "%d-th argument (of type %a) must be a bitstring" + (position + 1) (EcPrinting.pp_type ppe) ty + ) (List.combine dom sig_); + + begin + match EcEnv.Circuit.lookup_bitstring env codom with + | Some {size = (_, Some bs_size)} when bs_size = ret -> () + | Some {size = (_, bs_size)} -> + let ppe = EcPrinting.PPEnv.ofenv env in + hierror ~loc:(loc op) + "operator return type (%a) must be a bitstring of size %d, not %s" + (EcPrinting.pp_type ppe) codom ret + (Option.value (Option.map string_of_int bs_size) ~default:("abstract")) + | None -> + let ppe = EcPrinting.PPEnv.ofenv env in + hierror ~loc:(loc op) + "operator return type (%a) must be a bitstring of size %d" + (EcPrinting.pp_type ppe) codom ret + end; + + let item = CRB_Circuit { operator; circuit; name = unloc circ; } in + + let item = + EcTheory.mkitem ~import:true + (EcTheory.Th_crbinding (item, local)) in + { scope with sc_env = EcSection.add_item item scope.sc_env } + | circs -> Format.eprintf "Multiple matches found (%d) for circuit %s" (List.length circs) (unloc circ); assert false + (* FIXME *) + + (* FIXME: Decide if we want set or append here *) + let register_spec_files (scope : scope) (files : string list) : scope = + { scope with sc_specs = files } + + let add_circuits (scope : scope) (local : is_local) (binds : pbind_circuit) : scope = + List.fold_left (fun scope bnd -> + add_circuit1 scope local bnd) + scope binds.bindings +end + +(* -------------------------------------------------------------------- *) +module Search = struct let search (scope : scope) qs = let env = env scope in let paths = diff --git a/src/ecScope.mli b/src/ecScope.mli index d64007674c..9f916b17c9 100644 --- a/src/ecScope.mli +++ b/src/ecScope.mli @@ -266,6 +266,16 @@ module Reduction : sig val add_reduction : scope -> puserred -> scope end +(* -------------------------------------------------------------------- *) +module Circuit : sig + val add_bitstring : scope -> EcTypes.is_local -> pbind_bitstring -> scope + val add_array : scope -> EcTypes.is_local -> pbind_array -> scope + val add_bvoperator : scope -> EcTypes.is_local -> pbind_bvoperator -> scope + val add_circuits : scope -> EcTypes.is_local -> pbind_circuit -> scope + + val register_spec_files : scope -> string list -> scope +end + (* -------------------------------------------------------------------- *) module Cloning : sig val clone : scope -> Ax.proofmode -> theory_cloning -> scope diff --git a/src/ecSection.ml b/src/ecSection.ml index 3534ea58e2..c7410f87df 100644 --- a/src/ecSection.ml +++ b/src/ecSection.ml @@ -22,6 +22,7 @@ type cbarg = [ | `ModuleType of path | `Typeclass of path | `Instance of tcinstance + | `Crbind of crbinding * is_local ] type cb = cbarg -> unit @@ -50,11 +51,20 @@ let pp_cbarg env fmt (who : cbarg) = Format.fprintf fmt "module type %a" (EcPrinting.pp_modtype1 ppe) mty | `Typeclass p -> Format.fprintf fmt "typeclass %a" (EcPrinting.pp_tcname ppe) p - | `Instance tci -> + | `Instance tci -> begin match tci with | `Ring _ -> Format.fprintf fmt "ring instance" | `Field _ -> Format.fprintf fmt "field instance" | `General _ -> Format.fprintf fmt "instance" + end + | `Crbind (CRB_Bitstring _, _) -> + Format.fprintf fmt "bitstring binding" + | `Crbind (CRB_Array _, _) -> + Format.fprintf fmt "array binding" + | `Crbind (CRB_BvOperator _, _) -> + Format.fprintf fmt "bitstring operator binding" + | `Crbind (CRB_Circuit _, _) -> + Format.fprintf fmt "circuit binding" let pp_locality fmt = function | `Local -> Format.fprintf fmt "local" @@ -515,6 +525,7 @@ let locality (env : EcEnv.env) (who : cbarg) = | _ -> `Global end | `ModuleType p -> ((EcEnv.ModTy.by_path p env).tms_loca :> locality) + | `Crbind (_, lc) -> (lc :> locality) | `Instance _ -> assert false (* -------------------------------------------------------------------- *) @@ -786,7 +797,8 @@ let generalize_tydecl to_gen prefix (name, tydecl) = let to_gen = { to_gen with tg_subst} in let tydecl = { tyd_params; tyd_type; - tyd_loca = `Global; } in + tyd_loca = `Global; + tyd_clinline = tydecl.tyd_clinline; } in to_gen, Some (Th_type (name, tydecl)) | `Declare -> @@ -1027,6 +1039,13 @@ let generalize_auto to_gen auto_rl = else to_gen, Some (Th_auto {auto_rl with axioms}) +let generalize_crbinding (to_gen : to_gen) ((bd, lc) : crbinding * is_local) = + (* FIXME: not complete? *) + let bd = EcSubst.subst_crbinding to_gen.tg_subst bd in + let item = + if lc = `Local then None else Some (Th_crbinding (bd, lc)) + in to_gen, item + (* --------------------------------------------------------------- *) let get_locality scenv = scenv.sc_loca @@ -1053,6 +1072,7 @@ let rec set_lc_item lc_override item = | Th_baserw (s,lc) -> Th_baserw (s, set_lc lc_override lc) | Th_addrw (p,ps,lc) -> Th_addrw (p, ps, set_lc lc_override lc) | Th_reduction r -> Th_reduction r + | Th_crbinding (bd, lc) -> Th_crbinding (bd, set_lc lc_override lc) | Th_auto auto_rl -> Th_auto {auto_rl with locality=set_lc lc_override auto_rl.locality} | Th_alias alias -> Th_alias alias @@ -1065,7 +1085,6 @@ and set_local_th lc_override th = let sc_decl_mod (id,mt) = SC_decl_mod (id,mt) (* ---------------------------------------------------------------- *) - let is_abstract_ty = function | `Abstract _ -> true | _ -> false @@ -1118,16 +1137,16 @@ let cd_glob = d_tc = [`Global]; } -let can_depend (cd : can_depend) = function +let can_depend (cd : can_depend) (who : cbarg) = + match who with | `Type _ -> cd.d_ty | `Op _ -> cd.d_op | `Ax _ -> cd.d_ax - | `Sc _ -> cd.d_sc | `Module _ -> cd.d_mod | `ModuleType _ -> cd.d_modty | `Typeclass _ -> cd.d_tc | `Instance _ -> assert false - + | `Crbind _ -> assert false (* FIXME *) let cb scenv from cd who = let env = scenv.sc_env in @@ -1310,6 +1329,47 @@ let check_instance scenv ty tci lc = let cd = { cd_glob with d_ty = [`Declare; `Global]; } in on_instance (cb scenv from cd) ty tci +let check_crb_bitstring (scenv : scenv) ((bs, lc) : crb_bitstring * is_local) = + let from = (lc :> locality), `Crbind (CRB_Bitstring bs, lc) in + if lc = `Local then + check_section scenv from + else if scenv.sc_insec then begin + List.iter (fun p -> cb scenv from cd_glob (`Op p)) [bs.from_; bs.to_]; + cb scenv from cd_glob (`Type bs.type_) + end + +let check_crb_array (scenv : scenv) ((ba, lc) : crb_array * is_local) = + let from = (lc :> locality), `Crbind (CRB_Array ba, lc) in + if lc = `Local then + check_section scenv from + else if scenv.sc_insec then begin + List.iter (fun p -> cb scenv from cd_glob (`Op p)) [ba.get; ba.set; ba.tolist; ba.oflist]; + cb scenv from cd_glob (`Type ba.type_) + end + +let check_crb_bvoperator (scenv : scenv) ((op, lc) : crb_bvoperator * is_local) = + let from = (lc :> locality), `Crbind (CRB_BvOperator op, lc) in + if lc = `Local then + check_section scenv from + else if scenv.sc_insec then begin + cb scenv from cd_glob (`Op op.operator); + List.iter (fun ty -> cb scenv from cd_glob (`Type ty)) op.types + end + +let check_crb_circuit (scenv : scenv) ((cr, lc) : crb_circuit * is_local) = + let from = (lc :> locality), `Crbind (CRB_Circuit cr, lc) in + if lc = `Local then + check_section scenv from + else if scenv.sc_insec then + cb scenv from cd_glob (`Op cr.operator) + +let check_crbinding (scenv : scenv) ((crb, lc) : crbinding * is_local) = + match crb with + | CRB_Bitstring bs -> check_crb_bitstring scenv (bs, lc) + | CRB_Array ba -> check_crb_array scenv (ba, lc) + | CRB_BvOperator op -> check_crb_bvoperator scenv (op, lc) + | CRB_Circuit cr -> check_crb_circuit scenv (cr, lc) + (* -----------------------------------------------------------*) let enter_theory (name:symbol) (lc:is_local) (mode:thmode) scenv : scenv = if not scenv.sc_insec && lc = `Local then @@ -1348,6 +1408,8 @@ let add_item_ ?(override_locality=None) (item : theory_item) (scenv:scenv) = | Th_module me -> EcEnv.Mod.bind ~import me.tme_expr.me_name me env | Th_typeclass(s,tc) -> EcEnv.TypeClass.bind ~import s tc env | Th_export (p, lc) -> EcEnv.Theory.export p lc env + | Th_crbinding (bd, lc) -> EcEnv.Circuit.bind_crbinding lc bd env + | Th_theory _ -> assert false | Th_instance (tys,i,lc) -> EcEnv.TypeClass.add_instance ~import tys i lc env (*FIXME: import? *) | Th_baserw (s,lc) -> EcEnv.BaseRw.add ~import s lc env | Th_addrw (p,ps,lc) -> EcEnv.BaseRw.addto ~import p ps lc env @@ -1356,7 +1418,6 @@ let add_item_ ?(override_locality=None) (item : theory_item) (scenv:scenv) = auto.axioms auto.locality env | Th_alias (n,p) -> EcEnv.Theory.alias ~import n p env | Th_reduction r -> EcEnv.Reduction.add ~import r env - | _ -> assert false in (item, { scenv with sc_env = env; @@ -1370,6 +1431,7 @@ let add_th ~import (cth : EcEnv.Theory.compiled_theory) scenv = let rec generalize_th_item (to_gen : to_gen) (prefix : path) (th_item : theory_item) = let to_gen, item = match th_item.ti_item with + | Th_crbinding (bd, lc) -> generalize_crbinding to_gen (bd, lc) | Th_type tydecl -> generalize_tydecl to_gen prefix tydecl | Th_operator opdecl -> generalize_opdecl to_gen prefix opdecl | Th_axiom ax -> generalize_axiom to_gen prefix ax @@ -1384,7 +1446,6 @@ let rec generalize_th_item (to_gen : to_gen) (prefix : path) (th_item : theory_i | Th_reduction rl -> generalize_reduction to_gen rl | Th_auto hints -> generalize_auto to_gen hints | Th_alias _ -> (to_gen, None) (* FIXME:ALIAS *) - in let scenv = @@ -1491,7 +1552,8 @@ let check_item scenv item = | Th_auto { locality } -> if (locality = `Local && not scenv.sc_insec) then hierror "local hint can only be declared inside section"; - | Th_reduction _ -> () + | Th_reduction _ -> () (* FIXME *) + | Th_crbinding (crb, lc) -> check_crbinding scenv (crb, lc) | Th_theory _ -> assert false | Th_alias _ -> () (* FIXME:ALIAS *) diff --git a/src/ecSubst.ml b/src/ecSubst.ml index 45bdb2f747..6b22d34938 100644 --- a/src/ecSubst.ml +++ b/src/ecSubst.ml @@ -862,9 +862,10 @@ let subst_tydecl (s : subst) (tyd : tydecl) = let s, tparams = fresh_tparams s tyd.tyd_params in let body = subst_tydecl_body s tyd.tyd_type in - { tyd_params = tparams; - tyd_type = body; - tyd_loca = tyd.tyd_loca; } + { tyd_params = tparams; + tyd_type = body; + tyd_loca = tyd.tyd_loca; + tyd_clinline = tyd.tyd_clinline; } (* -------------------------------------------------------------------- *) let rec subst_op_kind (s : subst) (kind : operator_kind) = @@ -1014,6 +1015,101 @@ let subst_tc (s : subst) tc = let tc_axs = List.map (snd_map (subst_form s)) tc.tc_axs in { tc_prt; tc_ops; tc_axs; tc_loca = tc.tc_loca } +let subst_binding_size ?(red: (form -> int option) option) (s: subst) (bsize: binding_size) = + (* FIXME: add reduction? *) + let fsize = subst_form s (fst bsize) in + let csize = match red with + | Some red when Option.is_none (snd bsize) -> red fsize + | _ -> (snd bsize) + in (fsize, csize) + +let subst_bv_opkind ?(red: (form -> int option) option) (s: subst) (opk: bv_opkind) = + let ssize = subst_binding_size ?red s in + match opk with + | `Extend (s1, s2, sgn) -> `Extend (ssize s1, ssize s2, sgn) + | `Rem (s, sgn) -> `Rem (ssize s, sgn) + | `Div (s, sgn) -> `Div (ssize s, sgn) + | `Add (s) -> `Add (ssize s) + | `Lt (s, sgn) -> `Lt (ssize s, sgn) + | `Shl (s) -> `Shl (ssize s) + | `Shls (s1, s2) -> `Shls (ssize s1, ssize s2) + | `ASliceSet ((s1, s2), s3) -> `ASliceSet ((ssize s1, ssize s2), ssize s3) + | `And s -> `And (ssize s) + | `Extract (s1, s2) -> `Extract (ssize s1, ssize s2) + | `Map (s1, s2, s3) -> `Map (ssize s1, ssize s2, ssize s3) + | `AInit (s1, s2) -> `AInit (ssize s1, ssize s2) + | `Sub s -> `Sub (ssize s) + | `Get s -> `Get (ssize s) + | `Ror s -> `Ror (ssize s) + | `Le (s, sgn) -> `Le (ssize s, sgn) + | `Concat (s1, s2, s3) -> `Concat (ssize s1, ssize s2, ssize s3) + | `Truncate (s1, s2) -> `Truncate (ssize s1, ssize s2) + | `Not (s) -> `Not (ssize s) + | `Opp (s) -> `Opp (ssize s) + | `Or (s) -> `Or (ssize s) + | `Init (s) -> `Init (ssize s) + | `Insert (s1, s2) -> `Insert (ssize s1, ssize s2) + | `Xor (s) -> `Xor (ssize s) + | `Shr (s, sgn) -> `Shr (ssize s, sgn) + | `Shrs (s1, s2, sgn) -> `Shrs (ssize s1, ssize s2, sgn) + | `Mul (s) -> `Mul (ssize s) + | `Rol (s) -> `Rol (ssize s) + | `A2B ((s1, s2), s3) -> `A2B ((ssize s1, ssize s2), ssize s3) + | `ASliceGet ((s1, s2), s3) -> `ASliceGet ((ssize s1, ssize s2), ssize s3) + | `B2A (s1, (s2, s3)) -> `B2A (ssize s1, (ssize s2, ssize s3)) + +(* -------------------------------------------------------------------- *) +let subst_crbinding ?(red: (form -> int option) option) (s : subst) (crb : crbinding) = + match crb with + | CRB_Bitstring bs -> + assert (not (Mp.mem bs.type_ s.sb_tydef)); + assert (not (Mp.mem bs.from_ s.sb_def)); + assert (not (Mp.mem bs.to_ s.sb_def)); + assert (not (Mp.mem bs.touint s.sb_def)); + assert (not (Mp.mem bs.tosint s.sb_def)); + assert (not (Mp.mem bs.ofint s.sb_def)); + (* FIXME : maybe add an assert here? *) + CRB_Bitstring { + type_ = subst_path s bs.type_; + from_ = subst_path s bs.from_; + to_ = subst_path s bs.to_; + touint = subst_path s bs.touint; + tosint = subst_path s bs.tosint; + ofint = subst_path s bs.ofint; + size = subst_binding_size ?red s bs.size; + theory = subst_path s bs.theory; } + + | CRB_Array ba -> + assert (not (Mp.mem ba.type_ s.sb_tydef)); + assert (not (Mp.mem ba.get s.sb_def)); + assert (not (Mp.mem ba.set s.sb_def)); + assert (not (Mp.mem ba.tolist s.sb_def)); + assert (not (Mp.mem ba.oflist s.sb_def)); + CRB_Array { + type_ = subst_path s ba.type_; + get = subst_path s ba.get; + set = subst_path s ba.set; + tolist = subst_path s ba.tolist; + oflist = subst_path s ba.oflist; + size = subst_binding_size ?red s ba.size; + theory = subst_path s ba.theory } + + | CRB_BvOperator op -> + assert (List.for_all (fun ty -> not (Mp.mem ty s.sb_tydef)) op.types); + assert (not (Mp.mem op.operator s.sb_def)); + CRB_BvOperator { + kind = subst_bv_opkind ?red s op.kind; + types = List.map (subst_path s) op.types; + operator = subst_path s op.operator; + theory = subst_path s op.theory; } + + | CRB_Circuit cr -> + assert (not (Mp.mem cr.operator s.sb_def)); + CRB_Circuit { + name = cr.name; + circuit = cr.circuit; + operator = subst_path s cr.operator; } + (* -------------------------------------------------------------------- *) (* SUBSTITUTION OVER THEORIES *) let rec subst_theory_item_r (s : subst) (item : theory_item_r) = @@ -1060,6 +1156,9 @@ let rec subst_theory_item_r (s : subst) (item : theory_item_r) = Th_auto { auto_rl with axioms = List.map (fst_map (subst_path s)) axioms } + | Th_crbinding (bd, lc) -> + Th_crbinding (subst_crbinding s bd, lc) + | Th_alias (name, target) -> Th_alias (name, subst_path s target) @@ -1181,4 +1280,4 @@ let ss_inv_forall_ml_ts_inv menvl inv = let ss_inv_forall_mr_ts_inv menvr inv = let inv' = f_forall_mems [menvr] (ts_inv_rebind_right inv (fst menvr)).inv in - { inv=inv'; m=inv.ml } \ No newline at end of file + { inv=inv'; m=inv.ml } diff --git a/src/ecSubst.mli b/src/ecSubst.mli index 5ad3879db6..85c2e6a6ff 100644 --- a/src/ecSubst.mli +++ b/src/ecSubst.mli @@ -79,6 +79,11 @@ val subst_ss_inv : subst -> ss_inv -> ss_inv val subst_ts_inv : subst -> ts_inv -> ts_inv val subst_inv : subst -> inv -> inv +(* -------------------------------------------------------------------- *) +val subst_crbinding : ?red:(form -> int option) -> subst -> crbinding -> crbinding +val subst_bv_opkind : ?red:(form -> int option) -> subst -> bv_opkind -> bv_opkind +val subst_binding_size : ?red:(form -> int option) -> subst -> binding_size -> binding_size + (* -------------------------------------------------------------------- *) val open_oper : operator -> ty list -> ty * operator_kind val open_tydecl : tydecl -> ty list -> ty_body diff --git a/src/ecThCloning.ml b/src/ecThCloning.ml index a2f24e593a..57898d338b 100644 --- a/src/ecThCloning.ml +++ b/src/ecThCloning.ml @@ -466,6 +466,7 @@ end = struct | Th_addrw _ -> (proofs, evc) | Th_reduction _ -> (proofs, evc) | Th_auto _ -> (proofs, evc) + | Th_crbinding _ -> (proofs, evc) | Th_alias _ -> (proofs, evc) and doit prefix (proofs, evc) dth = diff --git a/src/ecTheory.ml b/src/ecTheory.ml index ffe226d060..4aac273460 100644 --- a/src/ecTheory.ml +++ b/src/ecTheory.ml @@ -31,6 +31,7 @@ and theory_item_r = | Th_baserw of symbol * is_local | Th_addrw of EcPath.path * EcPath.path list * is_local | Th_reduction of (EcPath.path * rule_option * rule option) list + | Th_crbinding of crbinding * is_local | Th_auto of auto_rule | Th_alias of (symbol * path) (* FIXME: currently, only theories *) diff --git a/src/ecTheory.mli b/src/ecTheory.mli index f246ee3f40..d16cb910ca 100644 --- a/src/ecTheory.mli +++ b/src/ecTheory.mli @@ -28,6 +28,7 @@ and theory_item_r = | Th_addrw of EcPath.path * EcPath.path list * is_local (* reduction rule does not survive to section so no locality *) | Th_reduction of (EcPath.path * rule_option * rule option) list + | Th_crbinding of crbinding * is_local | Th_auto of auto_rule | Th_alias of (symbol * path) diff --git a/src/ecTheoryReplay.ml b/src/ecTheoryReplay.ml index 0eff5cf66d..222bc5f64e 100644 --- a/src/ecTheoryReplay.ml +++ b/src/ecTheoryReplay.ml @@ -407,6 +407,41 @@ let rename ?(fold = true) ove subst (kind, name) = with Not_found -> (subst, name) +(* -------------------------------------------------------------------- *) +exception InvInstPath + +(* -------------------------------------------------------------------- *) +let forpath ~(opath : EcPath.path) ~(npath : EcPath.path) ~(ops : _ Mp.t) (p : EcPath.path) = + match EcPath.remprefix ~prefix:opath ~path:p |> omap List.rev with + | None | Some [] -> None + | Some (x::px) -> + let q = EcPath.fromqsymbol (List.rev px, x) in + + match Mp.find_opt q ops with + | None -> + Some (EcPath.pappend npath q) + | Some (op, alias) -> + match alias with + | true -> Some (EcPath.pappend npath q) + | false -> + match op.EcDecl.op_kind with + | OB_pred _ + | OB_nott _ -> assert false + | OB_oper None -> None + | OB_oper (Some (OP_Constr _)) + | OB_oper (Some (OP_Record _)) + | OB_oper (Some (OP_Proj _)) + | OB_oper (Some (OP_Fix _)) + | OB_oper (Some (OP_TC )) -> + Some (EcPath.pappend npath q) + | OB_oper (Some (OP_Plain f)) -> + match f.f_node with + | Fop (r, _) -> Some r + | _ -> raise InvInstPath + +let forpath ~opath ~npath ~ops p = + odfl p (forpath ~opath ~npath ~ops p) + (* -------------------------------------------------------------------- *) let rec replay_tyd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, otyd) = let scenv = ove.ovre_hooks.henv scope in @@ -428,18 +463,29 @@ let rec replay_tyd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, otyd let ue = EcUnify.UniEnv.create (Some nargs) in let ntyd = EcTyping.transty EcTyping.tp_tydecl env ue ntyd in let decl = - { tyd_params = nargs; - tyd_type = `Concrete ntyd; - tyd_loca = otyd.tyd_loca; } + { tyd_params = nargs; + tyd_type = `Concrete ntyd; + tyd_loca = otyd.tyd_loca; + tyd_clinline = (mode <> `Alias); } in (decl, ntyd) | `ByPath p -> begin match EcEnv.Ty.by_path_opt p env with | Some reftyd -> - let tyargs = List.map (fun (x, _) -> EcTypes.tvar x) reftyd.tyd_params in - let body = tconstr p tyargs in - let decl = { reftyd with tyd_type = `Concrete body; } in + let body = + if reftyd.tyd_clinline then + (match reftyd.tyd_type with + | `Concrete body -> body + | _ -> assert false) + else + let tyargs = + List.map (fun (x, _) -> EcTypes.tvar x) reftyd.tyd_params in + tconstr p tyargs in + let decl = + { reftyd with + tyd_type = `Concrete body; + tyd_clinline = (mode <> `Alias); } in (decl, body) | _ -> assert false @@ -448,10 +494,11 @@ let rec replay_tyd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, otyd | `Direct ty -> begin assert (List.is_empty otyd.tyd_params); let decl = - { tyd_params = []; - tyd_type = `Concrete ty; - tyd_loca = otyd.tyd_loca; } - + { tyd_params = []; + tyd_type = `Concrete ty; + tyd_loca = otyd.tyd_loca; + tyd_clinline = false; (* FIXME: check value here tyd_clinline PR *) + } in (decl, ty) end in @@ -998,39 +1045,7 @@ and replay_instance = let opath = ove.ovre_opath in let npath = ove.ovre_npath in - - let module E = struct exception InvInstPath end in - - let forpath (p : EcPath.path) = - match EcPath.remprefix ~prefix:opath ~path:p |> omap List.rev with - | None | Some [] -> None - | Some (x::px) -> - let q = EcPath.fromqsymbol (List.rev px, x) in - - match Mp.find_opt q ops with - | None -> - Some (EcPath.pappend npath q) - | Some (op, alias) -> - match alias with - | true -> Some (EcPath.pappend npath q) - | false -> - match op.EcDecl.op_kind with - | OB_pred _ - | OB_nott _ -> assert false - | OB_oper None -> None - | OB_oper (Some (OP_Constr _)) - | OB_oper (Some (OP_Record _)) - | OB_oper (Some (OP_Proj _)) - | OB_oper (Some (OP_Fix _)) - | OB_oper (Some (OP_TC )) -> - Some (EcPath.pappend npath q) - | OB_oper (Some (OP_Plain f)) -> - match f.f_node with - | Fop (r, _) -> Some r - | _ -> raise E.InvInstPath - in - - let forpath p = odfl p (forpath p) in + let forpath = forpath ~npath ~opath ~ops in try let (typ, ty) = EcSubst.subst_genty subst (typ, ty) in @@ -1066,9 +1081,144 @@ and replay_instance let scope = ove.ovre_hooks.hadd_item scope ~import (Th_instance ((typ, ty), tc, lc)) in (subst, ops, proofs, scope) - with E.InvInstPath -> + with InvInstPath -> + (subst, ops, proofs, scope) + +(* -------------------------------------------------------------------- *) +and replay_crb_bitstring (ove : _ ovrenv) (subst, ops, proofs, scope) (import, bs, lc) = + let opath = ove.ovre_opath in + let npath = ove.ovre_npath in + let forpath = forpath ~npath ~opath ~ops in + + let env = EcSection.env (ove.ovre_hooks.henv scope) in + let hyps = EcEnv.LDecl.init env [] in + let red f = try + Some (EcCallbyValue.norm_cbv EcReduction.full_red hyps f |> EcCoreFol.destr_int |> BI.to_int) + with + | EcCoreFol.DestrError "destr_int" -> None + | EcEnv.NotReducible -> None + in + + try + let to_ = forpath bs.to_ in + let from_ = forpath bs.from_ in + let touint = forpath bs.touint in + let tosint = forpath bs.tosint in + let ofint = forpath bs.ofint in + let type_ = match (EcSubst.subst_ty subst (tconstr bs.type_ [])).ty_node with + | Tconstr (p, []) -> p + | _ -> forpath bs.type_ (* FIXME: fallback *) + in + let theory = EcSubst.subst_path subst bs.theory in (* FIXME *) + let size = EcSubst.subst_binding_size ~red subst bs.size in + + let bs = CRB_Bitstring { to_; from_; touint; tosint; ofint; type_; theory; size; } in + let scope = ove.ovre_hooks.hadd_item scope ~import (Th_crbinding (bs, lc)) in + + (subst, ops, proofs, scope) + + with InvInstPath -> + (subst, ops, proofs, scope) + +(* -------------------------------------------------------------------- *) +and replay_crb_array (ove : _ ovrenv) (subst, ops, proofs, scope) (import, ba, lc) = + let opath = ove.ovre_opath in + let npath = ove.ovre_npath in + let forpath = forpath ~npath ~opath ~ops in + + let env = EcSection.env (ove.ovre_hooks.henv scope) in + let hyps = EcEnv.LDecl.init env [] in + let red f = try + Some (EcCallbyValue.norm_cbv EcReduction.full_red hyps f |> EcCoreFol.destr_int |> BI.to_int) + with + | EcCoreFol.DestrError "destr_int" -> None + | EcEnv.NotReducible -> None + in + + try + let get = forpath ba.get in + let set = forpath ba.set in + let tolist = forpath ba.tolist in + let oflist = forpath ba.oflist in + let type_ = match (EcSubst.subst_ty subst (tconstr ba.type_ [tint])).ty_node with (* FIXME: hack *) + | Tconstr (p, x::[]) -> p + | _ -> assert false; forpath ba.type_ + in + let size = EcSubst.subst_binding_size ~red subst ba.size in + let theory = EcSubst.subst_path subst ba.theory in (* FIXME *) + + let ba = CRB_Array { get; set; tolist; oflist; type_; size; theory; } in + let scope = ove.ovre_hooks.hadd_item scope ~import (Th_crbinding (ba, lc)) in + + (subst, ops, proofs, scope) + + + with InvInstPath -> + (subst, ops, proofs, scope) + +(* -------------------------------------------------------------------- *) +and replay_crb_bvoperator (ove : _ ovrenv) (subst, ops, proofs, scope) (import, op, lc) = + let opath = ove.ovre_opath in + let npath = ove.ovre_npath in + let forpath = forpath ~npath ~opath ~ops in + + let env = EcSection.env (ove.ovre_hooks.henv scope) in + let hyps = EcEnv.LDecl.init env [] in + let red f = try + Some (EcCallbyValue.norm_cbv EcReduction.full_red hyps f |> EcCoreFol.destr_int |> BI.to_int) + with + | EcCoreFol.DestrError "destr_int" -> None + | EcEnv.NotReducible -> None + in + + try + let kind = EcSubst.subst_bv_opkind ~red subst op.kind in + let operator = forpath op.operator in + let types = List.map forpath op.types in (* FIXME *) + let theory = forpath op.theory in (* FIXME *) + + let op = CRB_BvOperator { kind; operator; types; theory; } in + let scope = ove.ovre_hooks.hadd_item scope ~import (Th_crbinding (op, lc)) in + + (subst, ops, proofs, scope) + + with InvInstPath -> (subst, ops, proofs, scope) +(* -------------------------------------------------------------------- *) +and replay_crb_circuit (ove : _ ovrenv) (subst, ops, proofs, scope) (import, cr, lc) = + let opath = ove.ovre_opath in + let npath = ove.ovre_npath in + let forpath = forpath ~npath ~opath ~ops in + + try + let name = cr.name in + let circuit = cr.circuit in + let operator = forpath cr.operator in + + let cr = CRB_Circuit { name; circuit; operator; } in + let scope = ove.ovre_hooks.hadd_item scope ~import (Th_crbinding (cr, lc)) in + + (subst, ops, proofs, scope) + + with InvInstPath -> + (subst, ops, proofs, scope) + +(* -------------------------------------------------------------------- *) +and replay_crbinding (ove : _ ovrenv) (subst, ops, proofs, scope) (import, binding, lc) = + match binding with + | CRB_Bitstring bs -> + replay_crb_bitstring ove (subst, ops, proofs, scope) (import, bs, lc) + + | CRB_Array ba -> + replay_crb_array ove (subst, ops, proofs, scope) (import, ba, lc) + + | CRB_BvOperator op -> + replay_crb_bvoperator ove (subst, ops, proofs, scope) (import, op, lc) + + | CRB_Circuit cr -> + replay_crb_circuit ove (subst, ops, proofs, scope) (import, cr, lc) + (* -------------------------------------------------------------------- *) and replay_alias (ove : _ ovrenv) (subst, ops, proofs, scope) (import, name, target) @@ -1133,6 +1283,9 @@ and replay1 (ove : _ ovrenv) (subst, ops, proofs, scope) (hidden, item) = | Th_alias (name, target) -> replay_alias ove (subst, ops, proofs, scope) (item.ti_import, name, target) + | Th_crbinding (binding, lc) -> + replay_crbinding ove (subst, ops, proofs, scope) (item.ti_import, binding, lc) + | Th_theory (ox, cth) -> begin let thmode = cth.cth_mode in let (subst, x) = rename ove subst (`Theory, ox) in diff --git a/src/ecTypes.ml b/src/ecTypes.ml index 87efc57bee..11a6912b7d 100644 --- a/src/ecTypes.ml +++ b/src/ecTypes.ml @@ -65,13 +65,14 @@ let tfun t1 t2 = mk_ty (Tfun (t1, t2)) let tglob m = mk_ty (Tglob m) (* -------------------------------------------------------------------- *) -let tunit = tconstr EcCoreLib.CI_Unit .p_unit [] -let tbool = tconstr EcCoreLib.CI_Bool .p_bool [] -let tint = tconstr EcCoreLib.CI_Int .p_int [] -let txint = tconstr EcCoreLib.CI_xint .p_xint [] +let tunit = tconstr EcCoreLib.CI_Unit.p_unit [] +let tbool = tconstr EcCoreLib.CI_Bool.p_bool [] +let tint = tconstr EcCoreLib.CI_Int.p_int [] +let txint = tconstr EcCoreLib.CI_xint.p_xint [] let tdistr ty = tconstr EcCoreLib.CI_Distr.p_distr [ty] let toption ty = tconstr EcCoreLib.CI_Option.p_option [ty] +let tlist ty = tconstr EcCoreLib.CI_List.p_list [ty] let treal = tconstr EcCoreLib.CI_Real .p_real [] let tcpred ty = tfun ty tbool @@ -87,6 +88,18 @@ let ttuple lt = let toarrow dom ty = List.fold_right tfun dom ty +exception TyDestrError of string + +let tfrom_tlist ty = + match ty.ty_node with + | Tconstr (p, [ty]) when EcPath.p_equal p EcCoreLib.CI_List.p_list -> ty + | _ -> raise (TyDestrError "list") + +let tfrom_tfun2 ty = + match ty.ty_node with + | Tfun (a, b) -> (a, b) + | _ -> raise (TyDestrError "fun") + let tpred t = tfun t tbool (* -------------------------------------------------------------------- *) diff --git a/src/ecTypes.mli b/src/ecTypes.mli index 34b7b4cbf2..62fc9d4107 100644 --- a/src/ecTypes.mli +++ b/src/ecTypes.mli @@ -45,12 +45,18 @@ val txint : ty val treal : ty val tdistr : ty -> ty val toption : ty -> ty +val tlist : ty -> ty val tcpred : ty -> ty val toarrow : ty list -> ty -> ty val trealp : ty val txreal : ty +exception TyDestrError of string + +val tfrom_tlist : ty -> ty +val tfrom_tfun2 : ty -> ty * ty + val tytuple_flat : ty -> ty list val tyfun_flat : ty -> (dom * ty) diff --git a/src/ecTypesafeFol.ml b/src/ecTypesafeFol.ml new file mode 100644 index 0000000000..23b0df017b --- /dev/null +++ b/src/ecTypesafeFol.ml @@ -0,0 +1,152 @@ +open EcUtils +open EcAst +open EcTypes +open EcCoreFol +open EcUnify +open EcSubst +open EcEnv + +module Map = Batteries.Map + +module BI = EcBigInt +module Mp = EcPath.Mp +module Sp = EcPath.Sp +module Sm = EcPath.Sm +module Sx = EcPath.Sx +module UE = EcUnify.UniEnv + +type form = EcAst.form +type f_node = EcAst.f_node +type ty = EcTypes.ty + +let (%) f g = fun x -> f (g x) + +exception InsufficientArguments + +let tfrom_tlist ty = + let p_list = EcCoreLib.CI_List.p_list in + match ty.ty_node with + | Tconstr (p, [ty]) when p = p_list -> ty + | _ -> assert false + +let tfrom_tfun2 ty = + match ty.ty_node with + | Tfun (a, b) -> (a, b) + | _ -> assert false + +let unroll_ftype (ty:ty) : ty list * ty = + let rec doit (tys: ty list) (ty: ty) : ty list * ty = + match ty.ty_node with + | Tfun _ -> let t1, t2 = tfrom_tfun2 ty in doit (t1::tys) t2 + | _ -> (List.rev tys, ty) + in + + doit [] ty + +let ty_var_from_ty (ty:ty) : ty list = + match ty.ty_node with + | Tconstr (_, args) -> args + | _ -> assert false (* FIXME: how to handle this case ? *) + +(* Returned list is (tyvar, ty) *) +let rec match_ty_tyargs (ty: ty) (tyargs: ty) : (ty * ty) list = + match (ty.ty_node, tyargs.ty_node) with + | (Tconstr (p1, args1), Tconstr (p2, args2)) when p1 = p2 && (List.compare_lengths args1 args2 = 0) -> + List.flatten @@ List.map2 match_ty_tyargs args1 args2 + | (Ttuple args1, Ttuple args2) when (List.compare_lengths args1 args2 = 0) -> + List.flatten @@ List.map2 match_ty_tyargs args1 args2 + | (Tfun (ty11, ty12), Tfun (ty21, ty22)) -> + (match_ty_tyargs ty11 ty21) @ (match_ty_tyargs ty12 ty22) + | (_, Tvar _) -> [(ty, tyargs)] + | (_, Tunivar _) -> [(ty, tyargs)] + | _ -> assert false + +let rec sub_ty_tyargs (vals: (ty, ty) Map.t) (ty: ty) : ty = + match ty.ty_node with + | (Tconstr (p1, args1)) -> tconstr p1 (List.map (sub_ty_tyargs vals) args1) + | (Ttuple args1) -> ttuple (List.map (sub_ty_tyargs vals) args1) + | (Tfun (ty_arg, ty_ret)) -> tfun (sub_ty_tyargs vals ty_arg) (sub_ty_tyargs vals ty_ret) + | (Tvar _) -> Map.find ty vals + | (Tunivar _) -> Map.find ty vals + | (Tglob _) -> assert false + +let open_oper_ue op ue = + (* Maybe list map works fine because ue is imperative? *) + let open EcDecl in + let ue, tys = List.fold_left_map (fun ue _ -> (ue, EcUnify.UniEnv.fresh ue)) ue op.op_tparams in + (tys, open_oper op tys) + +let fop_from_path (env: env) (f: EcPath.path) : form = + let ue = UE.create None in + let p_f, o_f = EcEnv.Op.lookup (EcPath.toqsymbol f) env in + let tvars,(newt, _f_kind) = open_oper_ue o_f ue in + f_op f tvars newt + +let f_app_safe ?(full=true) (env: env) (f: EcPath.path) (args: form list) = + let ue = UE.create None in + let p_f, o_f = EcEnv.Op.lookup (EcPath.toqsymbol f) env in + let tvars,(newt,f_kind) = open_oper_ue o_f ue in + let rty = UE.fresh ue in + let fty = toarrow (List.map (fun f -> f.f_ty) args) rty in + let () = begin + try + (EcUnify.unify env ue fty newt) + with + | UnificationFailure (`TcCtt (ty, sp)) -> raise (UnificationFailure (`TcCtt (ty, sp))) + | UnificationFailure (`TyUni (ty1, ty2)) -> + let pp_type = (EcPrinting.pp_type (EcPrinting.PPEnv.ofenv env)) in + Format.eprintf "Failed to unify types (%a, %a) in call to %s@." pp_type ty1 pp_type ty2 + (let h,t = EcPath.toqsymbol f in List.fold_right (fun a b -> a ^ "." ^ b) h t); + raise (UnificationFailure (`TyUni (ty1, ty2))) + end + in + let uidmap = UE.assubst ue in + let subst = EcCoreSubst.Tuni.subst uidmap in + let rty = EcCoreSubst.ty_subst subst rty in + let newt = EcCoreSubst.ty_subst subst newt in + let tvars = List.map (EcCoreSubst.ty_subst subst) tvars in + let op = f_op p_f tvars newt in + if full then + match rty.ty_node with + | Tfun _ -> Format.eprintf "op: %a@.args: " (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) op; + List.iter (fun a -> Format.eprintf "%a, " (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) a) args; Format.eprintf "@."; + raise InsufficientArguments + | _ -> f_app op args rty + else + f_app op args rty + +let rec fapply_safe ?(redmode = EcReduction.full_red) (hyps: LDecl.hyps) (f: form) (fs: form list) : form = +(* + Format.eprintf "Applying forms:@.%a@.To form: %a@." + (fun fmt fs -> List.iter (fun f -> Format.fprintf fmt "%a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv (LDecl.toenv hyps))) f) fs) fs + (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv (LDecl.toenv hyps))) f; +*) + match f.f_node with + | Fop (pth, _) -> + f_app_safe ~full:false (LDecl.toenv hyps) pth fs |> EcCallbyValue.norm_cbv redmode hyps + | Fapp (fop, args) -> + (* let new_args = args @ fs in *) + (* let pp_form = EcPrinting.pp_form (EcPrinting.PPEnv.ofenv (LDecl.toenv hyps)) in *) + (* let pp_forms fmt = List.iter (Format.fprintf fmt "%a, " pp_form) in *) + (* Format.eprintf "new_args: %a@." pp_forms new_args; *) + fapply_safe ~redmode hyps fop (args @ fs) + | Fquant (Llambda, binds, f) -> + assert (List.compare_lengths binds fs >= 0); + let subst_bnds, rem_bnds = List.takedrop (List.length fs) binds in + let subst = + List.fold_left2 + (fun subst b f -> EcSubst.add_flocal subst (fst b) f) EcSubst.empty subst_bnds fs + in + let f = f_quant Llambda rem_bnds (EcSubst.subst_form subst f) in + EcCallbyValue.norm_cbv redmode hyps f + | Fquant (qtf, _, _) -> assert false + | Fif (f, ft, ff) -> assert false + | Fmatch (f, fs, t) -> assert false + | Flet (lpat, f, fb) -> assert false + | Fint (i) -> assert false + | Flocal (id) -> assert false + | Fpvar (pv, m) -> assert false + | Fglob (id, m) -> assert false + | Ftuple (fs) -> assert false + | Fproj (f, i) -> assert false + | _ -> assert false diff --git a/src/ecTyping.ml b/src/ecTyping.ml index 75f594a105..6ed6a52eb6 100644 --- a/src/ecTyping.ml +++ b/src/ecTyping.ml @@ -3637,6 +3637,13 @@ and trans_cp_base ?(memory : memory option) (env : EcEnv.env) (p : pcp_base) : c and trans_codepos1 ?(memory : memory option) (env : EcEnv.env) (p : pcodepos1) : codepos1 = snd_map (trans_cp_base ?memory env) p +(* -------------------------------------------------------------------- *) +(* FIXME: PR: Should this be kept? *) +and trans_codeoffset1 ?(memory : memory option) (env : EcEnv.env) (o : pcodeoffset1) : codeoffset1 = + match o with + | `ByOffset i -> `ByOffset i + | `ByPosition p -> `ByPosition (trans_codepos1 ?memory env p) + (* -------------------------------------------------------------------- *) and trans_codepos_brsel (bs : pbranch_select) : codepos_brsel = match bs with @@ -3665,11 +3672,6 @@ and trans_codepos_range ?(memory : memory option) (env : EcEnv.env) ((cps, cpe) and trans_dcodepos1 ?(memory : memory option) (env : EcEnv.env) (p : pcodepos1 doption) : codepos1 doption = DOption.map (trans_codepos1 ?memory env) p -and trans_codeoffset1 ?(memory: memory option) (env : EcEnv.env) (o : pcodeoffset1) : codeoffset1 = - match o with - | `ByOffset i -> `ByOffset i - | `ByPosition p -> `ByPosition (trans_codepos1 ?memory env p) - (* -------------------------------------------------------------------- *) let get_instances (tvi, bty) env = let inst = List.pmap diff --git a/src/ecTyping.mli b/src/ecTyping.mli index bf2da3aa21..0ac79f41be 100644 --- a/src/ecTyping.mli +++ b/src/ecTyping.mli @@ -231,6 +231,7 @@ val trans_codepos1 : ?memory:EcMemory.memory -> env -> pcodepos1 -> codepos1 val trans_codepos : ?memory:EcMemory.memory -> env -> pcodepos -> codepos val trans_dcodepos1 : ?memory:EcMemory.memory -> env -> pcodepos1 doption -> codepos1 doption val trans_codeoffset1 : ?memory:EcMemory.memory -> env -> pcodeoffset1 -> codeoffset1 +(* FIXME: trans_codeoffset to remove? *) (* -------------------------------------------------------------------- *) type ptnmap = ty EcIdent.Mid.t ref diff --git a/src/ecUtils.ml b/src/ecUtils.ml index e852ce0c09..801fcfe574 100644 --- a/src/ecUtils.ml +++ b/src/ecUtils.ml @@ -236,7 +236,8 @@ let oif (test : 'a -> bool) (x : 'a option) = let oget ?exn (x : 'a option) = match x, exn with - | None , None -> assert false + | None , None -> (* FIXME PR: Remove before merge *) + Printexc.get_callstack 100 |> Printexc.print_raw_backtrace stderr; assert false | None , Some exn -> raise exn | Some x, _ -> x @@ -600,6 +601,21 @@ module List = struct let has_dup ?(cmp = Stdlib.compare) (xs : 'a list) = Option.is_some (find_dup ~cmp xs) + let collapse ?(eq : 'a -> 'a -> bool = (=)) (xs : 'a list) = + match xs with + | [] -> None + | x :: xs -> if List.for_all (eq x) xs then Some x else None + + (* List of size n*w into list of n lists of size w *) + let chunkify (w : int) = + let rec doit (acc : 'a list list) (xs : 'a list) = + if is_empty xs then + rev acc + else + let hd, tl = takedrop w xs in + doit (hd :: acc) tl + in fun (xs : 'a list) -> doit [] xs + (* Separate list into a prefix for which p is true and the rest *) let takedrop_while (p: 'a -> bool) (xs : 'a list) = let rec doit (acc: 'a list) (xs : 'a list) = @@ -608,7 +624,6 @@ module List = struct | x::xs -> if p x then doit (x::acc) xs else (List.rev acc, x::xs) in doit [] xs - type 'a interruptible = [`Interrupt | `Continue of 'a] let fold_left_map_while (f : 'a -> 'b -> ('a * 'c) interruptible) = diff --git a/src/ecUtils.mli b/src/ecUtils.mli index 7d0a4c3c80..0098f0d4d3 100644 --- a/src/ecUtils.mli +++ b/src/ecUtils.mli @@ -300,6 +300,8 @@ module List : sig val reduce1 : ('a list -> 'a) -> 'a list -> 'a val find_dup : ?cmp:('a -> 'a -> int) -> 'a list -> 'a option val has_dup : ?cmp:('a -> 'a -> int) -> 'a list -> bool + val collapse : ?eq:('a -> 'a -> bool) -> 'a list -> 'a option + val chunkify : int -> 'a list -> 'a list list val takedrop_while : ('a -> bool) -> 'a list -> 'a list * 'a list diff --git a/src/phl/ecPhlBDep.ml b/src/phl/ecPhlBDep.ml new file mode 100644 index 0000000000..145c82542a --- /dev/null +++ b/src/phl/ecPhlBDep.ml @@ -0,0 +1,1662 @@ +(* -------------------------------------------------------------------- *) +open EcUtils +open EcIdent +open EcSymbols +open EcLocation +open EcParsetree +open EcAst +open EcEnv +open EcTypes +open EcCoreGoal +open EcFol +open EcLowCircuits +open EcCircuits +open LDecl + +(* -------------------------------------------------------------------- *) +module Map = Batteries.Map +module Hashtbl = Batteries.Hashtbl +module Set = Batteries.Set +module Option = Batteries.Option + +(* -------------------------------------------------------------------- *) +module C = struct + include Lospecs.Aig + include Lospecs.Circuit + include Lospecs.Circuit_spec +end + +module HL = struct + include Lospecs.Hlaig +end + +exception BDepError of string Lazy.t +exception BDepUninitializedInputs + +(* TODO: Refactor error printing and checking? Lots of duplicated code *) + +let int_of_form (hyps: hyps) (f: form) : BI.zint = + match f.f_node with + | Fint i -> i + | _ -> begin try destr_int @@ EcCallbyValue.norm_cbv EcReduction.full_red hyps f with + | DestrError _ -> let err = lazy (Format.asprintf "Failed to reduce form to int: %a@." + (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv (toenv hyps))) f) in + raise (BDepError err) + end + +let time (env: env) (t: float) (msg: string) : float = + let new_t = Unix.gettimeofday () in + EcEnv.notify ~immediate:true env `Info "[W] %s, took %f s@." msg (new_t -. t); + new_t + +let circ_of_qsymbol (hyps: hyps) (qs: qsymbol) : hyps * circuit = + try + let env = toenv hyps in + + let fpth, _fo = EcEnv.Op.lookup qs env in + let f = EcTypesafeFol.fop_from_path env fpth in + let f = EcCallbyValue.norm_cbv (EcCircuits.circ_red hyps) hyps f in + + let hyps, fc = circuit_of_form hyps f in + let fc = circuit_flatten fc in + let fc = circuit_aggregate_inps fc in + hyps, fc + with CircError err -> + raise (BDepError err) + +(* + f => arr_t.init (fun i => f.(i + offset)) + Assumes f has an array type binding + Assumes f has enough positions so that + arr_t.size + offset < size f (as array) +*) +let array_init_from_form (env: env) (f: form) ((arr_t, offset): qsymbol * BI.zint) : form = + let ppe = EcPrinting.PPEnv.ofenv env in + let tpath = match EcEnv.Ty.lookup_opt arr_t env with + | None -> raise (BDepError (lazy "Failed to lookup type for input slice")) + | Some (path, decl) when List.length decl.tyd_params = 1 -> + path + | Some ((_path, decl) as tdecl) -> + raise (BDepError (lazy (Format.asprintf "Type given to input slice (%a) does not look like an array type" EcPrinting.(pp_typedecl ppe) tdecl))) + in + let get = match EcEnv.Circuit.lookup_array env f.f_ty with + | Some { get } -> get + | None -> raise (BDepError (lazy (Format.asprintf "Failed to lookup array binding for type %a" EcPrinting.(pp_type ppe) f.f_ty))) + in + let init = EcEnv.Op.lookup_path (fst (tpath |> EcPath.toqsymbol), "init") env in + let idx = create "i" in + let f = f_lambda [(idx, GTty tint)] + (EcTypesafeFol.f_app_safe env get [f; f_int_add (f_local idx tint) (f_int offset)]) + in EcTypesafeFol.f_app_safe env init [f] + +(* -------------------------------------------------------------------- *) +let mapreduce + ?(debug: bool = false) + (hyps : hyps) + ((mem, mt): memenv) + (proc: stmt) + ((invs, n): (variable * (int * int) option) list * int) + ((outvs, m) : (variable * (int * int) option) list * int) + (f: psymbol) + (pcond: psymbol) + (perm: (int -> int) option) + : unit = + + let tm = Unix.gettimeofday () in + let env = toenv hyps in + + (* ------------------------------------------------------------------ *) + let hyps, fc = try + circ_of_qsymbol hyps ([], f.pl_desc) + with BDepError err -> + raise (BDepError (lazy ("Lane function circuit generation failed with error:\n" ^ (Lazy.force err)))) + in + if debug then EcEnv.notify ~immediate:true env `Warning "Writing lane function to file %s...@." @@ circuit_to_file ~name:"lane_function" fc; + + let tm = time env tm "Lane function circuit generation done" in + + (* ------------------------------------------------------------------ *) + let hyps, pcondc = try + circ_of_qsymbol hyps ([], pcond.pl_desc) + with BDepError err -> + raise (BDepError (lazy ("Precondition circuit generation failed with error:\n" ^ (Lazy.force err)))) + in + if debug then EcEnv.notify ~immediate:true env `Warning "Writing precondition function to file %s...@." @@ circuit_to_file ~name:"pcond" pcondc; + + let tm = time env tm "Precondition circuit generation done" in + + (* ------------------------------------------------------------------ *) + let pinvs = List.fst invs in + let hyps, st = try + EcCircuits.state_of_prog hyps mem proc.s_node pinvs + with CircError err -> + raise (BDepError err) + in + + let tm = time env tm "Program circuit generation done" in + + begin + let circs = List.map (function + | {v_name}, None -> Option.get (state_get_opt st mem v_name) + | {v_name}, Some (sz, offset) -> + circuit_slice (Option.get (state_get_opt st mem v_name)) sz offset + ) + outvs + in + Format.eprintf "Circs[0] = %a@." pp_circuit (List.hd circs); + Format.eprintf "Program variable names registered:@."; + List.iter (fun ((m, v), _) -> Format.eprintf "%s{%s}@." v (name m)) (state_get_all_pv st); + List.iteri (fun i c -> match circuit_has_uninitialized c with + | Some j -> EcEnv.notify ~immediate:true env `Critical "Bit %d of input %d has a dependency on an unititialized input@." j i; raise BDepUninitializedInputs + | None -> ()) circs; + + (* This is required for now as we do not allow mapreduce with multiple arguments *) + (* assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs = 1); *) + + (* ------------------------------------------------------------------ *) + let circs = + try + let slcs = List.snd invs in +(* if List.for_all Option.is_none slcs then circs else *) + List.map (fun c -> + (circuit_align_inputs c slcs) + ) circs + with CircError _ -> + raise (BDepError (lazy "Failed to align inputs to slice")) + in + + + (* ------------------------------------------------------------------ *) + let c = try + (circuit_aggregate circs) + with CircError _ -> + raise (BDepError (lazy "Failed to concatenate outputs")) + in + + let c = try + (circuit_aggregate_inps c) + with CircError _ -> + raise (BDepError (lazy "Failed to concatenate outputs")) + in + + if debug then EcEnv.notify ~immediate:true env `Info "Writing program circuit before mapreduce to file %s...@." @@ circuit_to_file ~name:"prog_no_mapreduce" c; + + (* ------------------------------------------------------------------ *) + let cs = try + circuit_mapreduce ?perm c n m + with CircError err -> + raise (BDepError err) + in + + let tm = time env tm "circuit dependecy analysis + splitting done" in + + if debug then EcEnv.notify ~immediate:true env `Info "Writing lane 0 circit to file %s...@." @@ circuit_to_file ~name:"lane_0" (List.hd cs); + + (* ------------------------------------------------------------------ *) + List.iteri (fun i c -> + if debug then EcEnv.notify ~immediate:true env `Info "Writing lane %d circit to file %s...@." (i+1) @@ circuit_to_file ~name:("lane_" ^ (string_of_int (i+1))) c; + if circ_equiv ~pcond:pcondc (List.hd cs) c + then () + else let err = lazy (Format.sprintf "Equivalence check failed between lanes 0 and %d" (i+1)) + in raise (BDepError err)) + (List.tl cs); + + let tm = time env tm "Program lanes equivs done" in + + (* ------------------------------------------------------------------ *) + if circ_equiv ~pcond:pcondc (List.hd cs) fc then () + else raise (BDepError (lazy "Equivalence failed between lane 0 and lane function")); + + let _tm = time env tm "Program to lane func equiv done" in + + EcEnv.notify ~immediate:true env `Info "Success@." + end + + +(* -------------------------------------------------------------------- *) +let prog_equiv_prod + (hyps : hyps) + ((meml, mtl), (memr, mtr): memenv * memenv) + (proc_l, proc_r: stmt * stmt) + ((invs_l, invs_r, n): (variable list * variable list * int)) + ((outvs_l, outvs_r, m) : (variable list * variable list * int)) + (pcond : form option) + (preprocess : bool ): unit = + + let env = toenv hyps in + + (* ------------------------------------------------------------------ *) + let hyps, pcond = match pcond with + | Some pcond -> begin try + let hyps, c = circuit_of_form hyps pcond in + hyps, Some c + with CircError err -> + raise (BDepError err) + end + | None -> hyps, None + in + let tm = Unix.gettimeofday () in + + (* ------------------------------------------------------------------ *) + let (hyps, st_l) : hyps * state = try + EcCircuits.state_of_prog hyps meml proc_l.s_node invs_l + with CircError err -> + raise (BDepError err) + in + let tm = time env tm "Left program generation done" in + + (* ------------------------------------------------------------------ *) + let (hyps, st_r) : hyps * state = try + EcCircuits.state_of_prog hyps memr proc_r.s_node invs_l + with CircError err -> + raise (BDepError err) + in + let tm = time env tm "Right program generation done" in + + begin + (* ------------------------------------------------------------------ *) + let circs_l = List.map (fun v -> state_get st_l meml v) + (List.map (fun v -> v.v_name) outvs_l) in + let circs_r = List.map (fun v -> state_get st_r memr v) + (List.map (fun v -> v.v_name) outvs_r) in + + List.iteri (fun i c -> match circuit_has_uninitialized c with + | Some j -> EcEnv.notify ~immediate:true env `Critical "Bit %d of input %d of the left program has a dependency on an unititialized input@." j i; raise BDepUninitializedInputs + | None -> ()) circs_l; + + List.iteri (fun i c -> match circuit_has_uninitialized c with + | Some j -> EcEnv.notify ~immediate:true env `Critical "Bit %d of input %d of the right program has a dependency on an unititialized input@." j i; raise BDepUninitializedInputs + | None -> ()) circs_r; + + (*assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_l = 1); *) + (*assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_r = 1);*) + (* ------------------------------------------------------------------ *) + let c_l = try + (circuit_aggregate circs_l) + with CircError _err -> + raise (BDepError (lazy "Failed to aggregate left program outputs")) + in + let c_r = try + (circuit_aggregate circs_r) + with CircError _err -> + raise (BDepError (lazy "Failed to aggregate right program outputs")) + in + (* ------------------------------------------------------------------ *) + let c_r = try + (circuit_aggregate_inps c_r) + with CircError _err -> + raise (BDepError (lazy "Failed to aggregate right program inputs")) + in + let c_l = try + (circuit_aggregate_inps c_l) + with CircError _err -> + raise (BDepError (lazy "Failed to aggregate right program inputs")) + in + + let tm = time env tm "Preprocessing for mapreduce done" in + + (* ------------------------------------------------------------------ *) + let lanes_l = try + circuit_mapreduce c_l n m + with CircError err -> + raise (BDepError (lazy ("Left program split step failed with error:\n" ^ (Lazy.force err)))) + in + let tm = time env tm "Left program deps + split done" in + + let lanes_r = try + circuit_mapreduce c_r n m + with CircError err -> + raise (BDepError (lazy ("Right program split step failed with error:\n" ^ (Lazy.force err)))) + in + let tm = time env tm "Right program deps + split done" in + + if preprocess then + begin + (* ------------------------------------------------------------------ *) + (List.iteri (fun i c -> + if circ_equiv ?pcond (List.hd lanes_l) c + then () + else let err = lazy (Format.sprintf "Left program lane equiv failed between lanes 0 and %d@." (i+i)) + in raise (BDepError err)) + (List.tl lanes_l)); + let tm = time env tm "Left program lanes equiv done" in + + (List.iteri (fun i c -> + if circ_equiv ?pcond (List.hd lanes_r) c + then () + else let err = lazy (Format.sprintf "Right program lane equiv failed between lanes 0 and %d@." (i+i)) + in raise (BDepError err)) + (List.tl lanes_r)); + let tm = time env tm "Right program lanes equiv done" in + + (* ------------------------------------------------------------------ *) + if (circ_equiv ?pcond (List.hd lanes_l) (List.hd lanes_r)) + then + time env tm "First lanes equiv done" |> ignore + else + raise (BDepError (lazy "Lane equiv failed between first lane of left and right programs")) + end + else + begin + List.iter2i (fun i c_l c_r -> + if circ_equiv ?pcond c_l c_r + then () + else let err = lazy (Format.sprintf "Lane equivalence failed between programs for lane %d@." i) in + raise (BDepError err)) lanes_l lanes_r; + time env tm "Program lane equivs done" |> ignore + end + end + +(* + Input: pstate -> Map from program variables to circuits, possibly empty + hyps + form -> form to be processed + Output: + Form with equalities between bitstring replaced by true + if both sides are equivalent as circuits + or false otherwise +*) + +let circ_form_eval_plus_equiv + ~(me: memenv) + (hyps: hyps) + (proc: stmt) + (f: form) + (v : variable) + : bool = + + (* ------------------------------------------------------------------ *) + assert(f.f_ty = tbool); + + (* ------------------------------------------------------------------ *) + let env = toenv hyps in + let env = EcEnv.Memory.push_active_ss me env in + let mem, mt = me in + let redmode = circ_red hyps in + let (@@!) = EcTypesafeFol.f_app_safe env in + + (* ------------------------------------------------------------------ *) + let size, of_int = match EcEnv.Circuit.lookup_bitstring env v.v_type with + | Some {size=(_, Some size); ofint} -> size, ofint + | Some {size=(_, None); ofint} -> + raise (BDepError (lazy "No concrete binding bitstring size")) + | None -> + let err = lazy (Format.asprintf "Binding not found for variable %s of type %a@." + v.v_name (EcPrinting.pp_type (EcPrinting.PPEnv.ofenv env)) v.v_type) + in raise (BDepError err) + in + + let rec test_values (size: int) (cur: BI.zint) : bool = + if debug then EcEnv.notify ~immediate:true env `Info "Testing for var = %s@." (BI.to_string cur); + + (* ------------------------------------------------------------------ *) + if Z.numbits (BI.to_zt cur) > size then true else (* If we reach the maximum value jump out *) + + (* ------------------------------------------------------------------ *) + let cur_bs = of_int @@! [f_int cur] in (* Current testing value as a bitstring *) + + let insts = List.map (fun i -> + match i.i_node with + | Sasgn (lv, e) -> + let f = (ss_inv_of_expr mem e) in + let fi = EcPV.PVM.subst1 env (PVloc v.v_name) mem cur_bs f.inv in + let fi = EcCallbyValue.norm_cbv redmode hyps fi in + let e = try expr_of_ss_inv {f with inv=fi} + with CannotTranslate -> + Format.eprintf "Failed on form : %a@." + EcPrinting.(pp_form PPEnv.(ofenv env)) fi; + raise CannotTranslate + in + EcCoreModules.i_asgn (lv, e) + | _ -> i + ) proc.s_node + in + + let st = circuit_state_of_memenv ~st:empty_state env me in + + let hyps, st = try + EcCircuits.state_of_prog ~st hyps mem insts [] + with CircError err -> + raise (BDepError err) + in + + (* ------------------------------------------------------------------ *) + (* FIXME: Suspicious *) + let f = EcPV.PVM.subst1 env (PVloc v.v_name) mem cur_bs f in + let hyps, pres = match state_get_opt st mem v.v_name with + | Some circ -> + begin try + Option.may (fun i -> + EcEnv.notify ~immediate:true env `Critical + "Bit %d of precondition circuit has dependency on uninitialized inputs@." i; + raise (BDepError (lazy "Uninitialized input for circuit")) ) + (circuit_has_uninitialized circ); + let hyps, c = (circuit_of_form hyps cur_bs) in + hyps, Some (circuit_eqs circ c) + with CircError err -> + raise (CircError err) +(* raise (BDepError (lazy ("Failed to generate circuit for current value precondition with error:\n" ^ (Lazy.force err)))) *) + end + | None -> hyps, None + in + + (* FIXME: how many times to reduce here ? *) + (* ------------------------------------------------------------------ *) + let f = EcCallbyValue.norm_cbv redmode hyps f in + let f = EcCircuits.circ_simplify_form_bitstring_equality ~st ?pres hyps f in + let f = EcCallbyValue.norm_cbv (EcReduction.full_red) hyps f in + + if f <> f_true then + (EcEnv.notify ~immediate:true env `Critical + "Got %a after reduction@." + (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f; + false) + else + test_values size (BI.(add cur one)) + in + test_values size (BI.zero) + +(* -------------------------------------------------------------------- *) +let mapreduce_eval + (hyps : hyps) + ((mem, mt): memenv) + (proc: stmt) + ((invs, n): variable list * int) + ((outvs, m) : variable list * int) + (f: psymbol) + (range: form list) + (sign: bool) + : unit = + + + let tm = Unix.gettimeofday () in + + (* ------------------------------------------------------------------ *) + let env = toenv hyps in + let fc = EcEnv.Op.lookup ([], f.pl_desc) env |> fst in + let (@@!) = EcTypesafeFol.f_app_safe env in + + let tm = time env tm "Lane function circuit generation done" in + + (* ------------------------------------------------------------------ *) + let hyps, st = try + EcCircuits.state_of_prog hyps mem proc.s_node invs + with CircError err -> + raise (BDepError err) + in + + let tm = time env tm "Program circuit generation done" in + + begin + let circs = List.map (fun v -> state_get st mem v) (List.map (fun v -> v.v_name) outvs) in + + List.iteri (fun i c -> match circuit_has_uninitialized c with + | Some j -> EcEnv.notify ~immediate:true env `Critical "Bit %d of input %d has a dependency on an unititialized input@." j i; raise BDepUninitializedInputs + | None -> ()) circs; + + (* ------------------------------------------------------------------ *) + let c = try + (circuit_aggregate circs) + with CircError _err -> + raise (BDepError (lazy "Failed to concatenate program outputs")) + in + let c = try + (circuit_aggregate_inps c) + with CircError _err -> + raise (BDepError (lazy "Failed to concatenate program outputs")) + in + + (* ------------------------------------------------------------------ *) + let cs = try + circuit_mapreduce c n m + with CircError err -> + raise (BDepError (lazy ("Split step failed with error:\n" ^ (Lazy.force err)))) + in + + let tm = time env tm "circuit dependecy analysis + splitting done" in + + (* ------------------------------------------------------------------ *) + List.iteri (fun i c -> + if circ_equiv (List.hd cs) c + then () + else let err = lazy (Format.sprintf "Equivalence failed between program lanes 0 and %d@." (i + 1)) in + raise (BDepError err) + ) + (List.tl cs); + + let tm = time env tm "Program lanes equivs done" in + + (* ------------------------------------------------------------------ *) + List.iter (fun v -> + let fv = v in + let v = destr_int v in + let lane_val = fc @@! [fv] in + let lane_val = int_of_form hyps lane_val in + let circ_val = compute ~sign (List.hd cs) [v] in + if circ_val = lane_val then () + else let err = + lazy (Format.sprintf "Error on input %s@.Circ val:%s | Lane val: %s@." + (BI.to_string v) + (BI.to_string circ_val) + (BI.to_string lane_val)) + in raise (BDepError err) + ) range; + + time env tm "Program to lane func equiv done" |> ignore + end + +let w2bits (env: env) (ty: ty) (arg: form) : form = + let tb = match EcEnv.Circuit.lookup_bitstring env ty with + | Some {to_=tb; _} -> tb + | _ -> let err = lazy (Format.asprintf "No w2bits for type %a@." (EcPrinting.pp_type (EcPrinting.PPEnv.ofenv env)) ty) in + raise (BDepError err) + in EcTypesafeFol.f_app_safe env tb [arg] + +let bits2w (env: env) (ty: ty) (arg: form) : form = + let fb = match EcEnv.Circuit.lookup_bitstring env ty with + | Some {from_=fb; _} -> fb + | _ -> let err = lazy (Format.asprintf "No bits2w for type %a@." (EcPrinting.pp_type (EcPrinting.PPEnv.ofenv env)) ty) in + raise (BDepError err) + in EcTypesafeFol.f_app_safe env fb [arg] + +let w2bits_op (env: env) (ty: ty) : form = + let tb = match EcEnv.Circuit.lookup_bitstring env ty with + | Some {to_=tb; _} -> tb + | _ -> let err = lazy (Format.asprintf "No bits2w for type %a@." (EcPrinting.pp_type (EcPrinting.PPEnv.ofenv env)) ty) in + raise (BDepError err) + in let tbp, tbo = EcEnv.Op.lookup (EcPath.toqsymbol tb) env in + f_op tb [] tbo.op_ty + +let bits2w_op (env: env) (ty: ty) : form = + let fb = match EcEnv.Circuit.lookup_bitstring env ty with + | Some {from_=fb; _} -> fb + | _ -> let err = lazy (Format.asprintf "No bits2w for type %a@." (EcPrinting.pp_type (EcPrinting.PPEnv.ofenv env)) ty) in + raise (BDepError err) + in let fbp, fbo = EcEnv.Op.lookup (EcPath.toqsymbol fb) env in + f_op fb [] fbo.op_ty + +let flatten_to_bits (env: env) (f: form) = + let (@@!) = EcTypesafeFol.f_app_safe env in + match EcEnv.Circuit.lookup_array_and_bitstring env f.f_ty with + | Some ({ tolist }, {type_; to_=tb}) -> + let base = tconstr type_ [] in + let w2bits = w2bits_op env base in + EcCoreLib.CI_List.p_flatten @@! + [EcCoreLib.CI_List.p_map @@! [w2bits; (tolist @@! [f])]] + | None -> + w2bits env f.f_ty f + +let reconstruct_from_bits (env: env) (f: form) (t: ty) = + (* Check input is a bool list *) + assert (match f.f_ty.ty_node with + | Tconstr(p, [b]) when p = EcCoreLib.CI_List.p_list -> b = tbool + | _ -> false); + let (@@!) = EcTypesafeFol.f_app_safe env in + match EcEnv.Circuit.lookup_array_and_bitstring env t with + | Some ({ oflist }, {type_; size = (_, Some size); ofint}) -> + let base = tconstr type_ [] in + oflist @@! [ ofint @@! [f_int (BI.of_int 0)]; + EcCoreLib.CI_List.p_map @@! [ bits2w_op env base; + EcCoreLib.CI_List.p_chunk @@! [(f_int (BI.of_int size)); f]]] + | Some ({ oflist }, {type_; size = (_, None); ofint}) -> + raise (BDepError (lazy "No concrete binding for type in reconstruct_from_bits")) (* FIXME: error messages *) + | _ -> + bits2w env t f + +(* FIXME: review and cleanup this section *) + +let reconstruct_from_bits_op (env: env) (t: ty) = + (* Check input is a bool list *) + match EcEnv.Circuit.lookup_array_and_bitstring env t with + | Some ({ oflist }, {type_; size = (_, Some size); ofint}) -> + let temp = create "temp" in + let bool_list = tconstr EcCoreLib.CI_List.p_list [tbool] in + f_quant Llambda [(temp, GTty bool_list)] @@ + reconstruct_from_bits env (f_local temp bool_list) t + | Some ({ oflist }, {type_; size = (_, None); ofint}) -> + raise (BDepError (lazy "No concrete binding for type in reconstruct_from_bits_op")) (* FIXME: error messages *) + | _ -> + bits2w_op env t + +let t_bdep + ?(debug: bool = false) + (n: int) + (m: int) + (inpvs: ((variable * (int * int) option) list)) + (outvs: ((variable * (int * int) option) list)) + (pcond: psymbol) + (op: psymbol) + (perm: (int -> int) option) + (tc : tcenv1) = + let () = match (FApi.tc1_goal tc).f_node with + | FhoareS sF -> if true then + begin try mapreduce ~debug (FApi.tc1_hyps tc) sF.hs_m sF.hs_s (inpvs, n) (outvs, m) op pcond perm with + | BDepError err -> tc_error (FApi.tc1_penv tc) "%s" (Lazy.force err) + end + else () + (* FIXME PR: Should these be guarded before call or do we just fail somehow if we hit it? *) + | FhoareF sH -> assert false + | FbdHoareF _ -> assert false + | FbdHoareS _ -> assert false + | FeHoareF _ -> assert false + | FeHoareS _ -> assert false + | _ -> assert false + in + FApi.close (!@ tc) VBdep + +let get_var (env: env) (v : bdepvar) (m : memenv) : (variable * ((qsymbol * BI.zint) option)) list = + let get1 (v : symbol) = + match EcMemory.lookup_me v m with + | Some (v, None, _) -> v + | _ -> let err = lazy (Format.asprintf "Couldn't locate variable %s@." v) in + raise (BDepError err) + in + match v with + | `Var v -> + [get1 (unloc v), None] + | `VarRange (v, n) -> + List.init n (fun i -> get1 (Format.sprintf "%s_%d" (unloc v) n), None) + | `Slice (v, (arr_t, off)) -> + [get1 (unloc v), Some(unloc arr_t, off)] + + +let get_vars (env: env) (vs : bdepvar list) (m : memenv) : (variable * ((qsymbol * BI.zint) option)) list = + List.flatten (List.map (fun v -> get_var env v m) vs) + +let blocks_from_vars (env: env) (vs: form list) (ty: ty) : form = + let (@@!) pth args = + try + EcTypesafeFol.f_app_safe env pth args + with EcUnify.UnificationFailure _ -> + let err = lazy (Format.sprintf "Type mismatch in pre-post generation, check your lane and precondition types@.") in + raise (BDepError err) + in + let m = try + width_of_type env ty + with CircError err -> + raise (BDepError (lazy ("Error while constructing precondition: \n" ^ (Lazy.force err)))) + in + (* let poutvs = List.map (fun v -> EcFol.f_pvar (pv_loc v.v_name) v.v_type mem) vs in *) + let poutvs = List.map (flatten_to_bits env) vs in + let poutvs = List.fold_right (fun v1 v2 -> EcCoreLib.CI_List.p_cons @@! [v1; v2]) poutvs (fop_empty (List.hd poutvs).f_ty) in + let poutvs = EcCoreLib.CI_List.p_flatten @@! [poutvs] in + let poutvs = EcCoreLib.CI_List.p_chunk @@! [f_int (BI.of_int m); poutvs] in + let poutvs = EcCoreLib.CI_List.p_map @@! [(reconstruct_from_bits_op env ty); poutvs] in + poutvs + +(* + Input: - form representing a list + - path for an operator of type int -> int representing a permutation + Output: form representing permuted list + *) +let permute_list (env: env) (perm: EcPath.path) (xs: form) : form = + let (@@!) pth args = + try + EcTypesafeFol.f_app_safe env pth args + with EcUnify.UnificationFailure _ -> + let err = lazy (Format.sprintf "Type mismatch in pre-post generation, check your lane and precondition types@.") in + raise (BDepError err) + in + let i = (create "i", GTty tint) in + let bty = tfrom_tlist xs.f_ty in + EcCoreLib.CI_List.p_mkseq @@! [ + f_lambda [i] + (EcCoreLib.CI_List.p_nth @@! [f_op EcCoreLib.CI_Witness.p_witness [bty] bty; xs; + perm @@! [f_local (fst i) tint]]); + EcCoreLib.CI_List.p_size @@! [xs] + ] + +(* FIXME: Add size checks for input and output *) +let process_bdep (bdinfo: bdep_info) (tc: tcenv1) = + let { n; invs; inpvs; m; outvs; lane; pcond; perm; debug } = bdinfo in + + let env = FApi.tc1_env tc in + let hyps = FApi.tc1_hyps tc in + let pe = FApi.tc1_penv tc in + let ppe = EcPrinting.PPEnv.ofenv env in + + let (@@!) pth args = + try + EcTypesafeFol.f_app_safe env pth args + with EcUnify.UnificationFailure _ -> + let err = Format.sprintf "Type mismatch in pre-post generation, check your lane and precondition types@." in + (* raise (BDepError err) *) + tc_error pe "%s" err + in + + (* FIXME: lookup is done twice here, should be easy to remove *) + let fperm_of_perm_op (perm: psymbol) : int -> int = + let pperm, bperm = EcEnv.Op.lookup ([], perm.pl_desc) env in + match bperm.op_kind with + | OB_oper (Some (OP_Plain + {f_node = Fquant (Llambda, bnd::[], + {f_node = Fapp ({f_node = Fop (pth, tys)}, [dfl; lst; idx])})})) when pth = EcCoreLib.CI_List.p_nth -> + if debug then Format.eprintf "[W] Taking the fast path for the permutation@."; + let elems = EcCoreFol.destr_list lst |> List.map (int_of_form hyps) |> List.map (BI.to_int) |> Array.of_list in + let idx_call = fun i -> (EcTypesafeFol.fapply_safe hyps (f_quant Llambda (bnd::[]) idx) ((f_int (BI.of_int i))::[])) |> int_of_form hyps |> BI.to_int in + let dfl = int_of_form hyps dfl |> BI.to_int in + fun i -> begin + try + elems.(idx_call i) + with Invalid_argument _ -> + dfl + end + | _ -> + if debug then Format.eprintf "[W] Taking the slow path for the permutation (op: %a)@." EcPrinting.(pp_opdecl (PPEnv.ofenv env)) (pperm, bperm); + (fun i -> + let arg = f_int (BI.of_int i) in + let call = EcTypesafeFol.f_app_safe env pperm [arg] in + let res = EcCallbyValue.norm_cbv (EcReduction.full_red) (FApi.tc1_hyps tc) call in + begin try + destr_int res |> BI.to_int + with DestrError _ -> + tc_error pe "Application of function %s failed" (EcPath.tostring pperm) + end + ) + in + + let fperm, pperm = match perm with + | None -> None, None + | Some perm -> + let pperm = EcEnv.Op.lookup ([], perm.pl_desc) env |> fst in + let fperm = fperm_of_perm_op perm in +(* + let fperm (i: int) = + let arg = f_int (BI.of_int i) in + let call = EcTypesafeFol.f_app_safe env pperm [arg] in + let res = EcCallbyValue.norm_cbv (EcReduction.full_red) (FApi.tc1_hyps tc) call in + begin try + destr_int res |> BI.to_int + with DestrError _ -> + tc_error pe "Application of function %s failed" (EcPath.tostring pperm) + end + in +*) + Some fperm, Some pperm + in + + (* DEBUG SECTION *) + (* let pp_type (fmt: Format.formatter) (ty: ty) = *) + (* Format.fprintf fmt "%a" (EcPrinting.pp_type (EcPrinting.PPEnv.ofenv env)) ty in *) + + let plane, olane = EcEnv.Op.lookup ([], lane.pl_desc) env in + let ppcond, opcond = EcEnv.Op.lookup ([], pcond.pl_desc) env in + (* FIXME: Add a check that this does not return a function type + aka lane function only have one argument *) + let inpbty, outbty = tfrom_tfun2 olane.op_ty in + + (* Refactor this *) + + let hr = EcLowPhlGoal.tc1_as_hoareS tc in + + + (* ------------------------------------------------------------------ *) + let outvs = try + get_vars env outvs hr.hs_m + with BDepError err -> + tc_error pe "get_vars (outvs) error: %s" (Lazy.force err) + in + + let out_size = List.sum (List.map + (function + | v, None -> width_of_type env (v.v_type) + | v, Some (t, _) -> + let t = match v.v_type.ty_node with + | Tconstr (_, [bsty]) -> tconstr (EcEnv.Ty.lookup_path t env) [bsty] + | _ -> tc_error pe "Failed to parse output type %a" EcPrinting.(pp_type ppe) v.v_type + in width_of_type env t ) + outvs) + in + + if (out_size mod m <> 0) then tc_error pe "Output size (%d) not divisible by lane size (%d)" out_size m; + let out_block_nr = out_size / m in + let out_block_nr = match fperm with + | None -> out_block_nr + | Some fperm -> List.init out_block_nr (fun i -> if fperm i >= 0 then 1 else 0) |> List.sum + in + + let post_form_of_pv (v: variable * (qsymbol * BI.zint) option) : form = + match v with + | v, None -> (f_pvar (pv_loc v.v_name) v.v_type (hr.hs_m |> fst)).inv + | {v_type} as v, Some (arr_t, offset) -> + let f = f_pvar (pv_loc v.v_name) v.v_type (hr.hs_m |> fst) in + array_init_from_form env f.inv (arr_t, offset) + in + + let poutvs = try + blocks_from_vars env + (List.map post_form_of_pv outvs) outbty + with BDepError err -> + tc_error pe "%s" (Lazy.force err) + in + + (* OPTIONAL PERMUTATION STEP *) + let poutvs = try + Option.apply (Option.map (permute_list env) pperm) poutvs + with BDepError err -> + tc_error pe "%s" (Lazy.force err) + in + + + (* ------------------------------------------------------------------ *) + let inpvs = try + get_vars env inpvs hr.hs_m + with BDepError err -> + tc_error pe "Error in get_vars(inpvs): %s" (Lazy.force err) + in + let in_size = List.sum (List.map + (function + | v, None -> width_of_type env (v.v_type) + | v, Some (t, _) -> + let t = match v.v_type.ty_node with + | Tconstr (_, [bsty]) -> tconstr (EcEnv.Ty.lookup_path t env) [bsty] + | _ -> tc_error pe "Failed to parse input type %a" EcPrinting.(pp_type ppe) v.v_type + in width_of_type env t) + inpvs) in + + EcEnv.notify ~immediate:true env `Info "in_size : %d | block_size: %d@." in_size n; + assert (in_size mod n = 0); + let in_block_nr = in_size / n in + EcEnv.notify ~immediate:true env `Info "in_block_nr: %d | out_block_nr: %d@." in_block_nr out_block_nr; + assert (in_block_nr = out_block_nr); + + let finpvs = List.map post_form_of_pv inpvs in + + let inpvs = List.map + (function + | v, None -> v, None + | v, Some (t, offset) -> + let asz, bsz = match EcEnv.Circuit.lookup_array_and_bitstring env v.v_type with + | Some ({size = (_, Some asz)}, { size = (_, Some bsz) }) -> asz, bsz + | Some (_, { size = (_, None) }) -> tc_error pe "non concrete binding for input bitstring type (%a)" EcPrinting.(pp_type ppe) v.v_type + | Some ({size = (_, None)}, _) -> tc_error pe "non concrete binding for input array type (%a)" EcPrinting.(pp_type ppe) v.v_type + | None -> tc_error pe "Failed to lookup array or bitstring binding for input type (%a)" EcPrinting.(pp_type ppe) v.v_type + in + (* FIXME: Run this once to check that it is equal then remove block below *) + let asz2 = match EcEnv.Circuit.lookup_array_path env (EcPath.fromqsymbol t) with + | Some {size= (_, Some size)} -> size + | Some {size= (_, None)} -> assert false + | _ -> assert false + in + assert (asz = asz2); + v, Some (bsz * asz, (BI.to_int offset) * bsz) + ) + inpvs + in + + let outvs = List.map + (function + | v, None -> v, None + | v, Some (t, offset) -> + let asz, bsz = match EcEnv.Circuit.lookup_array_and_bitstring env v.v_type with + | Some ({size = (_, Some asz)}, { size = (_, Some bsz) }) -> asz, bsz + | Some (_, { size = (_, None) }) -> tc_error pe "non concrete binding for output bitstring type (%a)" EcPrinting.(pp_type ppe) v.v_type + | Some ({size = (_, None)}, _) -> tc_error pe "non concrete binding for output array type (%a)" EcPrinting.(pp_type ppe) v.v_type + | None -> tc_error pe "Failed to lookup array or bitstring binding for output type (%a)" EcPrinting.(pp_type ppe) v.v_type + in + (* FIXME: Run this once to check that it is equal then remove block below *) + let asz2 = match EcEnv.Circuit.lookup_array_path env (EcPath.fromqsymbol t) with + | Some {size= (_, Some size)} -> size + | Some {size= (_, None)} -> assert false + | _ -> assert false + in + assert (asz = asz2); + v, Some (bsz * asz, (BI.to_int offset) * bsz) + ) + outvs + in + + + let invs = + let lookup (x : bdepvar) : ((ident * ty) * (qsymbol * BI.zint) option) list = + let get1 (v : symbol) = + EcEnv.Var.lookup_local v env + in + match x with + | `Var x -> + [get1 (unloc x), None] + | `VarRange (x, n) -> + List.init n (fun i -> get1 (Format.sprintf "%s_%d" (unloc x) i), None) + | `Slice (x, (arr_t, offset)) -> + [get1 (unloc x), Some (unloc arr_t, offset)] + in + List.map lookup invs |> List.flatten in + (* FIXME: Why was this needed? Check that all input types were equal? *) + (* let inty = match List.collapse inv_tys with + | Some ty -> ty + | None -> + let err = Format.sprintf "Failed to coallesce types for input@." + (* in raise (BDepError err) *) + in tc_error pe "%s@." err + in *) + + let post_form_of_lv (v: ((ident * ty) * ((qsymbol * BI.zint) option))) = + match v with + | (id, ty), None -> f_local id ty + | (id, ty), Some (arr_t, offset) -> + let f = f_local id ty in + array_init_from_form env f (arr_t, offset) + in + + let finvs = List.map post_form_of_lv invs in + let pinvs = try + blocks_from_vars env finvs inpbty + with BDepError err -> + tc_error pe "Error while generating input variable expression for precondition:@.%s@." (Lazy.force err) + in + let pinvs_post = EcCoreLib.CI_List.p_map @@! [(f_op plane [] olane.op_ty); pinvs] in + (* ------------------------------------------------------------------ *) + let post = f_eq pinvs_post poutvs in + let pre = EcCoreLib.CI_List.p_all @@! [(f_op ppcond [] opcond.op_ty); pinvs] in + + if (List.compare_lengths inpvs invs <> 0) + then tc_error pe "Logical variables should correspond 1-1 to program variables"; + let pre = f_ands (pre::(List.map2 (fun iv ipv -> f_eq iv ipv) finvs finpvs)) in + + (* let env, hyps, concl = FApi.tc1_eflat tc in *) + let tc = EcPhlConseq.t_hoareS_conseq_nm {inv=pre; m=(fst hr.hs_m)} {inv=post; m=(fst hr.hs_m)} tc in (* FIXME: check memory here*) + FApi.t_last (t_bdep ~debug n m inpvs outvs pcond lane fperm) tc + + + +let t_bdepeq (inpvs_l, inpvs_r: (variable list * variable list)) (n: int) (out_blocks: (int * variable list * variable list) list) (pcond: form option) (preprocess: bool) (tc : tcenv1) = + (* Run bdep and check that is works FIXME *) + let () = match (FApi.tc1_goal tc).f_node with + | FequivS sE -> begin try List.iter (fun (m, outvs_l, outvs_r) -> + prog_equiv_prod (FApi.tc1_hyps tc) (sE.es_ml, sE.es_mr) (sE.es_sl, sE.es_sr) (inpvs_l, inpvs_r, n) (outvs_l, outvs_r, m) pcond preprocess) out_blocks + with BDepError err -> + tc_error (FApi.tc1_penv tc) "Program equivalence failed with error: @.%s@." (Lazy.force err) + end + (* FIXME PR: Do we throw a decent error here or should this be guarded before the call? *) + | FequivF sE -> assert false + | FhoareF sH -> assert false + | FhoareS sF -> assert false + | FbdHoareF _ -> assert false + | FbdHoareS _ -> assert false + | FeHoareF _ -> assert false + | FeHoareS _ -> assert false + | _ -> assert false + in + FApi.close (!@ tc) VBdep + +let process_bdepeq + (bdeinfo: bdepeq_info) + (tc: tcenv1) += + + let env = FApi.tc1_env tc in + let (@@!) pth args = EcTypesafeFol.f_app_safe env pth args in + + + let inpvsl = bdeinfo.inpvs_l in + let inpvsr = bdeinfo.inpvs_r in + let n = bdeinfo.n in + let preprocess = bdeinfo.preprocess in + let pe = FApi.tc1_penv tc in + + (* DEBUG SECTION *) + + let eqS = EcLowPhlGoal.tc1_as_equivS tc in + let mem_l, mem_r = eqS.es_ml, eqS.es_mr in + + (* ------------------------------------------------------------------ *) + let process_block (outvsl: bdepvar list) (outvsr: bdepvar list) = + try + let outvsl = get_vars env outvsl mem_l |> List.fst in + let poutvsl = List.map (fun v -> (EcFol.f_pvar (pv_loc v.v_name) v.v_type (fst mem_l)).inv) outvsl in + + let outvsr = get_vars env outvsr mem_r |> List.fst in + let poutvsr = List.map (fun v -> (EcFol.f_pvar (pv_loc v.v_name) v.v_type (fst mem_r)).inv) outvsr in + List.map2 f_eq poutvsl poutvsr |> f_ands, (outvsl, outvsr) + with BDepError err -> + tc_error pe "Process block failed with error: %s@." (Lazy.force err) + in + + + let inpvsl = try + get_vars env inpvsl mem_l |> List.fst + with BDepError err -> + tc_error pe "%s" (Lazy.force err) + in + let pinpvsl = try + List.map (fun v -> (EcFol.f_pvar (pv_loc v.v_name) v.v_type (fst mem_l)).inv) inpvsl + with BDepError err -> + tc_error pe "%s" (Lazy.force err) + in + + let inpvsr = try + get_vars env inpvsr mem_r |> List.fst + with BDepError err -> + tc_error pe "%s" (Lazy.force err) + in + let pinpvsr = try + List.map (fun v -> (EcFol.f_pvar (pv_loc v.v_name) v.v_type (fst mem_r)).inv) inpvsr + with BDepError err -> + tc_error pe "%s" (Lazy.force err) + in + + let pre = List.map2 f_eq pinpvsl pinpvsr |> f_ands in + let post, outvs = List.map (fun (m, vs_l, vs_r) -> + let post, outvs = process_block vs_l vs_r in + let outvs_l, outvs_r = outvs in + post, (m, outvs_l, outvs_r)) bdeinfo.out_blocks |> List.split in + let post = f_ands post in + + + let prepcond, pcond = match bdeinfo.pcond with + | Some pcond -> + (* FIXME: generate correct precond. Is this fixed ? *) + let ppcond, opcond = EcEnv.Op.lookup ([], pcond.pl_desc) env in + let pcond = match opcond.op_kind with + | OB_oper (Some OP_Plain f) -> f + | _ -> tc_error pe "Unsupported precondition kind for bdepeq" + in + + let opinty = + match opcond.op_ty.ty_node with + | Tfun (a, b) -> a + | _ -> tc_error pe "precond should have function type" + in + + let pinpl_blocks = try + blocks_from_vars env pinpvsl opinty + with BDepError err -> + tc_error pe "%s" (Lazy.force err) + in + + let pre_l = EcCoreLib.CI_List.p_all @@! [(f_op ppcond [] opcond.op_ty); pinpl_blocks] in + + pre_l, Some pcond + | None -> f_true, None + in + + let pre = f_and pre prepcond in + + (* ------------------------------------------------------------------ *) + let tc = EcPhlConseq.t_equivS_conseq_nm {inv=pre; mr=(fst mem_r); ml=(fst mem_l)} {inv=post; mr=(fst mem_r); ml=(fst mem_l)} tc in (* FIXME: check memory *) + FApi.t_last (t_bdepeq (inpvsl, inpvsr) n outvs pcond preprocess) tc + +let t_bdep_form + (f: form) + (v: variable) + (tc : tcenv1) + : tcenv = + match (FApi.tc1_goal tc).f_node with + | FhoareS sF -> begin try + if circ_form_eval_plus_equiv ~me:sF.hs_m (FApi.tc1_hyps tc) sF.hs_s f v then + FApi.t_last (fun tc -> FApi.close (!@ tc) VBdep) (EcPhlConseq.t_hoareS_conseq_nm (hs_pr sF) {(hs_po sF) with inv=(f_and f sF.hs_po)} tc) + else tc_error (FApi.tc1_penv tc) "Supplied formula is not always true@." + with + | BDepError le -> + tc_error (FApi.tc1_penv tc) "BDepError: %s@." (Lazy.force le) + | CircError le -> + assert false + end + | _ -> tc_error (FApi.tc1_penv tc) "Goal should be a Hoare judgement with inlined code@." + +let process_bdep_form + (f: pformula) + (v: bdepvar) + (tc : tcenv1) + : tcenv = + let hr = EcLowPhlGoal.tc1_as_hoareS tc in + let hyps = FApi.tc1_hyps tc in + let env = toenv hyps in + let v = get_var env v hr.hs_m |> List.fst |> as_seq1 in + let ue = EcUnify.UniEnv.create None in + let env = (toenv hyps) in + let env = Memory.push_active_ss hr.hs_m env in + let f = EcTyping.trans_prop env ue f in + assert (EcUnify.UniEnv.closed ue); + let f = EcCoreSubst.Fsubst.f_subst (Tuni.subst (EcUnify.UniEnv.close ue)) f in + t_bdep_form f v tc + +(* FIXME: move? V *) +let form_list_from_iota (hyps: hyps) (f: form) : form list = + match f.f_node with + | Fapp ({f_node = Fop(p, _)}, [n; m]) when p = EcCoreLib.CI_List.p_iota -> + let n = int_of_form hyps n in + let m = int_of_form hyps m in + List.init (BI.to_int m) (fun i -> f_int (BI.(add n (of_int i)))) + | _ -> let err = lazy (Format.asprintf "Failed to get form list from iota expression %a@." + (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv (toenv hyps))) f) in + raise (BDepError err) + +let rec form_list_of_form ?(ppenv: EcPrinting.PPEnv.t option) (f: form) : form list = + match destr_op_app f with + | (pc, _), [h; {f_node = Fop(p, _)}] when + pc = EcCoreLib.CI_List.p_cons && + p = EcCoreLib.CI_List.p_empty -> + [h] + | (pc, _), [h; t] when + pc = EcCoreLib.CI_List.p_cons -> + h::(form_list_of_form t) + | _ -> + if debug then Option.may (fun ppenv -> Format.eprintf "Failed to destructure claimed list: %a@." (EcPrinting.pp_form ppenv) f) ppenv; + raise (BDepError (lazy "Failed to destruct list")) + +(* FIXME: move? A *) + +let t_bdep_eval + (n: int) + (m: int) + (inpvs: variable list) + (outvs: variable list) + (op: psymbol) + (range: form list) + (sign: bool) + (tc : tcenv1) = + (* Run bdep and check that is works FIXME *) + let () = match (FApi.tc1_goal tc).f_node with + | FhoareS sF -> begin try mapreduce_eval (FApi.tc1_hyps tc) sF.hs_m sF.hs_s (inpvs, n) (outvs, m) op range sign with + | BDepError err -> tc_error (FApi.tc1_penv tc) "Error in bdep eval: %s@." (Lazy.force err) + end + (* FIXME PR: Do we throw a decent error here or should this be guarded before the call? *) + | FhoareF sH -> assert false + | FbdHoareF _ -> assert false + | FbdHoareS _ -> assert false + | FeHoareF _ -> assert false + | FeHoareS _ -> assert false + | _ -> assert false + in + FApi.close (!@ tc) VBdep + +let process_bdep_eval (bdeinfo: bdep_eval_info) (tc: tcenv1) = + let { in_ty; out_ty; invs; inpvs; outvs; lane; range; sign } = bdeinfo in + + (* ------------------------------------------------------------------ *) + let env = FApi.tc1_env tc in + let hr = EcLowPhlGoal.tc1_as_hoareS tc in + let hyps = FApi.tc1_hyps tc in + let ppenv = EcPrinting.PPEnv.ofenv env in + let pe = FApi.tc1_penv tc in + + (* ------------------------------------------------------------------ *) + let (@@!) pth args = + try + EcTypesafeFol.f_app_safe env pth args + with EcUnify.UnificationFailure _ -> + tc_error pe + "Type mismatch in pre-post generation, check your lane and precondition types@.\ + Args: %a@." + (fun fmt fs -> List.iter + (fun f -> (Format.fprintf fmt "%a | "(EcPrinting.pp_form ppenv) f)) + fs) + args + in + (* ------------------------------------------------------------------ *) + let (@@!!) pth args = + try + EcTypesafeFol.f_app_safe ~full:false env pth args + with EcUnify.UnificationFailure _ -> + tc_error (FApi.tc1_penv tc) "Type mismatch in pre-post generation, check your lane and precondition types@." + in + + (* DEBUG SECTION *) + (* ------------------------------------------------------------------ *) + let pp_type (fmt: Format.formatter) (ty: ty) = + Format.fprintf fmt "%a" (EcPrinting.pp_type (EcPrinting.PPEnv.ofenv env)) ty in + + (* ------------------------------------------------------------------ *) + let plane, olane = EcEnv.Op.lookup ([], lane.pl_desc) env in + + (* ------------------------------------------------------------------ *) + let in_ty = + let ue = EcUnify.UniEnv.create None in + let ty = EcTyping.transty EcTyping.tp_tydecl env ue in_ty in + assert (EcUnify.UniEnv.closed ue); + ty_subst (Tuni.subst (EcUnify.UniEnv.close ue)) ty + in + + (* ------------------------------------------------------------------ *) + let out_ty = + let ue = EcUnify.UniEnv.create None in + let ty = EcTyping.transty EcTyping.tp_tydecl env ue out_ty in + assert (EcUnify.UniEnv.closed ue); + ty_subst (Tuni.subst (EcUnify.UniEnv.close ue)) ty + in + + (* ------------------------------------------------------------------ *) + let ue = EcUnify.UniEnv.create None in + let env = Memory.push_active_ss hr.hs_m env in + + (* ------------------------------------------------------------------ *) + let range = EcTyping.trans_form env ue range (tconstr EcCoreLib.CI_List.p_list [tint])in + assert (EcUnify.UniEnv.closed ue); + let range = EcCoreSubst.Fsubst.f_subst (Tuni.subst (EcUnify.UniEnv.close ue)) range in + + let frange = try form_list_from_iota hyps range + with BDepError err -> tc_error pe "%s" (Lazy.force err) + in + + + (* ------------------------------------------------------------------ *) + let n, in_to_uint, in_to_sint,in_of_int = + match EcEnv.Circuit.lookup_bitstring env in_ty with + | Some {size = (_, Some size); touint; tosint; ofint} -> size, touint, tosint, ofint + | Some {size = (_, None); _} -> raise (BDepError (lazy "No concrete binding for input")) + | _ -> tc_error pe "No binding for type %a@." pp_type in_ty + in + + let in_to_uint = f_op in_to_uint [] (tfun in_ty tint) in + let in_to_sint = f_op in_to_sint [] (tfun in_ty tint) in + let in_of_int = f_op in_of_int [] (tfun tint in_ty) in + + (* ------------------------------------------------------------------ *) + let m, out_of_int = match EcEnv.Circuit.lookup_bitstring env out_ty with + | Some {size = (_, Some size); ofint} -> size, ofint + | Some {size = (_, None); _} -> raise (BDepError (lazy "No concrete binding for output") ) + | _ -> tc_error pe "No binding for type %a@." pp_type out_ty + in + let out_of_int = f_op out_of_int [] (tfun tint out_ty) in + + (* ------------------------------------------------------------------ *) + let outvs = get_vars env outvs hr.hs_m |> List.fst in + let poutvs = List.map (fun v -> (EcFol.f_pvar (pv_loc v.v_name) v.v_type (fst hr.hs_m)).inv) outvs in + let poutvs = blocks_from_vars env poutvs out_ty in + + (* ------------------------------------------------------------------ *) + let inpvs = get_vars env inpvs hr.hs_m |> List.fst in + let finpvs = List.map (fun v -> (EcFol.f_pvar (pv_loc v.v_name) v.v_type (fst hr.hs_m)).inv) inpvs in + let invs, inv_tys = + let lookup (x : bdepvar) : (ident * ty) list = + let get1 (v : symbol) = + EcEnv.Var.lookup_local v env + in + match x with + | `Var x -> + [get1 (unloc x)] + | `VarRange (x, n) -> + List.init n (fun i -> get1 (Format.sprintf "%s_%d" (unloc x) i)) + | `Slice _ -> tc_error pe "Input slice not currently supported" (* FIXME PR: Do we want to implement this ? *) + + in List.map lookup invs |> List.flatten |> List.split in + + let inty = match List.collapse inv_tys with + | Some ty -> ty + | None -> tc_error pe "Failed to coallesce types for input@." + in + let finvs = List.map (fun id -> f_local id inty) invs in + let pinvs = blocks_from_vars env finvs in_ty in + let pinvs_post = if sign + then EcCoreLib.CI_List.p_map @@! [in_to_sint; pinvs] + else EcCoreLib.CI_List.p_map @@! [in_to_uint; pinvs] + in + let pinvs_post = EcCoreLib.CI_List.p_map @@! [(f_op plane [] olane.op_ty); pinvs_post] in + let pinvs_post = EcCoreLib.CI_List.p_map @@! [out_of_int; pinvs_post] in + + (* ------------------------------------------------------------------ *) + let post = f_eq pinvs_post poutvs in + let pre = EcCoreLib.CI_List.p_all @@! + [(EcCoreLib.CI_List.p_mem @@!! [ + (EcCoreLib.CI_List.p_map @@! [in_of_int; range])]); pinvs] in + + assert (List.compare_lengths inpvs invs = 0); + let pre = f_ands (pre::(List.map2 (fun iv ipv -> f_eq iv ipv) finvs finpvs)) in + + (* let env, hyps, concl = FApi.tc1_eflat tc in *) + let tc = EcPhlConseq.t_hoareS_conseq_nm {inv=pre; m=(fst hr.hs_m)} {inv=post; m=(fst hr.hs_m)} tc in + FApi.t_last (t_bdep_eval n m inpvs outvs lane frange sign) tc + +let rec destr_conj (hyps: hyps) (f: form) : form list = + let redmode = {(circ_red hyps) with zeta = false} in + let f = (EcCallbyValue.norm_cbv redmode hyps f) in + match f.f_node with + | Fapp ({f_node = Fop (p, _)}, fs) -> begin match (EcFol.op_kind p, fs) with + | Some (`And _), _ -> List.flatten @@ List.map (destr_conj hyps) fs + | (None, [f;fs]) when p = EcCoreLib.CI_List.p_all -> + let fs = form_list_from_iota hyps fs in + List.map (fun farg -> f_app f (farg::[]) tbool) fs + | _ -> f::[] + end + | _ -> f::[] + + +(* Should return a list of circuits corresponding to the atomic parts of the pre *) +(* + This means: + /\ p_i => [p_i]_i, + a = b => [a.[i] = b.[i]]_i +*) +(* Returns _open_ circuits *) +let process_pre ?(st : state option) (tc: tcenv1) (f: form) : state * circuit list = + let debug = false in + let env = FApi.tc1_env tc in + let ppe = EcPrinting.PPEnv.ofenv env in + let hyps = FApi.tc1_hyps tc in (* FIXME: should target be specified here? *) + + (* Maybe move this to be a parameter and just supply it from outside *) + let st = match st with + | Some st -> st + | None -> circuit_state_of_hyps hyps + in + + (* Takes in a form of the form /\_i f_i + and returns a list of the conjunction terms [ f_i ]*) + let destr_conj = destr_conj hyps in + + let fs = destr_conj f in + + if debug then Format.eprintf "Destructured conj, obtained:@.%a@." + (EcPrinting.pp_list ";@\n" EcPrinting.(pp_form PPEnv.(ofenv env))) fs; + + (* If f is of the form (a_ = a) (aka prog_var = log_var) + then add it to the state, otherwise do nothing *) + (* FIXME: are all the simplifications necessary ? *) + let rec process_equality (s: state) (f: form) : state = + let f = (EcCallbyValue.norm_cbv (circ_red hyps) hyps f) in + match f.f_node with + | Fapp ({f_node = Fop (p, _);_}, [a; b]) -> begin match EcFol.op_kind p, (EcCallbyValue.norm_cbv (circ_red hyps) hyps a), (EcCallbyValue.norm_cbv (circ_red hyps) hyps b) with + | Some `Eq, {f_node = Fpvar (PVloc pv, m); _}, fv + | Some `Eq, fv, {f_node = Fpvar (PVloc pv, m); _} -> + if debug then Format.eprintf "Adding equality to known information for translation: %a@." EcPrinting.(pp_form PPEnv.(ofenv env)) f; + update_state_pv s m pv (circuit_of_form ~st hyps fv |> snd) + | _ -> s + end + | _ -> s + in + + let st = List.fold_left process_equality st fs in + + (* If convertible to circuit then add to precondition conjunction. + Use state from previous as well *) + let rec process_form (f: form) : circuit list = + match f.f_node with + | Fapp ({f_node = Fop (p, _);_}, [f1; f2]) when EcFol.op_kind p = Some `Eq -> + let hyps, c1 = circuit_of_form ~st hyps (EcCallbyValue.norm_cbv (circ_red hyps) hyps f1) in + let hyps, c2 = circuit_of_form ~st hyps (EcCallbyValue.norm_cbv (circ_red hyps) hyps f2) in + circuit_eqs c1 c2 + | _ -> + begin + if debug then + Format.eprintf "Processing form: %a@.Simplified version: %a@." + EcPrinting.(pp_form ppe) f + EcPrinting.(pp_form ppe) (EcCallbyValue.norm_cbv (circ_red hyps) hyps f); + try (circuit_of_form ~st hyps (EcCallbyValue.norm_cbv (circ_red hyps) hyps f) |> snd)::[] with + e -> begin if debug then Format.eprintf "Encountered exception when converting part of the pre to circuit: %s@." (Printexc.to_string e); + [] end + end + in + + let cs = List.fold_left (fun acc f -> List.rev_append (process_form f) acc) [] fs |> List.rev in +(* + if debug then Format.eprintf "Translated as much as possible from pre to circuits, got:@.%a@\n" + (EcPrinting.pp_list "@\n@\n" pp_circuit) cs; +*) + + if debug then Format.eprintf "In the context of the following bindings in the environment:@\n%a@\n" + (EcPrinting.pp_list "@\n@\n" (fun fmt cinp -> Format.eprintf "%a@." pp_cinp cinp)) (state_lambdas st); + st, cs + +let solve_post ~(st: state) ~(pres: circuit list) (hyps: hyps) (post: form) : bool = + let destr_conj = destr_conj hyps in + let posts = destr_conj post in + + List.for_all (fun post -> + if debug then Format.eprintf "Solving post: %a@." + EcPrinting.(pp_form PPEnv.(ofenv (toenv hyps))) post; + match post.f_node with + | Fapp ({f_node= Fop(p, _); _}, [f1; f2]) -> + begin match EcFol.op_kind p with + | Some `Eq -> + circuit_simplify_equality ~st ~hyps ~pres f1 f2 + | _ -> circuit_of_form ~st hyps post |> snd |> state_close_circuit st |> circ_taut + end + | _ -> circuit_of_form ~st hyps post |> snd |> state_close_circuit st |> circ_taut + ) posts + +(* TODO: Figure out how to not repeat computations here? *) +let t_bdep_solve + (tc : tcenv1) = + let time (env: env) (t: float) (msg: string) : float = + let new_t = Unix.gettimeofday () in + EcEnv.notify ~immediate:true env `Info "[W] %s, took %f s@." msg (new_t -. t); + new_t + in + + + begin + let hyps = (FApi.tc1_hyps tc) in + let goal = (FApi.tc1_goal tc) in + match goal.f_node with + | FhoareS {hs_m; hs_pr; hs_po; hs_s} -> begin try + let tm = Unix.gettimeofday () in + let st, cpres = process_pre tc hs_pr in + let tm = time (toenv hyps) tm "Done with precondition processing" in + + let hyps, st = state_of_prog hyps (fst hs_m) ~st hs_s.s_node [] in + let _tm = time (toenv hyps) tm "Done with program circuit gen" in + let res = solve_post ~st ~pres:cpres hyps hs_po in + EcCircuits.clear_translation_caches (); + if res then + FApi.close (!@ tc) VBdep + else + raise (BDepError (lazy "Failed to verify postcondition")) + with + | BDepError le + | CircError le -> + tc_error (FApi.tc1_penv tc) "%s" (Lazy.force le) + end + | FequivS { es_ml; es_mr; es_pr; es_sl; es_sr; es_po } -> + begin + try ( + let tm = Unix.gettimeofday () in + (* FIXME: rework this *) + let st = circuit_state_of_memenv ~st:empty_state (FApi.tc1_env tc) es_ml in + let st = circuit_state_of_memenv ~st (FApi.tc1_env tc) es_mr in +(* let st = circuit_state_of_hyps ~st (FApi.tc1_hyps tc) in *) + let st, cpres = process_pre ~st tc es_pr in + let tm = time (toenv hyps) tm "Done with precondition processing" in + + (* Circuits from pvars are tagged by memory so we can just put everything in one state *) + let hyps, st = state_of_prog ~me:es_ml hyps (fst es_ml) ~st es_sl.s_node [] in + let tm = time (toenv hyps) tm "Done with left program circuit gen" in + let hyps, st = state_of_prog ~me:es_mr hyps (fst es_mr) ~st es_sr.s_node [] in + let _tm = time (toenv hyps) tm "Done with right program circuit gen" in + + (if solve_post ~st ~pres:cpres hyps es_po + then FApi.close (!@ tc) VBdep else + raise (BDepError (lazy "Failed to verify postcondition"))) + ) + with + | BDepError le + | CircError le -> + tc_error (FApi.tc1_penv tc) "%s" (Lazy.force le) + end + | _ -> + let ctxt = tohyps hyps in + assert (ctxt.h_tvar = []); + let st = circuit_state_of_hyps hyps in + let cgoal = (circuit_of_form ~st hyps goal |> snd |> state_close_circuit st) in + if debug then Format.eprintf "goal: %a@." pp_flatcirc (fst cgoal).reg; + if circ_taut cgoal then + FApi.close (!@ tc) VBdep + else + tc_error (FApi.tc1_penv tc) "Failed to solve goal through circuit reasoning@\n" + end + +let t_bdep_simplify (tc: tcenv1) = + let time (env: env) (t: float) (msg: string) : float = + let new_t = Unix.gettimeofday () in + EcEnv.notify ~immediate:true env `Info "[W] %s, took %f s@." msg (new_t -. t); + Format.eprintf "[W] %s, took %f s@." msg (new_t -. t); + new_t + in + let hyps = (FApi.tc1_hyps tc) in + let goal = (FApi.tc1_goal tc) in + let env = (FApi.tc1_env tc) in + match goal.f_node with + | FhoareS {hs_m=(m, me) as hs_m; hs_pr; hs_po; hs_s} -> +(* begin try *) + let tm = Unix.gettimeofday () in + let st = circuit_state_of_hyps ~use_mem:true hyps in + let st = circuit_state_of_memenv ~st env hs_m in + let st, pres = process_pre ~st tc hs_pr in + let tm = time env tm "Done with precondition processing" in + + + let hyps, st = try + EcCircuits.state_of_prog ~st hyps (fst hs_m) hs_s.s_node [] + with CircError (lazy err) -> + tc_error (FApi.tc1_penv tc) "CircError: @.%s" err + in + let post = EcCallbyValue.norm_cbv (circ_red hyps) hyps hs_po in + (* + if debug then Format.eprintf "[W] Post after simplify (before circuit pass):@. %a@." + EcPrinting.(pp_form PPEnv.(ofenv env)) post; + *) + let tm = time env tm "Done with first simplify" in + let f = EcCircuits.circ_simplify_form_bitstring_equality ~st ~pres hyps post in + let tm = time env tm "Done with circuit simplify" in + let f = EcCallbyValue.norm_cbv (EcReduction.full_red) hyps f in + let tm = time env tm "Done with second simplify" in + let new_goal = f_hoareS (snd hs_m) {inv=hs_pr; m} hs_s {inv=f; m} in + (* + if debug then Format.eprintf "[W] Goal after simplify:@. %a@." + EcPrinting.(pp_form PPEnv.(ofenv env)) new_goal; + *) + + FApi.mutate1 tc (fun _ -> VBdep) new_goal |> FApi.tcenv_of_tcenv1 +(* + with CircError err -> + tc_error (FApi.tc1_penv tc) "CircError: %s@." (Lazy.force err) + end +*) + | _ -> assert false (* FIXME : TODO *) + + + +(* ================ EXTENS TACTIC ==================== *) +(* FIXME: Maybe move later? *) +open FApi +let t_extens (v: string option) (tt : backward) (tc : tcenv1) = + (* Find goal shape + -> generate one goal for each value + -> solve goal by applying the tactic + *) + + let open EcAst in + + let tm = Unix.gettimeofday () in + + let solved = ref 0 in + + let rec do_all (goals: form list) = + match goals with + | [] -> None + | goal::goals -> + let new_tc = mutate1 tc (fun _ -> VBdep) goal in + match (t_try_base tt new_tc) with + | `Failure e -> + tc_error_exn (tc1_penv tc) e + | `Success new_tc -> + match tc_opened new_tc with + | [] -> + incr solved; + (* EcEnv.notify ~immediate:true (tc1_env tc) `Warning "Solved goal %d@." !solved; *) + do_all goals + | hd::_ -> + Some (get_pregoal_by_id hd (tc_penv new_tc)).g_concl + in + + let subst_pv_stmt ?(redmode: EcReduction.reduction_info option) (hyps: LDecl.hyps) (mem: memory) (sb: EcPV.PVM.subst) (s: stmt) = + let redmode = Option.default (circ_red hyps) redmode in + let env = LDecl.toenv hyps in + stmt (List.map (fun i -> + match i.i_node with + | Sasgn (lv, e) -> + let f = (ss_inv_of_expr mem e) in + let fi = EcPV.PVM.subst env sb f.inv in + let fi = EcCallbyValue.norm_cbv redmode hyps fi in + let e = try expr_of_ss_inv {f with inv=fi} + with CannotTranslate -> + Format.eprintf "Failed on form : %a@." + EcPrinting.(pp_form PPEnv.(ofenv env)) fi; + raise CannotTranslate + in + EcCoreModules.i_asgn (lv, e) + | _ -> raise (CannotTranslate) (* FIXME: Errors *) + + ) s.s_node) + in + + let goals = match (tc1_goal tc).f_node, v with + | Fapp ({f_node = Fop (p, [tint]); _}, [fpred; flist]), None when p = EcCoreLib.CI_List.p_all -> + Format.eprintf "[W] Found list all@."; + begin match flist.f_node with + | Fapp ({f_node = Fop (p, []); _}, [fstart; flen]) when p = EcCoreLib.CI_List.p_iota -> + let start = match fstart.f_node with + | Fint i -> EcBigInt.to_int i + | _ -> tc_error (tc1_penv tc) "Iota start should be constant" + in + + let len = match flen.f_node with + | Fint i -> EcBigInt.to_int i + | _ -> tc_error (tc1_penv tc) "Iota length should be constant" + in + + let goals = List.init len (fun i -> + EcTypesafeFol.fapply_safe (tc1_hyps tc) fpred [f_int EcBigInt.(of_int (i + start))] + ) in + + Format.eprintf "[w] Got iota => [%d, %d)@.Goals: %a@." start len + EcPrinting.(pp_list " | " (pp_form PPEnv.(ofenv (tc1_env tc)))) goals; + goals + + | _ -> tc_error (tc1_penv tc) "Unsupported List pattern" + end + | FhoareS ({ hs_m=(m, mt); hs_s; hs_pr; hs_po }), Some v -> + let v = match EcMemory.lookup v mt with + | Some (v, _, _) -> v + | None -> tc_error (tc1_penv tc) "Failed to find var %s in memory %s" v (EcIdent.name m) + in + (* FIXME: Assumes is not array, fix later *) + let size = match EcEnv.Circuit.lookup_bitstring_size (tc1_env tc) v.v_type with + | Some size -> size + | None -> tc_error (tc1_penv tc) "Failed to get size for type %a (is it finite and does it have a binding?)" + EcPrinting.(pp_type PPEnv.(ofenv (tc1_env tc))) v.v_type + in + let tpath = match v.v_type.ty_node with + | Tconstr (p, _ ) -> p + | _ -> tc_error (tc1_penv tc) "Failed to destructure var type" + in + let of_int = match EcEnv.Circuit.reverse_type (tc1_env tc) tpath with + | [] -> tc_error (tc1_penv tc) "No bindings found for type of var" + | `Bitstring { ofint }::_ -> ofint + | _ -> tc_error (tc1_penv tc) "FIXME: Unhandled case" + in + let ngoals = 1 lsl size in +(* let ngoals = min ngoals 5 in *) + List.init ngoals (fun i -> (* FIXME FIXME this is bad *) + let subst = EcPV.PVM.(add (tc1_env tc) (PVloc v.v_name) m + (EcTypesafeFol.f_app_safe (tc1_env tc) of_int [f_int BI.(of_int i)]) empty) + in + let s = subst_pv_stmt (tc1_hyps tc) m subst hs_s in + let subst = EcPV.PVM.subst (tc1_env tc) subst in + let pr = subst hs_pr in + let po = subst hs_po in + let goal = f_hoareS mt ({inv=pr;m}) s ({inv=po;m}) in + if debug then + ( + Format.eprintf "[W] Generated goal %d@." i; +(* + Format.eprintf "%a@." + EcPrinting.(pp_form PPEnv.(ofenv (tc1_env tc))) goal +*) + ); + goal + ) + + | _ -> tc_error (tc1_penv tc) "Wrong goal shape@." + in + + match do_all goals with + | None -> + EcEnv.notify ~immediate:true (tc1_env tc) `Warning "[W] Extens took %f seconds@." (Unix.gettimeofday () -. tm); + close (tcenv_of_tcenv1 tc) VBdep + | Some gfail -> + tc_error (tc1_penv tc) "Failed to close goal:@. %a@." + EcPrinting.(pp_form PPEnv.(ofenv (tc1_env tc))) gfail + false + + diff --git a/src/phl/ecPhlBDep.mli b/src/phl/ecPhlBDep.mli new file mode 100644 index 0000000000..222239a59b --- /dev/null +++ b/src/phl/ecPhlBDep.mli @@ -0,0 +1,30 @@ +(* -------------------------------------------------------------------- *) +open EcParsetree +open EcCoreGoal +open EcAst + +(* -------------------------------------------------------------------- *) +(* val bdep : env -> stmt -> psymbol -> variable list -> int -> variable list -> int-> psymbol -> unit *) + +val t_bdep_form : form -> variable -> tcenv1 -> tcenv + +val t_bdep : ?debug:bool -> int -> int -> (variable * (int * int) option) list -> (variable * (int * int) option) list -> psymbol -> psymbol -> (int -> int) option -> tcenv1 -> tcenv + +val t_bdepeq : variable list * variable list -> int -> (int * variable list * variable list) list -> form option -> bool -> tcenv1 -> tcenv + +val t_bdep_eval : int -> int -> variable list -> variable list -> psymbol -> form list -> bool -> tcenv1 -> tcenv + +val t_bdep_solve : tcenv1 -> tcenv + +val t_bdep_simplify : tcenv1 -> tcenv + +val t_extens : string option -> FApi.backward -> FApi.backward + +val process_bdep : bdep_info -> tcenv1 -> tcenv + +val process_bdepeq : bdepeq_info -> tcenv1 -> tcenv + +val process_bdep_form : pformula -> bdepvar -> tcenv1 -> tcenv + +val process_bdep_eval : bdep_eval_info -> tcenv1 -> tcenv + diff --git a/src/phl/ecPhlCodeTx.ml b/src/phl/ecPhlCodeTx.ml index f3ec2c513c..191344a0a8 100644 --- a/src/phl/ecPhlCodeTx.ml +++ b/src/phl/ecPhlCodeTx.ml @@ -185,6 +185,7 @@ let t_set_match_r (side : oside) (cpos : Position.codepos) (id : symbol) pattern (t_zip (set_match_stmt id pattern)) tc (* -------------------------------------------------------------------- *) +(* FIXME: have a better handling of PV *) let cfold_stmt ?(simplify = true) (pf, hyps) (me : memenv) (olen : int option) (zpr : Zpr.zipper) = let env = LDecl.toenv hyps in @@ -196,19 +197,23 @@ let cfold_stmt ?(simplify = true) (pf, hyps) (me : memenv) (olen : int option) ( e ) else identity in - let is_const_expression (e : expr) = - PV.is_empty (e_read env e) in - let for_instruction ((subst as subst0) : (expr, unit) Mpv.t) (i : instr) = let wr = EcPV.i_write env i in let i = Mpv.isubst env subst i in let (subst, asgn) = - List.fold_left_map (fun subst ((pv, _) as pvty) -> - match Mpv.find env pv subst with - | e -> Mpv.remove env pv subst, Some (pvty, e) - | exception Not_found -> subst, None - ) subst (fst (PV.elements wr)) in + List.fold_left_map (fun subst (pv, e) -> + let exception Remove in + + try + if PV.mem_pv env pv wr then raise Remove; + let rd = EcPV.e_read env e in + if PV.mem_pv env pv rd then raise Remove; + subst, None + + with Remove -> + Mpv.remove env pv subst, Some ((pv, e.e_ty), e) + ) subst (EcPV.Mnpv.bindings (Mpv.pvs subst)) in let asgn = List.filter_map identity asgn in @@ -227,7 +232,7 @@ let cfold_stmt ?(simplify = true) (pf, hyps) (me : memenv) (olen : int option) ( try match i.i_node with | Sasgn (lv, e) -> - (* We already removed the variables of `lv` from the substitution *) + (* We already removed the variables of `lv` & the rhs from the substitution *) (* We are only interested in the variables of `lv` that are in `wr` *) let es = match simplify e, lv with @@ -238,7 +243,7 @@ let cfold_stmt ?(simplify = true) (pf, hyps) (me : memenv) (olen : int option) ( let lv = lv_to_ty_list lv in let tosubst, asgn2 = List.partition (fun ((pv, _), e) -> - Mpv.mem env pv subst0 && is_const_expression e + Mpv.mem env pv subst0 ) (List.combine lv es) in let subst = @@ -292,9 +297,6 @@ let cfold_stmt ?(simplify = true) (pf, hyps) (me : memenv) (olen : int option) ( | e, _ -> [e] in let lv = lv_to_ty_list lv in - if not (List.for_all is_const_expression es) then - tc_error pf "right-values are not closed expressions"; - if not (List.for_all (is_loc |- fst) lv) then tc_error pf "left-values must be made of local variables only"; @@ -435,7 +437,14 @@ let process_case ((side, pos) : side option * pcodepos) (tc : tcenv1) = let lv, e = destr_asgn i in - let pvl = EcPV.lp_write env lv in + let pvl = (* FIXME: do we want to do this TCB-wise? *) + match lv with + | LvVar _ -> PV.empty + | LvTuple lvs -> + let lvs = List.tl (List.rev lvs) in + let lvs = Option.get (lv_of_list lvs) in + EcPV.lp_write env lvs in + let pve = EcPV.e_read env e in let lv = lv_to_list lv in @@ -446,7 +455,11 @@ let process_case ((side, pos) : side option * pcodepos) (tc : tcenv1) = match lv, e.e_node with | [_], _ -> [e] | _ , Etuple es -> es - | _ ,_ -> assert false in + | _ ,_ -> + let tys = + match (EcEnv.Ty.hnorm e.e_ty env).ty_node with + | Ttuple tys -> tys | _ -> assert false in + List.mapi (fun i ty -> e_proj e i ty) tys in let s = List.map2 (fun pv e -> i_asgn (LvVar (pv, e.e_ty), e)) lv e in diff --git a/src/phl/ecPhlEqobs.ml b/src/phl/ecPhlEqobs.ml index 8466ac9351..b13c46c607 100644 --- a/src/phl/ecPhlEqobs.ml +++ b/src/phl/ecPhlEqobs.ml @@ -1,6 +1,9 @@ (* -------------------------------------------------------------------- *) open EcUtils +open EcPath +open EcParsetree open EcAst +open EcMatching.Position open EcTypes open EcModules open EcFol @@ -13,6 +16,16 @@ open EcLowPhlGoal module TTC = EcProofTyping +(* -------------------------------------------------------------------- *) +type sim_info = { + sim_pos : codepos1 pair option; + sim_hint : (xpath option * xpath option * EcPV.Mpv2.t) list * ts_inv option; + sim_eqs : EcPV.Mpv2.t option; +} + +let empty_sim_info : sim_info = + { sim_pos = None; sim_hint = ([], None); sim_eqs = None; } + (* -------------------------------------------------------------------- *) let extend_body fsig body = let arg = pv_arg in @@ -398,7 +411,7 @@ let t_eqobs_inS_r sim eqo tc = tc_error !!tc "cannot apply sim"; let sg = List.map (mk_inv_spec env inv) sim.needed_spec in - let concl = f_equivS (snd es.es_ml) (snd es.es_mr) (es_pr es) sl sr pre in + let concl = f_equivS (snd es.es_ml) (snd es.es_mr) (es_pr es) sl sr pre in FApi.xmutate1 tc `EqobsIn (sg @ [concl]) @@ -424,53 +437,62 @@ let t_eqobs_inF_r sim eqo tc = let t_eqobs_inF = FApi.t_low2 "eqobs-in" t_eqobs_inF_r (* -------------------------------------------------------------------- *) -let process_eqs env tc f = - try - Mpv2.of_form env f - with Not_found -> - tc_error_lazy !!tc (fun fmt -> - let ppe = EcPrinting.PPEnv.ofenv env in - Format.fprintf fmt - "cannot recognize %a as a set of equalities" - (EcPrinting.pp_form ppe) f.inv) +let process_eqs (pe : proofenv) (env : env) (f : ts_inv) = + try + Mpv2.of_form env f + with Not_found -> + tc_error_lazy pe (fun fmt -> + let ppe = EcPrinting.PPEnv.ofenv env in + Format.fprintf fmt + "cannot recognize %a as a set of equalities" + (EcPrinting.pp_form ppe) f.inv) (* -------------------------------------------------------------------- *) -let process_hint ml mr tc hyps (feqs, inv) = +let process_hint ml mr (pe : proofenv) (hyps : LDecl.hyps) (feqs, inv : _ * _) = let env = LDecl.toenv hyps in let ienv = LDecl.push_active_ts (EcMemory.abstract ml) (EcMemory.abstract mr) hyps in - let doinv pf = {ml;mr;inv=TTC.pf_process_form !!tc ienv tbool pf} in - let doeq pf = process_eqs env tc (doinv pf) in + let doinv pf = {ml;mr;inv=TTC.pf_process_form pe ienv tbool pf} in + let doeq pf = process_eqs pe env (doinv pf) in let dof g = omap (EcTyping.trans_gamepath env) g in let geqs = - List.map (fun ((f1,f2),geq) -> dof f1, dof f2, doeq geq) + List.map + (fun ((f1, f2), geq) -> dof f1, dof f2, doeq geq) feqs in - let ginv = odfl {ml;mr;inv=f_true} (omap doinv inv) in + let ginv = (omap doinv inv) in (* FIXME: check *) geqs, ginv (* -------------------------------------------------------------------- *) -let process_eqobs_inS info tc = +let pre_eqobs (cm : crushmode) (tc : tcenv1) = + let dt, ts = EcHiGoal.process_crushmode cm in + EcPhlConseq.t_conseqauto ~delta:dt ?tsolve:ts tc + +(* -------------------------------------------------------------------- *) +let t_eqobs_inS_ (info : sim_info) (tc : tcenv1) = let env, hyps, _ = FApi.tc1_eflat tc in let es = tc1_as_equivS tc in - let ml, mr = fst es.es_ml, fst es.es_mr in - let spec, inv = process_hint ml mr tc hyps info.EcParsetree.sim_hint in + let spec, inv = info.sim_hint in + + let inv = match inv with + | Some inv -> inv + | None -> let ml, mr = fst es.es_ml, fst es.es_mr in + {ml;mr;inv=f_true} + in + let eqo = - match info.EcParsetree.sim_eqs with - | Some pf -> - process_eqs env tc (TTC.tc1_process_prhl_formula tc pf) - | None -> - try Mpv2.needed_eq env (es_po es) - with Not_found -> tc_error !!tc "cannot infer the set of equalities" in - let post = Mpv2.to_form_ts_inv eqo inv in + match info.sim_eqs with Some eqo -> eqo | None -> + try Mpv2.needed_eq env (es_po es) + with _ -> tc_error !!tc "cannot infer the set of equalities" in + let sim = init_sim env spec inv in + let post = Mpv2.to_form_ts_inv eqo inv in + let t_main tc = - match info.EcParsetree.sim_pos with + match info.sim_pos with | None -> FApi.t_last (FApi.t_try (FApi.t_seq EcPhlSkip.t_skip t_trivial)) (t_eqobs_inS sim eqo tc) | Some(p1,p2) -> - let p1 = EcLowPhlGoal.tc1_process_codepos1 tc (Some `Left , p1) in - let p2 = EcLowPhlGoal.tc1_process_codepos1 tc (Some `Right, p2) in let _,sl2 = s_split env p1 es.es_sl in let _,sr2 = s_split env p2 es.es_sr in let _, eqi = @@ -485,49 +507,106 @@ let process_eqobs_inS info tc = ]) tc in (EcPhlConseq.t_equivS_conseq (es_pr es) post @+ [t_trivial; - t_trivial; - t_main]) tc + t_trivial; + t_main]) tc + +(* -------------------------------------------------------------------- *) +let t_eqobs_inS (cm : crushmode option) (info : sim_info) (tc : tcenv1) = + FApi.t_last (t_eqobs_inS_ info) ((omap pre_eqobs cm |> odfl t_id) tc) (* -------------------------------------------------------------------- *) -let process_eqobs_inF info tc = - if info.EcParsetree.sim_pos <> None then - tc_error !!tc "no positions excepted"; +let process_eqobs_inS (cm : crushmode option) (info : psim_info) (tc : tcenv1) = + let env, hyps, _ = FApi.tc1_eflat tc in + let es = tc1_as_equivS tc in + let ml, mr = fst es.es_ml, fst es.es_mr in + let sim_hint = process_hint ml mr !!tc hyps info.psim_hint in + let sim_eqs = + let process pf = + process_eqs !!tc env (TTC.tc1_process_prhl_formula tc pf) + in Option.map process info.psim_eqs in + let sim_pos = + info.psim_pos + |> Option.map (pair_map (EcTyping.trans_codepos1 env)) + in + + let info = { sim_pos; sim_hint; sim_eqs; } in + + t_eqobs_inS cm info tc + +(* -------------------------------------------------------------------- *) +let t_eqobs_inF_ (info : sim_info) (tc : tcenv1) = + assert (Option.is_none info.sim_pos); + let env, hyps, _ = FApi.tc1_eflat tc in let ef = tc1_as_equivF tc in - let ml, mr = ef.ef_ml, ef.ef_mr in - let spec, inv = process_hint ml mr tc hyps info.EcParsetree.sim_hint in let fl = ef.ef_fl and fr = ef.ef_fr in + + let spec, inv = info.sim_hint in + let eqo = - match info.EcParsetree.sim_eqs with - | Some pf -> - let _,(mle,mre) = Fun.equivF_memenv ml mr fl fr env in - let hyps = LDecl.push_active_ts mle mre hyps in - process_eqs env tc {ml; mr; inv=TTC.pf_process_form !!tc hyps tbool pf} - | None -> + match info.sim_eqs with Some eqo -> eqo | None -> try Mpv2.needed_eq env (ef_po ef) with _ -> tc_error !!tc "cannot infer the set of equalities" in + let eqo = Mpv2.remove env pv_res pv_res eqo in + + let inv = match inv with + | Some inv -> inv + | None -> let ml, mr = ef.ef_ml, ef.ef_mr in + {ml;mr;inv=f_true} + in + let sim = init_sim env spec inv in let _, eqi = try f_eqobs_in fl fr sim eqo with EqObsInError -> tc_error !!tc "not able to process" in let ef' = destr_equivF (mk_inv_spec2 env inv (fl, fr, eqi, eqo)) in + (EcPhlConseq.t_equivF_conseq (ef_pr ef') (ef_po ef') @+ [ t_trivial; t_trivial; t_eqobs_inF sim eqo]) tc (* -------------------------------------------------------------------- *) -let process_eqobs_in cm info tc = - let prett cm tc = - let dt, ts = EcHiGoal.process_crushmode cm in - EcPhlConseq.t_conseqauto ~delta:dt ?tsolve:ts tc in - let tt tc = - let concl = FApi.tc1_goal tc in - match concl.f_node with - | FequivF _ -> process_eqobs_inF info tc - | FequivS _ -> process_eqobs_inS info tc - | _ -> tc_error_noXhl ~kinds:[`Equiv `Any] !!tc - in +let t_eqobs_inF (cm : crushmode option) (info : sim_info) (tc : tcenv1) = + FApi.t_last (t_eqobs_inF_ info) ((omap pre_eqobs cm |> odfl t_id) tc) + +(* -------------------------------------------------------------------- *) +let process_eqobs_inF (cm : crushmode option) (info : psim_info) (tc : tcenv1) = + if Option.is_some info.psim_pos then + tc_error !!tc "no positions excepted"; - FApi.t_last tt ((omap prett cm |> odfl t_id) tc) + let env, hyps, _ = FApi.tc1_eflat tc in + let ef = tc1_as_equivF tc in + let ml, mr = ef.ef_ml, ef.ef_mr in + let sim_hint = process_hint ml mr !!tc hyps info.psim_hint in + let fl = ef.ef_fl and fr = ef.ef_fr in + let sim_eqs = + let process pf = + let _,(mle,mre) = Fun.equivF_memenv ml mr fl fr env in + let hyps = LDecl.push_active_ts mle mre hyps in + process_eqs !!tc env {ml; mr; inv=TTC.pf_process_form !!tc hyps tbool pf} + in Option.map process info.psim_eqs in + + let info = { sim_pos = None; sim_hint; sim_eqs; } in + + t_eqobs_inF cm info tc + +(* -------------------------------------------------------------------- *) +let process_eqobs_in (cm : crushmode option) (info : psim_info) (tc : tcenv1) = + let concl = FApi.tc1_goal tc in + match concl.f_node with + | FequivF _ -> process_eqobs_inF cm info tc + | FequivS _ -> process_eqobs_inS cm info tc + | _ -> tc_error_noXhl ~kinds:[`Equiv `Any] !!tc + +(* -------------------------------------------------------------------- *) +let t_eqobs_in_r (cm : crushmode option) (info : sim_info) (tc : tcenv1) = + let concl = FApi.tc1_goal tc in + match concl.f_node with + | FequivF _ -> t_eqobs_inF cm info tc + | FequivS _ -> t_eqobs_inS cm info tc + | _ -> tc_error_noXhl ~kinds:[`Equiv `Any] !!tc + +(* -------------------------------------------------------------------- *) +let t_eqobs_in = FApi.t_low2 "eqobs-in" t_eqobs_in_r diff --git a/src/phl/ecPhlEqobs.mli b/src/phl/ecPhlEqobs.mli index d210124949..2da8b44476 100644 --- a/src/phl/ecPhlEqobs.mli +++ b/src/phl/ecPhlEqobs.mli @@ -1,7 +1,20 @@ (* -------------------------------------------------------------------- *) - +open EcUtils +open EcPath open EcParsetree +open EcAst +open EcMatching.Position open EcCoreGoal.FApi (* -------------------------------------------------------------------- *) -val process_eqobs_in : crushmode option -> sim_info -> backward +type sim_info = { + sim_pos : codepos1 pair option; + sim_hint : (xpath option * xpath option * EcPV.Mpv2.t) list * ts_inv option; + sim_eqs : EcPV.Mpv2.t option; +} + +val empty_sim_info : sim_info + +(* -------------------------------------------------------------------- *) +val t_eqobs_in : crushmode option -> sim_info -> backward +val process_eqobs_in : crushmode option -> psim_info -> backward diff --git a/src/phl/ecPhlLoopTx.ml b/src/phl/ecPhlLoopTx.ml index 434dece2ce..7eb8105155 100644 --- a/src/phl/ecPhlLoopTx.ml +++ b/src/phl/ecPhlLoopTx.ml @@ -20,7 +20,7 @@ module TTC = EcProofTyping (* -------------------------------------------------------------------- *) type fission_t = oside * pcodepos * (int * (int * int)) type fusion_t = oside * pcodepos * (int * (int * int)) -type unroll_t = oside * pcodepos * bool +type unroll_t = oside * pcodepos * [`While | `For of bool] type splitwhile_t = pexpr * oside * pcodepos (* -------------------------------------------------------------------- *) @@ -220,7 +220,7 @@ let process_splitwhile (b, side, cpos) tc = t_splitwhile b side cpos tc (* -------------------------------------------------------------------- *) -let process_unroll_for side cpos tc = +let process_unroll_for ~cfold side cpos tc = let env = FApi.tc1_env tc in let hyps = FApi.tc1_hyps tc in let (goal_m, _), c = EcLowPhlGoal.tc1_get_stmt side tc in @@ -305,7 +305,7 @@ let process_unroll_for side cpos tc = let t_conseq_nm tc = match (tc1_get_pre tc) with - | Inv_ss inv -> + | Inv_ss inv -> (EcPhlConseq.t_hoareS_conseq_nm inv {m=inv.m;inv=f_true} @+ [ t_trivial; t_trivial; EcPhlTAuto.t_hoare_true]) tc | _ -> tc_error !!tc "expecting single sided precondition" in @@ -327,16 +327,19 @@ let process_unroll_for side cpos tc = let tcenv = t_doit 0 pos zs tc in let tcenv = FApi.t_onalli doi tcenv in - let cpos = EcMatching.Position.shift ~offset:(-1) cpos in - let clen = blen * (List.length zs - 1) in + if cfold then begin + let cpos = EcMatching.Position.shift ~offset:(-1) cpos in + let clen = blen * (List.length zs - 1) in - FApi.t_last (EcPhlCodeTx.t_cfold side cpos (Some clen)) tcenv + FApi.t_last (EcPhlCodeTx.t_cfold side cpos (Some clen)) tcenv + end else tcenv (* -------------------------------------------------------------------- *) let process_unroll (side, cpos, for_) tc = - if for_ then - process_unroll_for side cpos tc - else begin + match for_ with + | `While -> let cpos = EcLowPhlGoal.tc1_process_codepos tc (side, cpos) in t_unroll side cpos tc - end + + | `For cfold -> + process_unroll_for ~cfold:(not cfold) side cpos tc diff --git a/src/phl/ecPhlLoopTx.mli b/src/phl/ecPhlLoopTx.mli index 8d619f9afd..994b447db2 100644 --- a/src/phl/ecPhlLoopTx.mli +++ b/src/phl/ecPhlLoopTx.mli @@ -13,10 +13,10 @@ val t_splitwhile : expr -> oside -> codepos -> backward (* -------------------------------------------------------------------- *) type fission_t = oside * pcodepos * (int * (int * int)) type fusion_t = oside * pcodepos * (int * (int * int)) -type unroll_t = oside * pcodepos * bool +type unroll_t = oside * pcodepos * [`While | `For of bool] type splitwhile_t = pexpr * oside * pcodepos -val process_unroll_for : oside -> pcodepos -> backward +val process_unroll_for : cfold:bool -> oside -> pcodepos -> backward val process_fission : fission_t -> backward val process_fusion : fusion_t -> backward val process_unroll : unroll_t -> backward diff --git a/src/phl/ecPhlOutline.ml b/src/phl/ecPhlOutline.ml index 2898a138b5..11b9181f03 100644 --- a/src/phl/ecPhlOutline.ml +++ b/src/phl/ecPhlOutline.ml @@ -7,12 +7,46 @@ open EcCoreGoal.FApi open EcLowPhlGoal (*---------------------------------------------------------------------------------------*) +(* FIXME PR: Remove? *) +let t_outline_stmt side start_pos end_pos s tc = + let env = FApi.tc1_env tc in + let goal = tc1_as_equivS tc in + + (* Check which memory/program we are outlining *) + let code = match side with + | `Left -> goal.es_sl + | `Right -> goal.es_sr + in + + (* Extract the program prefix and suffix *) + let rest, code_suff = s_split env end_pos code in + let code_pref, _, _ = s_split_i env start_pos (stmt rest) in + + let new_prog = s_seq (s_seq (stmt code_pref) s) (stmt code_suff) in + let tc = EcPhlTrans.t_equivS_trans_eq side new_prog tc in + + (* The middle goal, showing equivalence with the replaced code, ideally solves. *) + let tp = match side with | `Left -> 1 | `Right -> 2 in + let p = EcHiGoal.process_tfocus tc (Some [Some tp, Some tp], None) in + let tc = + t_onselect + p + (t_try ( + t_seqs [ + EcPhlInline.process_inline (`ByName (None, None, ([], None))); + EcPhlEqobs.t_eqobs_in None EcPhlEqobs.empty_sim_info; + EcPhlAuto.t_auto; + EcHiGoal.process_done; + ])) + tc + in + tc (* `by inline; sim; auto=> />` *) let t_auto_equiv_sim = t_seqs [ EcPhlInline.process_inline (`ByName (None, None, ([], None))); - EcPhlEqobs.process_eqobs_in None {sim_pos = None; sim_hint = ([], None); sim_eqs = None}; + EcPhlEqobs.process_eqobs_in None {psim_pos = None; psim_hint = ([], None); psim_eqs = None}; EcPhlAuto.t_auto; EcLowGoal.t_crush; EcHiGoal.process_done; diff --git a/src/phl/ecPhlRCond.mli b/src/phl/ecPhlRCond.mli index 87306ed994..099093971f 100644 --- a/src/phl/ecPhlRCond.mli +++ b/src/phl/ecPhlRCond.mli @@ -24,5 +24,5 @@ val t_rcond : oside -> bool -> codepos1 -> backward val process_rcond : oside -> bool -> pcodepos1 -> backward (* -------------------------------------------------------------------- *) +val t_rcond_match : oside -> symbol -> codepos1 -> backward val process_rcond_match : oside -> symbol -> pcodepos1 -> backward -val t_rcond_match : oside -> symbol -> codepos1 -> backward diff --git a/src/phl/ecPhlRewrite.ml b/src/phl/ecPhlRewrite.ml index 57abab3d51..bc28749e7c 100644 --- a/src/phl/ecPhlRewrite.ml +++ b/src/phl/ecPhlRewrite.ml @@ -1,11 +1,13 @@ (* -------------------------------------------------------------------- *) open EcParsetree +open EcUtils open EcAst open EcCoreGoal open EcEnv open EcModules open EcFol open Batteries +open EcLowPhlGoal (* -------------------------------------------------------------------- *) let t_change @@ -169,26 +171,71 @@ let process_rewrite | `Rw rw -> process_rewrite_rw side pos rw tc | `Simpl -> process_rewrite_simpl side pos tc +let rec pvtail (env: env) (pvs : EcPV.PV.t) (zp : Zpr.ipath) = + let parent = + match zp with + | Zpr.ZTop -> None + | Zpr.ZWhile (_, p) -> Some p + | Zpr.ZIfThen (e, p, _) -> Some p + | Zpr.ZIfElse (e, _, p) -> Some p + | Zpr.ZMatch (e, p, _) -> Some p in + match parent with + | None -> pvs + | Some ((_, tl), p) -> pvtail env (EcPV.PV.union pvs (EcPV.is_read env tl)) p + (* -------------------------------------------------------------------- *) let t_change_stmt (side : side option) - (pos : EcMatching.Position.codepos_range) + (pos : EcMatching.Position.codepos_range) + ((me, bindings) : memenv * ovariable list) (s : stmt) (tc : tcenv1) = let env = FApi.tc1_env tc in - let me, stmt = EcLowPhlGoal.tc1_get_stmt side tc in + let goal = (FApi.tc1_goal tc) in + let post = match goal.f_node with + | FhoareS { hs_po } -> hs_po + | FbdHoareS { bhs_po } -> bhs_po + | FeHoareS { ehs_po } -> ehs_po + | FequivS { es_po } -> es_po + | _ -> assert false + in + let _, stmt = EcLowPhlGoal.tc1_get_stmt side tc in + + let env = EcEnv.Memory.push_active_ts me me env in (* FIXME *) + + let zpr, epos = Zpr.zipper_of_cpos_range env pos stmt in + let stmt, epilog = match zpr.z_tail with + | [] -> raise Zpr.InvalidCPos + | i::tl -> let s, tl = Zpr.split_at_cpos1 env epos (EcAst.stmt tl) in + (i::s), tl + in - let (zpr, _), (stmt, epilog) = EcMatching.Zipper.zipper_and_split_of_cpos_range env pos stmt in + let keep = pvtail env (EcPV.is_read env epilog) zpr.z_path in + let keep = EcPV.PV.union keep (EcPV.PV.fv env (EcMemory.memory me) post) in let pvs = EcPV.is_write env (stmt @ s.s_node) in - let pvs, globs = EcPV.PV.elements pvs in + let _pvs, globs = EcPV.PV.elements pvs in + + let pvs, _ = EcPV.PV.elements (EcPV.PV.inter keep pvs) in - let pre_pvs, pre_globs = EcPV.PV.elements @@ EcPV.PV.inter - (EcPV.is_read env stmt) + let pre_pvs = EcPV.PV.inter + (EcPV.is_read env stmt) (EcPV.is_read env s.s_node) in + (* FIXME: Check | Do we need this? *) +(* + let pre_pvs = EcPV.PV.union pre_pvs ( + pvtail env (EcPV.is_read env epilog) zpr.z_path + ) in +*) + + (* Do we need this? *) +(* let pre_pvs = EcPV.PV.union pre_pvs (EcPV.PV.fv env (EcMemory.memory me) post) in *) + + let pre_pvs, pre_globs = EcPV.PV.elements pre_pvs in + let mleft = EcIdent.create "&1" in (* FIXME: PR: is this how we want to do this? *) let mright = EcIdent.create "&2" in @@ -221,20 +268,26 @@ let t_change_stmt let stmt = EcMatching.Zipper.zip { zpr with z_tail = s.s_node @ epilog } in - let goal2 = - EcLowPhlGoal.hl_set_stmt - side (FApi.tc1_goal tc) - stmt in + let goal2 = match side, goal.f_node with + | None, FhoareS hs -> f_hoareS (snd me) (hs_pr hs) stmt (hs_po hs) + | None, FbdHoareS bhs -> f_bdHoareS (snd me) (bhs_pr bhs) stmt (bhs_po bhs) (bhs.bhs_cmp) (bhs_bd bhs) + | None, FeHoareS ehs -> f_eHoareS (snd me) (ehs_pr ehs) stmt (ehs_po ehs) + | Some `Left, FequivS es -> f_equivS (snd me) (snd es.es_mr) (es_pr es) stmt (es.es_sr) (es_po es) + | Some `Right, FequivS es -> f_equivS (snd es.es_ml) (snd me) (es_pr es) (es.es_sl) stmt (es_po es) + | _ -> assert false + in FApi.xmutate1 tc `ProcChangeStmt [goal1; goal2] (* -------------------------------------------------------------------- *) let process_change_stmt (side : side option) + (binds : ptybindings option) (pos : pcodepos_range) (s : pstmt) (tc : tcenv1) = + let hyps = FApi.tc1_hyps tc in let env = FApi.tc1_env tc in begin match side, (FApi.tc1_goal tc).f_node with @@ -255,14 +308,46 @@ let process_change_stmt let me, _ = EcLowPhlGoal.tc1_get_stmt side tc in - let pos = + let pos = let env = EcEnv.Memory.push_active_ss me env in - EcTyping.trans_codepos_range ~memory:(fst me) env pos + EcTyping.trans_codepos_range ~memory:(fst me) env pos in - let s = match side with +(* + let s = match side with | Some side -> EcProofTyping.tc1_process_prhl_stmt tc side s | None -> EcProofTyping.tc1_process_Xhl_stmt tc s in +*) + + let bindings = + binds + |> Option.default [] + |> List.map (fun (xs, ty) -> List.map (fun x -> (x, ty)) xs) + |> List.flatten + |> List.map (fun (x, ty) -> + let ue = EcUnify.UniEnv.create (Some (EcEnv.LDecl.tohyps hyps).h_tvar) in + let ty = EcTyping.transty EcTyping.tp_tydecl env ue ty in + assert (EcUnify.UniEnv.closed ue); + let ty = + let subst = EcCoreSubst.Tuni.subst (EcUnify.UniEnv.close ue) in + EcCoreSubst.ty_subst subst ty in + let x = Option.map EcLocation.unloc (EcLocation.unloc x) in + let vr = EcAst.{ ov_name = x; ov_type = ty; } in + vr + ) + in + let me, bindings = EcMemory.bindall_fresh bindings me in + + let env = EcEnv.Memory.push_active_ss me env in + let s = + let ue = EcProofTyping.unienv_of_hyps hyps in + let s = EcTyping.transstmt env ue s in + + assert (EcUnify.UniEnv.closed ue); + + let sb = EcCoreSubst.Tuni.subst (EcUnify.UniEnv.close ue) in + EcCoreSubst.s_subst sb s + in - t_change_stmt side pos s tc + t_change_stmt side pos (me, bindings) s tc diff --git a/src/phl/ecPhlRewrite.mli b/src/phl/ecPhlRewrite.mli index 9640d3a1fc..b25f49a895 100644 --- a/src/phl/ecPhlRewrite.mli +++ b/src/phl/ecPhlRewrite.mli @@ -5,7 +5,8 @@ open EcCoreGoal.FApi (* -------------------------------------------------------------------- *) val process_change : side option -> pcodepos -> pexpr -> backward +(* -------------------------------------------------------------------- *) val process_rewrite_rw : side option -> pcodepos -> ppterm -> backward val process_rewrite_simpl : side option -> pcodepos -> backward val process_rewrite : side option -> pcodepos -> prrewrite -> backward -val process_change_stmt : side option -> pcodepos_range -> pstmt -> backward +val process_change_stmt : side option -> ptybindings option -> pcodepos_range -> pstmt -> backward diff --git a/src/phl/ecPhlRwEquiv.ml b/src/phl/ecPhlRwEquiv.ml index 777b93c165..35dc5dec4d 100644 --- a/src/phl/ecPhlRwEquiv.ml +++ b/src/phl/ecPhlRwEquiv.ml @@ -1,6 +1,8 @@ +(* -------------------------------------------------------------------- *) open EcUtils open EcLocation open EcParsetree +open EcAst open EcFol open EcModules open EcPath @@ -11,7 +13,7 @@ open EcCoreGoal.FApi open EcLowGoal open EcLowPhlGoal -(*---------------------------------------------------------------------------------------*) +(* -------------------------------------------------------------------- *) type rwe_error = | RWE_InvalidFunction of xpath * xpath | RWE_InvalidPosition @@ -20,7 +22,7 @@ exception RwEquivError of rwe_error let rwe_error e = raise (RwEquivError e) -(*---------------------------------------------------------------------------------------*) +(* -------------------------------------------------------------------- *) (* `rewrite equiv` - replace a call to a procedure with an equivalent call, using an equiv @@ -34,7 +36,15 @@ let rwe_error e = raise (RwEquivError e) and return value. *) (* FIXME: What is a good interface for this tactic? *) -let t_rewrite_equiv side dir cp (equiv : equivF) equiv_pt rargslv tc = +let t_rewrite_equiv + (side : side) + (dir : [`LtoR | `RtoL]) + (cp : EcMatching.Position.codepos1) + (equiv : equivF) + (equiv_pt : proofterm) + (rargslv : (expr list * lvalue option) option) + (tc : tcenv1) += let env = tc1_env tc in let goal = tc1_as_equivS tc in @@ -56,7 +66,6 @@ let t_rewrite_equiv side dir cp (equiv : equivF) equiv_pt rargslv tc = (* Extract the call statement and surrounding code *) let prefix, (llv, func, largs), suffix = - let cp = EcLowPhlGoal.tc1_process_codepos1 tc (Some side, cp) in let p, i, s = s_split_i env cp code in if not (is_call i) then rwe_error RWE_InvalidPosition; @@ -80,7 +89,8 @@ let t_rewrite_equiv side dir cp (equiv : equivF) equiv_pt rargslv tc = t_onselect p (t_seqs [ - EcPhlEqobs.process_eqobs_in none {sim_pos = some (cp, cp); sim_hint = ([], none); sim_eqs = none}; + EcPhlEqobs.t_eqobs_in + None EcPhlEqobs.{ empty_sim_info with sim_pos = Some (cp, cp) }; (match side, dir with | `Left, `LtoR -> t_id | `Left, `RtoL -> EcPhlSym.t_equiv_sym @@ -96,7 +106,7 @@ let t_rewrite_equiv side dir cp (equiv : equivF) equiv_pt rargslv tc = ]) tc -(*---------------------------------------------------------------------------------------*) +(* -------------------------------------------------------------------- *) (* Proccess a user call to rewrite equiv *) let process_rewrite_equiv info tc = @@ -151,6 +161,8 @@ let process_rewrite_equiv info tc = end in + let cp = EcTyping.trans_codepos1 env cp in + (* Offload to the tactic *) try t_rewrite_equiv side dir cp equiv eqv_pt rargslv tc diff --git a/src/phl/ecPhlRwEquiv.mli b/src/phl/ecPhlRwEquiv.mli index eee53c6091..0504c28b7e 100644 --- a/src/phl/ecPhlRwEquiv.mli +++ b/src/phl/ecPhlRwEquiv.mli @@ -1,12 +1,15 @@ +(* -------------------------------------------------------------------- *) open EcCoreGoal.FApi +open EcAst open EcParsetree open EcCoreGoal -open EcAst +open EcMatching.Position +(* -------------------------------------------------------------------- *) val t_rewrite_equiv : side -> [`LtoR | `RtoL ] -> - pcodepos1 -> + codepos1 -> equivF -> proofterm -> (expr list * lvalue option) option -> diff --git a/src/phl/ecPhlRwPrgm.ml b/src/phl/ecPhlRwPrgm.ml index c249539240..92b06f9670 100644 --- a/src/phl/ecPhlRwPrgm.ml +++ b/src/phl/ecPhlRwPrgm.ml @@ -1,12 +1,83 @@ (* -------------------------------------------------------------------- *) +open EcUtils open EcParsetree open EcCoreGoal open EcLowPhlGoal -open EcAst (* -------------------------------------------------------------------- *) type change_t = pcodepos * ptybindings option * int * pstmt +(* -------------------------------------------------------------------- *) +let process_change ((cpos, bindings, i, s) : change_t) (tc : tcenv1) = + let hyps = FApi.tc1_hyps tc in + let env = EcEnv.LDecl.toenv hyps in + let hs = EcLowPhlGoal.tc1_as_hoareS tc in + let cpos = EcLowPhlGoal.tc1_process_codepos tc (None, cpos) in + + let mem, _ = + let bindings = + bindings + |> Option.value ~default:[] + |> List.map (fun (xs, ty) -> List.map (fun x -> (x, ty)) xs) + |> List.flatten in + List.fold_left_map (fun mem (x, ty) -> + let ue = EcUnify.UniEnv.create (Some (EcEnv.LDecl.tohyps hyps).h_tvar) in + let ty = EcTyping.transty EcTyping.tp_tydecl env ue ty in + assert (EcUnify.UniEnv.closed ue); + let ty = + let subst = EcCoreSubst.Tuni.subst (EcUnify.UniEnv.close ue) in + EcCoreSubst.ty_subst subst ty in + let x = Option.map EcLocation.unloc (EcLocation.unloc x) in + let vr = EcAst.{ ov_name = x; ov_type = ty; } in + let (mem, _) = EcMemory.bind_fresh vr mem in + (mem, (EcTypes.pv_loc (oget x), ty)) (* FIXME *) + ) hs.hs_m bindings in + + let env = EcEnv.Memory.push_active_ss mem env in + + let s = + let ue = EcProofTyping.unienv_of_hyps (FApi.tc1_hyps tc) in + let s = EcTyping.transstmt env ue s in + + assert (EcUnify.UniEnv.closed ue); (* FIXME *) + + let sb = EcCoreSubst.Tuni.subst (EcUnify.UniEnv.close ue) in + EcCoreSubst.s_subst sb s in + + let zp = Zpr.zipper_of_cpos env cpos hs.hs_s in + + let rec pvtail (pvs : EcPV.PV.t) (zp : Zpr.ipath) = + let parent = + match zp with + | Zpr.ZTop -> None + | Zpr.ZWhile (_, p) -> Some p + | Zpr.ZIfThen (e, p, _) -> Some p + | Zpr.ZIfElse (e, _, p) -> Some p + | Zpr.ZMatch (e, p, _) -> Some p in + match parent with + | None -> pvs + | Some ((_, tl), p) -> pvtail (EcPV.PV.union pvs (EcPV.is_read env tl)) p + in + + let zp = + let target, tl = List.split_at i zp.z_tail in + + let keep = pvtail (EcPV.is_read env tl) zp.z_path in + let keep = EcPV.PV.union keep (EcPV.PV.fv env (EcMemory.memory mem) (EcAst.hs_po hs).inv) in + + begin + try + if not (EcCircuits.instrs_equiv (FApi.tc1_hyps tc) ~keep mem target s.s_node) then + tc_error !!tc "statements are not circuit-equivalent" + with EcCircuits.CircError s -> + tc_error !!tc "circuit-equivalence checker error: %s" (Lazy.force s) + end; + { zp with z_tail = s.s_node @ tl } in + + let hs = { hs with hs_s = Zpr.zip zp; hs_m = mem; } in + + FApi.xmutate1 tc `BChange EcAst.[EcFol.f_hoareS (hs.hs_m |> snd) (hs_pr hs) (hs.hs_s) (hs_po hs)] + (* -------------------------------------------------------------------- *) type idassign_t = pcodepos * pqsymbol @@ -23,10 +94,13 @@ let process_idassign ((cpos, pv) : idassign_t) (tc : tcenv1) = let s = Zpr.zipper_of_cpos env cpos hs.hs_s in let s = { s with z_tail = sasgn :: s.z_tail } in { hs with hs_s = Zpr.zip s } in - FApi.xmutate1 tc `IdAssign [EcFol.f_hoareS (snd hs.hs_m) (hs_pr hs) (hs.hs_s) (hs_po hs)] + FApi.xmutate1 tc `IdAssign EcAst.[EcFol.f_hoareS (hs.hs_m |> snd) (hs_pr hs) (hs.hs_s) (hs_po hs)] (* -------------------------------------------------------------------- *) let process_rw_prgm (mode : rwprgm) (tc : tcenv1) = match mode with | `IdAssign (cpos, pv) -> process_idassign (cpos, pv) tc + | `Change (cpos, bindings, i, s) -> + process_change (cpos, bindings, i, s) tc + diff --git a/tests/abstract_bind.ec b/tests/abstract_bind.ec new file mode 100644 index 0000000000..fb04b44def --- /dev/null +++ b/tests/abstract_bind.ec @@ -0,0 +1,70 @@ +require import AllCore List Int IntDiv CoreMap Real Number Bool. + +require import QFABV. + +abstract theory Test. +type t. + +op size : int. + +axiom size_gt0 : 0 < size. + +op add : t -> t -> t. + +op w2bits : t -> bool list. + +op bits2w : bool list -> t. + +op touint : t -> int. + +op tosint : t -> int. + +op ofint : int -> t. + +(* +axiom t_tolistP: forall (bv : t), bits2w (w2bits bv) = bv. +axiom t_oflistP: forall (xs : bool list), + size xs = Test.size => w2bits (bits2w xs) = xs. +axiom t_touintP: forall (bv : t), + Test.touint bv = BitEncoding.BS2Int.bs2int (w2bits bv). +axiom t_tosintP: forall (bv : t), + Test.size = 1 \/ + let v = BitEncoding.BS2Int.bs2int (w2bits bv) in + if msb bv then Test.tosint bv = v - 2 ^ Test.size + else Test.tosint bv = v. +axiom t_ofintP: forall (i : int), + Test.ofint i = bits2w (BitEncoding.BS2Int.int2bs Test.size i). +*) + +bind bitstring w2bits bits2w touint tosint ofint t size. + +realize gt0_size. by apply size_gt0. qed. + +realize tolistP by admit. + +realize touintP by admit. + +realize size_tolist by admit. + +realize oflistP by admit. + +realize tosintP by admit. + +realize ofintP by admit. + +bind op t add "add". + +realize bvaddP by admit. + +end Test. + +clone import Test as CTest + with type t <- bool, + op size <- 1, + op add <- (^^). + +print CTest. + +lemma xor2_false (b: bool) : b ^^ b = CTest.ofint 0. + +bdep solve. qed. diff --git a/tests/circuit_test.ec b/tests/circuit_test.ec new file mode 100644 index 0000000000..8552d8f83e --- /dev/null +++ b/tests/circuit_test.ec @@ -0,0 +1,176 @@ +require import AllCore List QFABV IntDiv. + + +theory FakeWord. +type W. +op size : int. + +op to_bits : W -> bool list. +op from_bits : bool list -> W. +op of_int : int -> W. +op to_uint : W -> int. +op to_sint : W -> int. + +bind bitstring + to_bits + from_bits + to_uint + to_sint + of_int + W + size. + +realize gt0_size by admit. +realize tolistP by admit. +realize oflistP by admit. +realize touintP by admit. +realize tosintP by admit. +realize ofintP by admit. +realize size_tolist by admit. + + + +op bool2bits (b : bool) : bool list = [b]. +op bits2bool (b: bool list) : bool = List.nth false b 0. + +op i2b : int -> bool. +op b2si (b: bool) = 0. + +bind bitstring bool2bits bits2bool b2i b2si i2b bool 1. +realize size_tolist by auto. +realize tolistP by auto. + +realize oflistP by rewrite /bool2bits /bits2bool;smt(size_eq1). +realize ofintP by admit. +realize touintP by admit. +realize tosintP by move => bv => //. +realize gt0_size by done. + +op (+^) : W -> W -> W. + +bind op W (+^) "xor". +realize bvxorP by admit. + +end FakeWord. + +type W8. + +clone include FakeWord with + op size <- 8, + type W <- W8. + +module M = { + proc test (a: W8, b: W8) = { + var c : W8; + c <- a +^ b; + return c; + } +}. + +op "_.[_]" : W8 -> int -> bool. + +bind op [W8 & bool] "_.[_]" "get". +realize le_size by auto. +realize eq1_size by auto. +realize bvgetP by admit. + +lemma W8_ext (a: W8) : List.all (fun i => a.[i] = a.[i]) (iota_ 0 8). +proof. +extens : circuit. +qed. + + + +lemma W8_xor_ext (a_ b_ : W8) : hoare[M.test : a_ = a /\ b_ = b ==> res = a_ +^ b_]. +proof. +proc. +(* extens [a] : (wp; skip; smt()). *) +(* FIXME : while debugging fhash *) admit. +qed. + + +lemma W8_xor_simp (a_ b_ : W8) : hoare[M.test : a_ = a /\ b_ = b ==> res = a_ +^ b_]. +proof. +proc. +(* circuit simplify; trivial. *) admit. +qed. + + +lemma W8_xor_ext2 (a_ b_ : W8) : hoare[M.test : a_ = a /\ b_ = b ==> res = a_ +^ b_]. +proof. +proc. +admit. +(* extens [a] : circuit. *) +qed. + +lemma W8_xor_ext_simp (a_ b_ : W8) : hoare[M.test : a_ = a /\ b_ = b ==> res = a_ +^ b_]. +proof. +proc. +(* extens [a] : by circuit simplify; trivial. (* FIXME: without by does not work *) *) admit. +qed. + + +(* +lemma xor_0 (a_ b_ : W8) : hoare[M.test : a_ = a /\ b_ = b /\ a_ = b_ ==> res = of_int 0]. +proof. + proc. + proc change 1 : { c <- b +^ a; }. + wp. skip. move => &h1 &h2. + have : a{h1} = a_ by admit. + have : b{h1} = b_ by admit. + move => A B [] C D. + have : a{h2} = a_ by smt(). + have : b{h2} = b_ by smt(). + (* move : A B C D. (* Comment or uncomment this line for different modes of working *) *) + bdep solve. +bdep solve. +qed. +*) + + +lemma xor_com (a_ b_ : W8) : hoare[M.test : a_ = a /\ b_ = b /\ a_ = b_ ==> res = b_ +^ a_]. +proof. + proc. + proc change 1 : [ d : W8 ] { d <- of_int 0; d <- a +^ d; c <- d +^ b; }. + circuit. + circuit. +qed. + +theory Array8. +type 'a t. + +op tolist : 'a t -> 'a list. +op oflist : 'a list -> 'a t. +op "_.[_]" : 'a t -> int -> 'a. +op "_.[_<-_]" : 'a t -> int -> 'a -> 'a t. + +end Array8. + +bind array Array8."_.[_]" Array8."_.[_<-_]" Array8.tolist Array8.oflist Array8.t 8. +realize gt0_size by auto. +realize tolistP by admit. +realize eqP by admit. +realize get_setP by admit. +realize get_out by admit. + + +op init_8_8 (f: int -> W8) : W8 Array8.t. + +bind op [W8 & Array8.t] init_8_8 "ainit". +realize bvainitP by admit. + +print Array8."_.[_]". + +op get : W8 Array8.t -> int -> W8 = Array8."_.[_]". + +lemma init_test (_aw: W8 Array8.t) : + init_8_8 (fun i => get _aw ((i * -1) %% 8)) = + init_8_8 (fun i => + get (init_8_8 + (get (init_8_8 (fun k => + get (init_8_8 (fun (l: int) => + get _aw ((l*5)%%8))) ((k * 3) %% 8))))) i ). +proof. +circuit. +qed. + diff --git a/tests/ext_test.ec b/tests/ext_test.ec new file mode 100644 index 0000000000..b6e679925b --- /dev/null +++ b/tests/ext_test.ec @@ -0,0 +1,13 @@ +require import AllCore Int List. + +print List.Iota.iota_. +print List.all. +print List.Iota. + +lemma random : List.all (fun i => i = i) + (List.Iota.iota_ 0 10). + proof. + + extens trivial. + qed. + diff --git a/tests/procchange.ec b/tests/procchange.ec index 6863802f09..cdc2924952 100644 --- a/tests/procchange.ec +++ b/tests/procchange.ec @@ -14,8 +14,7 @@ theory ProcChangeAssignEquiv. lemma L : equiv[M.f ~ M.f: true ==> true]. proof. proc. - proc change {1} [1..3] : { x <- 3; }. - + proc change {1} [1..3] : [y : int] { y <- 3; x <- y; }. wp. skip. smt(). abort. end ProcChangeAssignEquiv. @@ -93,9 +92,13 @@ theory ProcChangeWhileEquiv. x <- x + 1 + 0; } }. + (* proc rewrite {1} 1 /=. *) + admit. (* FIXME *) + (* proc rewrite {1} 1 /=. proc rewrite {2} 1.1 /=. sim. + *) abort. end ProcChangeWhileEquiv. diff --git a/theories/algebra/StdBigop.ec b/theories/algebra/StdBigop.ec index c3c9821ddf..9e246deef4 100644 --- a/theories/algebra/StdBigop.ec +++ b/theories/algebra/StdBigop.ec @@ -86,6 +86,13 @@ lemma big_constz (P : 'a -> bool) x s: BIA.big P (fun i => x) s = x * (count P s). proof. by rewrite BIA.sumr_const -IntID.intmulz. qed. +lemma sumz_nseq (n v : int) : 0 <= n => sumz (nseq n v) = n * v. +proof. +move=> ge0_n; rewrite sumzE (_ : n = size (iota_ 0 n)). +- by rewrite size_iota lez_maxr. +by rewrite -map_nseq BIA.big_map /(\o) /= big_constz count_predT mulzC. +qed. + lemma bigi_constz x (n m:int): n <= m => BIA.bigi predT (fun i => x) n m = x * (m - n). diff --git a/theories/datatypes/List.ec b/theories/datatypes/List.ec index 7d8d0c2aa0..480c72ecdb 100644 --- a/theories/datatypes/List.ec +++ b/theories/datatypes/List.ec @@ -1890,6 +1890,13 @@ lemma map_comp (f1 : 'b -> 'c) (f2 : 'a -> 'b) s: map (f1 \o f2) s = map f1 (map f2 s). proof. by elim: s => //= x s ->. qed. +lemma map_nseq ['a 'b] (x : 'b) (s : 'a list) : + map (fun _ => x) s = nseq (size s) x. +proof. +elim: s => /= [|s ih]; first by rewrite nseq0. +by rewrite addzC nseqS 1:size_ge0 ih. +qed. + lemma map_id (s : 'a list): map idfun s = s. proof. by elim: s => //= x s ->. qed. diff --git a/theories/datatypes/QFABV.ec b/theories/datatypes/QFABV.ec new file mode 100644 index 0000000000..cd904a4185 --- /dev/null +++ b/theories/datatypes/QFABV.ec @@ -0,0 +1,552 @@ +(* -------------------------------------------------------------------- *) +require import AllCore List Int IntDiv BitEncoding. +(* - *) import BS2Int. + +(* ==================================================================== *) +abstract theory BV. + op size : int. + + axiom [bydone] gt0_size : 0 < size. + + type bv. + + op tolist : bv -> bool list. + op oflist : bool list -> bv. + + op touint : bv -> int. + op tosint : bv -> int. + op ofint : int -> bv. + + op get (b: bv) (n: int) : bool = + List.nth false (tolist b) n. + + op msb (b: bv) : bool = + List.nth false (tolist b) (size - 1). + + axiom size_tolist (bv : bv): List.size (tolist bv) = size. + + axiom tolistP (bv : bv) : oflist (tolist bv) = bv. + axiom oflistP (xs : bool list) : size xs = size => tolist (oflist xs) = xs. + + axiom touintP (bv : bv) : + touint bv = bs2int (tolist bv). + + axiom tosintP (bv : bv) : + (size = 1) \/ + let v = bs2int (tolist bv) in + if (msb bv) then + tosint bv = v - 2^size + else + tosint bv = v. + + axiom ofintP (i : int) : + ofint i = oflist (int2bs size i). +end BV. + +(* ==================================================================== *) +(* FIXME: Missing of_list axiomatization *) +abstract theory A. + op size : int. + + axiom [bydone] gt0_size : 0 < size. + + type 'a t. + + op get ['a] : 'a t -> int -> 'a. + + op set ['a] : 'a t -> int -> 'a -> 'a t. + + op to_list ['a] : 'a t -> 'a list. + + axiom tolistP ['a] (a : 'a t) : + to_list a = mkseq (fun i => get a i) size. + + axiom eqP ['a] (a1 a2 : 'a t) : + (forall i, 0 <= i < size => get a1 i = get a2 i) + <=> (a1 = a2). + + axiom get_setP ['a] (a : 'a t) (i j : int) (v : 'a) : + 0 <= i < size + => 0 <= j < size + => get (set a j v) i = if i = j then v else get a i. + + axiom get_out ['a] (a1 a2 : 'a t) (i : int) : + !(0 <= i < size) => get a1 i = get a2 i. +end A. + +(* ==================================================================== *) +theory BVOperators. + (* ------------------------------------------------------------------ *) + abstract theory BVAdd. + clone import BV. + + op bvadd : bv -> bv -> bv. + + axiom bvaddP (bv1 bv2 : bv) : + touint (bvadd bv1 bv2) = (touint bv1 + touint bv2) %% 2^BV.size. + end BVAdd. + + (* ------------------------------------------------------------------ *) + abstract theory BVSub. + clone import BV. + + op bvsub : bv -> bv -> bv. + + axiom bvsubP (bv1 bv2 : bv) : + touint (bvsub bv1 bv2) = (touint bv1 - touint bv2) %% 2^BV.size. + end BVSub. + + abstract theory BVOpp. + clone import BV. + + op bvopp : bv -> bv. + + axiom bvoppP (bv : bv) : + tosint (bvopp bv) = -(tosint bv). + end BVOpp. + + (* ------------------------------------------------------------------ *) + abstract theory BVMul. + clone import BV. + + op bvmul : bv -> bv -> bv. + + axiom bvmulP (bv1 bv2 : bv) : + touint (bvmul bv1 bv2) = (touint bv1 * touint bv2) %% 2^BV.size. + end BVMul. + + (* ------------------------------------------------------------------ *) + abstract theory BVUDiv. + clone import BV. + + op bvudiv : bv -> bv -> bv. + + axiom bvudivP (bv1 bv2 : bv) : touint bv2 <> 0 => + touint (bvudiv bv1 bv2) = touint bv1 %/ touint bv2. + end BVUDiv. + + (* ------------------------------------------------------------------ *) + abstract theory BVURem. + clone import BV. + + op bvurem : bv -> bv -> bv. + + axiom bvuremP (bv1 bv2 : bv) : + touint (bvurem bv1 bv2) = touint bv1 %% touint bv2. + end BVURem. + + (* ------------------------------------------------------------------ *) + abstract theory BVSHL. + clone import BV. + + op bvshl : bv -> bv -> bv. + + axiom bvshlP (bv1 bv2 : bv) : touint (bvshl bv1 bv2) = + (touint bv1 * 2 ^ (touint bv2)) %% (2 ^ BV.size). + end BVSHL. + + (* ------------------------------------------------------------------ *) + abstract theory BVSHR. + clone import BV. + + op bvshr : bv -> bv -> bv. + + axiom bvshrP (bv1 bv2 : bv) : touint (bvshr bv1 bv2) = + touint bv1 %/ 2 ^ (touint bv2). + end BVSHR. + + (* ------------------------------------------------------------------ *) + abstract theory BVASHR. + clone import BV. + + op bvashr : bv -> bv -> bv. + + axiom bvashrP (bv1 bv2 : bv) : tosint (bvashr bv1 bv2) = + tosint bv1 %/ 2 ^ (touint bv2). + end BVASHR. + + (* ------------------------------------------------------------------ *) + abstract theory BVSHLS. + clone import BV as BV1. + clone import BV as BV2. + + op bvshls : BV1.bv -> BV2.bv -> BV1.bv. + + axiom bvshlsP (bv1 : BV1.bv) (bv2 : BV2.bv) : touint (bvshls bv1 bv2) = + (touint bv1 * 2 ^ (touint bv2)) %% (2 ^ BV1.size). + end BVSHLS. + + (* ------------------------------------------------------------------ *) + abstract theory BVSHRS. + clone import BV as BV1. + clone import BV as BV2. + + op bvshrs : BV1.bv -> BV2.bv -> BV1.bv. + + axiom bvshrsP (bv1 : BV1.bv) (bv2 : BV2.bv) : touint (bvshrs bv1 bv2) = + touint bv1 %/ 2 ^ (touint bv2). + end BVSHRS. + + (* ------------------------------------------------------------------ *) + abstract theory BVASHRS. + clone import BV as BV1. + clone import BV as BV2. + + op bvashrs : BV1.bv -> BV2.bv -> BV1.bv. + + axiom bvashrsP (bv1 : BV1.bv) (bv2 : BV2.bv) : tosint (bvashrs bv1 bv2) = + tosint bv1 %/ 2 ^ (touint bv2). + end BVASHRS. + + (* ------------------------------------------------------------------ *) + abstract theory BVROL. + clone import BV. + + op bvrol : bv -> bv -> bv. + + axiom bvrolP (bv1 bv2 : bv) (i: int) : + 0 <= i < BV.size => + List.nth false (tolist (bvrol bv1 bv2)) i = + List.nth false (tolist bv1) ((i-touint bv2)%%BV.size). + + end BVROL. + + (* ------------------------------------------------------------------ *) + abstract theory BVROR. + clone import BV. + + op bvror : bv -> bv -> bv. + + axiom bvrorP (bv1 bv2 : bv) (i: int): + 0 <= i < BV.size => + List.nth false (tolist (bvror bv1 bv2)) i = + List.nth false (tolist bv1) ((i+touint bv2)%%BV.size). + + end BVROR. + + (* ------------------------------------------------------------------ *) + abstract theory BVAnd. + clone import BV. + + op bvand : bv -> bv -> bv. + + axiom bvandP (bv1 bv2 : bv) : tolist (bvand bv1 bv2) = + map (fun (b : _ * _) => b.`1 /\ b.`2) (zip (tolist bv1) (tolist bv2)). + end BVAnd. + + (* ------------------------------------------------------------------ *) + abstract theory BVOr. + clone import BV. + + op bvor : bv -> bv -> bv. + + axiom bvorP (bv1 bv2 : bv) : tolist (bvor bv1 bv2) = + map (fun (b : _ * _) => b.`1 \/ b.`2) (zip (tolist bv1) (tolist bv2)). + end BVOr. + + (* ------------------------------------------------------------------ *) + abstract theory BVXor. + clone import BV. + + op bvxor: bv -> bv -> bv. + + axiom bvxorP (bv1 bv2 : bv) : tolist (bvxor bv1 bv2) = + map (fun (b : _ * _) => Bool.(^^) b.`1 b.`2)%Bool (zip (tolist bv1) (tolist bv2)). + end BVXor. + + (* ------------------------------------------------------------------ *) + abstract theory BVNot. + clone import BV. + + op bvnot : bv -> bv. + + axiom bvnotP (bv : bv) : tolist (bvnot bv) = + map (fun b => !b) (tolist bv). + end BVNot. + + (* ------------------------------------------------------------------ *) + abstract theory BVULt. + clone import BV as BV1 with op size <= 1. + clone import BV as BV2. + + op bvult : BV2.bv -> BV2.bv -> BV1.bv. + + axiom bvultP (bv1 bv2 : BV2.bv) : + BV1.touint (bvult bv1 bv2) <> 0 <=> (BV2.touint bv1 < BV2.touint bv2). + end BVULt. + +(* ------------------------------------------------------------------ *) + abstract theory BVSLt. + clone import BV as BV1 with op size <= 1. + clone import BV as BV2. + + op bvslt : BV2.bv -> BV2.bv -> BV1.bv. + + axiom bvsltP (bv1 bv2 : BV2.bv) : + BV1.touint (bvslt bv1 bv2) <> 0 <=> (BV2.tosint bv1 < BV2.tosint bv2). + end BVSLt. + + + (* ------------------------------------------------------------------ *) + abstract theory BVULe. + clone import BV as BV1 with op size <= 1. + clone import BV as BV2. + + op bvule : BV2.bv -> BV2.bv -> BV1.bv. + + axiom bvuleP (bv1 bv2 : BV2.bv) : + BV1.touint (bvule bv1 bv2) <> 0 <=> (BV2.touint bv1 <= BV2.touint bv2). + end BVULe. + +(* ------------------------------------------------------------------ *) + abstract theory BVSLe. + clone import BV as BV1 with op size <= 1. + clone import BV as BV2. + + op bvsle : BV2.bv -> BV2.bv -> BV1.bv. + + axiom bvsleP (bv1 bv2 : BV2.bv) : + BV1.touint (bvsle bv1 bv2) <> 0 <=> (BV2.tosint bv1 <= BV2.tosint bv2). + end BVSLe. + + + (* ------------------------------------------------------------------ *) + abstract theory BVZExtend. + clone BV as BV1. + clone BV as BV2. + + axiom [bydone] le_size : BV1.size <= BV2.size. + + op bvzextend : BV1.bv -> BV2.bv. + + axiom bvzextendP (bv : BV1.bv) : + BV1.touint bv = BV2.touint (bvzextend bv). + end BVZExtend. + +(* ------------------------------------------------------------------ *) + abstract theory BVSExtend. + clone BV as BV1. + clone BV as BV2. + + axiom [bydone] le_size : BV1.size <= BV2.size. + + op bvsextend : BV1.bv -> BV2.bv. + + axiom bvsextendP (bv : BV1.bv) : + BV1.tosint bv = BV2.tosint (bvsextend bv). + end BVSExtend. + + (* ------------------------------------------------------------------ *) + abstract theory BVTruncate. + clone BV as BV1. + clone BV as BV2. + + axiom [bydone] le_size : BV2.size <= BV1.size. + + op bvtruncate : BV1.bv -> BV2.bv. + + axiom bvtruncateP (bv : BV1.bv) : + take BV2.size (BV1.tolist bv) = BV2.tolist (bvtruncate bv). + end BVTruncate. + + (* ------------------------------------------------------------------ *) + abstract theory BVExtract. + clone BV as BV1. + clone BV as BV2. + + axiom [bydone] le_size : BV2.size <= BV1.size. + + op bvextract : BV1.bv -> int -> BV2.bv. + + axiom bvextractP (bv : BV1.bv) (base : int) : 0 <= base => base + BV2.size <= BV1.size => + take BV2.size (drop base (BV1.tolist bv)) = BV2.tolist (bvextract bv base). + end BVExtract. + +print List.mkseq. + +(* ------------------------------------------------------------------ *) + abstract theory BVInsert. + clone BV as BV1. + clone BV as BV2. + + axiom [bydone] le_size : BV2.size <= BV1.size. + + op bvinsert : BV1.bv -> int -> BV2.bv -> BV1.bv. + + axiom bvinsertP (bv : BV1.bv) (base : int) (bvins: BV2.bv) : 0 <= base => base + BV2.size <= BV1.size => + let orig = BV1.tolist bv in + let new = BV2.tolist bvins in + List.mkseq (fun i => if i < base || base + BV2.size <= i + then List.nth witness orig i + else List.nth witness new (i - base)) + BV1.size + = BV1.tolist (bvinsert bv base bvins). + end BVInsert. + +(* ------------------------------------------------------------------ *) + abstract theory BVGet. + clone BV as BV1. + clone BV as BV2. + + axiom [bydone] le_size : BV2.size <= BV1.size. + axiom [bydone] eq1_size : BV2.size = 1. + + op bvget : BV1.bv -> int -> BV2.bv. + + axiom bvgetP (bv : BV1.bv) (idx: int) : + List.nth false (BV2.tolist (bvget bv idx)) 0 = List.nth false (BV1.tolist bv) idx. + end BVGet. + + (* ------------------------------------------------------------------ *) + abstract theory BVASliceGet. + clone BV as BV1. + clone BV as BV2. + clone A. + + axiom [bydone] le_size : BV2.size <= BV1.size * A.size. + + op bvasliceget : (BV1.bv A.t) -> int -> BV2.bv. + + (* We need the definition of target semantic to allow + a rewrite without conditions, but the binding just + needs to be correct for valid offsets *) + axiom bvaslicegetP (arr : BV1.bv A.t) (offset : int) : + 0 <= offset <= BV1.size * A.size - BV2.size => + let base = List.flatten (List.map BV1.tolist (A.to_list arr)) in + let ret = bvasliceget arr offset in + forall i, 0 <= i < BV2.size => + nth false (BV2.tolist ret) i = nth false (take BV2.size (List.drop offset base)) i. + end BVASliceGet. + + (* ------------------------------------------------------------------ *) + abstract theory BVASliceSet. + clone BV as BV1. + clone BV as BV2. + clone A. + + axiom [bydone] le_size : BV2.size <= BV1.size * A.size. + + op bvasliceset : (BV1.bv A.t) -> int -> (BV2.bv) -> BV1.bv A.t. + + (* We need the definition of target semantic to allow + a rewrite without conditions, but the binding just + needs to be correct for valid offsets *) + axiom bvaslicesetP (arr : BV1.bv A.t) (offset : int) (bv: BV2.bv): + 0 <= offset <= BV1.size * A.size - BV2.size => + let input_arr = List.flatten (List.map (BV1.tolist) (A.to_list arr)) in + let input_bv = BV2.tolist bv in + let output_arr = List.flatten (List.map BV1.tolist (A.to_list (bvasliceset arr offset bv))) in + forall i, 0 <= i < BV1.size * A.size => + List.nth false output_arr i = + if offset <= i < offset + BV2.size then + List.nth false input_bv (i - offset) + else + List.nth false input_arr i. + end BVASliceSet. + + (* ------------------------------------------------------------------ *) + abstract theory BVConcat. + clone BV as BV1. + clone BV as BV2. + clone BV as BV3. + + axiom [bydone] eq_size : BV1.size + BV2.size = BV3.size. + + op bvconcat : BV1.bv -> BV2.bv -> BV3.bv. + + axiom bvconcatP (bv1 : BV1.bv) (bv2 : BV2.bv) : + BV3.tolist (bvconcat bv1 bv2) = BV1.tolist bv1 ++ BV2.tolist bv2. + end BVConcat. + + (* ------------------------------------------------------------------ *) + abstract theory BVInit. + clone BV as BV1. + clone BV as BV2. + + axiom [bydone] size_1 : BV1.size = 1. + + op bvinit : (int -> BV1.bv) -> BV2.bv. + + axiom bvinitP (f : int -> BV1.bv) : + BV2.tolist (bvinit f) = List.flatten (List.mkseq (fun i => BV1.tolist (f i)) BV2.size). + end BVInit. + + (* ------------------------------------------------------------------ *) + abstract theory BVAInit. + clone BV. + clone A. + + op bvainit : (int -> BV.bv) -> BV.bv A.t. + + axiom bvainitP (f : int -> BV.bv) : + A.to_list (bvainit f) = List.mkseq (fun i => (f i)) A.size. + end BVAInit. + + (* ------------------------------------------------------------------ *) + abstract theory BVMap. + clone BV as BV1. + clone BV as BV2. + clone A. + + op map (f: BV1.bv -> BV2.bv) (abv: BV1.bv A.t) : BV2.bv A.t. + + axiom mapP (f: BV1.bv -> BV2.bv) (abv: BV1.bv A.t) : + A.to_list (map f abv) = List.map f (A.to_list abv). + end BVMap. + + (* ------------------------------------------------------------------ *) + abstract theory BVA2B. + clone BV as BV1. + clone BV as BV2. + clone A. + + axiom [bydone] size_ok : A.size * BV2.size = BV1.size. + + op bva2b : BV2.bv A.t -> BV1.bv. + + axiom a2bP (bva : BV2.bv A.t) : + flatten (map BV2.tolist (A.to_list bva)) = BV1.tolist (bva2b bva). + end BVA2B. + + (* ------------------------------------------------------------------ *) + abstract theory BVB2A. + clone BV as BV1. + clone BV as BV2. + clone A. + + axiom [bydone] size_ok : A.size * BV2.size = BV1.size. + + op bvb2a : BV1.bv -> BV2.bv A.t. + + axiom b2aP (bva : BV1.bv) : + BV1.tolist bva = flatten (map BV2.tolist (A.to_list (bvb2a bva))). + end BVB2A. + + (* ------------------------------------------------------------------ *) + abstract theory A2B2A. (* choubidoubidou *) + clone BV as BV1. + clone BV as BV2. + clone import A. + + axiom [bydone] size_ok : A.size * BV2.size = BV1.size. + + clone import BVA2B with + theory BV1 <- BV1, + theory BV2 <- BV2, + theory A <- A + proof size_ok by exact/size_ok. + + clone import BVB2A with + theory BV1 <- BV1, + theory BV2 <- BV2, + theory A <- A + proof size_ok by exact/size_ok. + + lemma a2bK : cancel bva2b bvb2a. + proof. admitted. + + lemma b2aK : cancel bva2b bvb2a. + proof. admitted. + end A2B2A. +end BVOperators. + diff --git a/theories/dune b/theories/dune index ef2fab8389..ec4b5757c2 100644 --- a/theories/dune +++ b/theories/dune @@ -1,4 +1,3 @@ (install (section (site (easycrypt theories))) (files (glob_files_rec *.{ec,eca}))) -