diff --git a/.gitignore b/.gitignore index 5022b5e..f4175d3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .vscode /target /examples +.*.sw* diff --git a/CHANGELOG.md b/CHANGELOG.md index de9a510..2b558da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,47 +1,57 @@ # Release Notes +## v0.8.0 (2023-01-10) + +### Features + +- Added initial support for Circom 2.1.2, which includes tags, tuples, and + anonymous components. +- Added templates to the `bn128-specific-circuits` analysis pass. + +### Bug fixes + +- Rewrote the `unconstrained-less-than` analysis pass to better capture the + underlying issue. +- Fixed an issue where the cyclomatic complexity calculation could underflow + in some cases in the `overly-complex-function-or-template` analysis pass. ## v0.7.2 (2022-12-01) ### Features - - Added a URL to the issue description for each output. - +- Added a URL to the issue description for each output. ### Bug Fixes - - Rewrote description of the unconstrained less-than analysis pass, as the +- Rewrote description of the unconstrained less-than analysis pass, as the previous description was too broad. - - Fixed grammar in the under-constrained signal warning message. - +- Fixed grammar in the under-constrained signal warning message. ## v0.7.0 (2022-11-29) - ### Features - - New analysis pass (`unconstrained-less-than`) that detects uses of the +- New analysis pass (`unconstrained-less-than`) that detects uses of the Circomlib `LessThan` template where the input signals are not constrained to be less than the bit size passed to `LessThan`. - - New analysis pass (`unconstrained-division`) that detects signal assignments +- New analysis pass (`unconstrained-division`) that detects signal assignments containing division, where the divisor is not constrained to be non-zero. - - New analysis pass (`bn128-specific-circuits`) that detects uses of Circomlib +- New analysis pass (`bn128-specific-circuits`) that detects uses of Circomlib templates with hard-coded BN128-specific constants together with a custom curve like BLS12-381 or Goldilocks. - - New analysis pass (`under-constrained-signal`) that detects intermediate +- New analysis pass (`under-constrained-signal`) that detects intermediate signals which do not occur in at least two separate constraints. - - Rule name is now included in Sarif output. (The rule name is now also +- Rule name is now included in Sarif output. (The rule name is now also displayed by the VSCode Sarif extension.) - - Improved parsing error messages. - +- Improved parsing error messages. ### Bug Fixes - - Fixed an issue during value propagation where values would be propagated to +- Fixed an issue during value propagation where values would be propagated to arrays by mistake. - - Fixed an issue in the `nonstrict-binary-conversion` analysis pass where +- Fixed an issue in the `nonstrict-binary-conversion` analysis pass where some instantiations of `Num2Bits` and `Bits2Num` would not be detected. - - Fixed an issue where the maximum degree of switch expressions were evaluated +- Fixed an issue where the maximum degree of switch expressions were evaluated incorrectly. - - Previous versions could take a very long time to complete value and degree +- Previous versions could take a very long time to complete value and degree propagation. These analyses are now time boxed and will exit if the analysis takes more than 10 seconds to complete. diff --git a/Cargo.lock b/Cargo.lock index b03e2fb..4fd4c1e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -162,7 +162,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "circomspect" -version = "0.7.2" +version = "0.8.0" dependencies = [ "anyhow", "atty", @@ -178,7 +178,7 @@ dependencies = [ [[package]] name = "circomspect-circom-algebra" -version = "2.0.1" +version = "2.0.2" dependencies = [ "num-bigint-dig 0.8.2", "num-traits", @@ -186,7 +186,7 @@ dependencies = [ [[package]] name = "circomspect-parser" -version = "2.0.11" +version = "2.1.2" dependencies = [ "circomspect-program-structure", "lalrpop 0.18.1", @@ -204,7 +204,7 @@ dependencies = [ [[package]] name = "circomspect-program-analysis" -version = "0.7.2" +version = "0.8.0" dependencies = [ "anyhow", "circomspect-parser", @@ -217,7 +217,7 @@ dependencies = [ [[package]] name = "circomspect-program-structure" -version = "2.0.11" +version = "2.1.2" dependencies = [ "anyhow", "atty", @@ -240,7 +240,7 @@ dependencies = [ [[package]] name = "circomspect-program-structure-tests" -version = "0.6.1" +version = "0.8.0" dependencies = [ "circomspect-parser", "circomspect-program-structure", diff --git a/Cargo.toml b/Cargo.toml index 2f202fc..70be6ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,5 +4,5 @@ members = [ "parser", "program_analysis", "program_structure", - "program_structure_tests" + "program_structure_tests", ] diff --git a/circom_algebra/Cargo.toml b/circom_algebra/Cargo.toml index 6d8caa0..94054cd 100644 --- a/circom_algebra/Cargo.toml +++ b/circom_algebra/Cargo.toml @@ -1,7 +1,8 @@ [package] name = "circomspect-circom-algebra" -version = "2.0.1" -edition = "2018" +version = "2.0.2" +edition = "2021" +rust-version = "1.65" license = "LGPL-3.0-only" authors = ["hermeGarcia "] description = "Support crate for the Circomspect static analyzer" diff --git a/cli/Cargo.toml b/cli/Cargo.toml index f9d9fb7..84aee8b 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -1,13 +1,14 @@ [package] name = "circomspect" -version = "0.7.2" +version = "0.8.0" edition = "2021" +rust-version = "1.65" license = "LGPL-3.0-only" authors = ["Trail of Bits"] readme = "../README.md" description = "A static analyzer and linter for the Circom zero-knowledge DSL" -repository = "https://github.com/trailofbits/circomspect" keywords = ["cryptography", "static-analysis", "zero-knowledge", "circom"] +repository = "https://github.com/trailofbits/circomspect" [dependencies] anyhow = "1.0" @@ -15,9 +16,9 @@ atty = "0.2" # Stay on Clap version 3 until version 4 supports coloured help output. clap = { version = "3.2", features = ["derive"] } log = "0.4" -parser = { package = "circomspect-parser", version = "2.0.11", path = "../parser" } +parser = { package = "circomspect-parser", version = "2.1.2", path = "../parser" } pretty_env_logger = "0.4" -program_analysis = { package = "circomspect-program-analysis", version = "0.7.1", path = "../program_analysis" } -program_structure = { package = "circomspect-program-structure", version = "2.0.11", path = "../program_structure" } +program_analysis = { package = "circomspect-program-analysis", version = "0.8.0", path = "../program_analysis" } +program_structure = { package = "circomspect-program-structure", version = "2.1.2", path = "../program_structure" } serde_json = "1.0" termcolor = "1.1" diff --git a/cli/src/config.rs b/cli/src/config.rs deleted file mode 100644 index f735274..0000000 --- a/cli/src/config.rs +++ /dev/null @@ -1,5 +0,0 @@ -use program_structure::ast::Version; - -pub(crate) const COMPILER_VERSION: Version = (2, 1, 2); -pub(crate) const DEFAULT_LEVEL: &str = "WARNING"; -pub(crate) const DEFAULT_CURVE: &str = "BN128"; diff --git a/cli/src/main.rs b/cli/src/main.rs index 6a67eb0..f522f72 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -1,11 +1,9 @@ -mod config; -mod analysis_runner; - use std::path::PathBuf; use std::process::ExitCode; -use clap::{CommandFactory, Parser}; +use clap::{ArgAction, CommandFactory, Parser}; -use analysis_runner::AnalysisRunner; +use program_analysis::config; +use program_analysis::analysis_runner::AnalysisRunner; use program_structure::constants::Curve; use program_structure::report::Report; @@ -19,6 +17,10 @@ struct Cli { #[clap(name = "INPUT")] input_files: Vec, + /// Analyze included files recursively + #[clap(short = 'i', long = "follow-includes", action = ArgAction::SetTrue)] + follow_includes: bool, + /// Output level (INFO, WARNING, or ERROR) #[clap(short = 'l', long = "level", name = "LEVEL", default_value = config::DEFAULT_LEVEL)] output_level: MessageCategory, @@ -67,8 +69,11 @@ fn main() -> ExitCode { let mut stdout_writer = CachedStdoutWriter::new(options.verbose) .add_filter(move |report: &Report| filter_by_id(report, &allow_list)) .add_filter(move |report: &Report| filter_by_level(report, &options.output_level)); - let mut runner = AnalysisRunner::new(&options.curve); - runner.with_files(&options.input_files, &mut stdout_writer); + let mut runner = AnalysisRunner::new(options.curve).with_files( + &options.input_files, + options.follow_includes, + &mut stdout_writer, + ); runner.analyze_functions(&mut stdout_writer); runner.analyze_templates(&mut stdout_writer); diff --git a/doc/analysis_passes.md b/doc/analysis_passes.md index 2af7f09..28be701 100644 --- a/doc/analysis_passes.md +++ b/doc/analysis_passes.md @@ -36,7 +36,6 @@ An assigned value which does not contribute either directly or indirectly to a c Here, `lout` no longer influences the generated circuit, which is detected by Circomspect. - ### Shadowing variable A shadowing variable declaration is a declaration of a variable with the same name as a previously declared variable. This does not have to be a problem, but if a variable declared in an outer scope is shadowed by mistake, this could change the semantics of the program which would be an issue. @@ -57,7 +56,6 @@ For example, consider this function which is supposed to compute the number of b Since a new variable `r` is declared in the while-statement body, the outer variable is never updated and the return value is always 0. - ### Signal assignment Signals should typically be assigned using the constraint assignment operator `<==`. This ensures that the circuit and witness generation stay in sync. If `<--` is used it is up to the developer to ensure that the signal is properly constrained. Circomspect will try to detect if the right-hand side of the assignment is a quadratic expression. If it is, the signal assignment can be rewritten using the constraint assignment operator `<==`. @@ -66,17 +64,14 @@ However, sometimes it is not possible to express the assignment using a quadrati The Tornado Cash codebase was originally affected by an issue of this type. For details see the Tornado Cash disclosure [here](https://tornado-cash.medium.com/tornado-cash-got-hacked-by-us-b1e012a3c9a8). - ### Under-constrained signal Under-constrained signals are one of the most common issues in zero-knowledge circuits. Circomspect will flag intermediate signals that only occur in a single constraint. Since intermediate signals are not available outside the template, this typically indicates an issue with the implementation. - ### Constant branching condition If a branching statement condition always evaluates to either `true` or `false`, this means that the branch is either always taken, or never taken. This typically indicates a mistake in the code which should be fixed. - ### Non-strict binary conversion Using `Num2Bits` and `Bits2Num` from @@ -88,10 +83,9 @@ uniquely determined by the input. For example, suppose that we create a component `n2b` given by `Num2Bits(254)` and set the input to `1`. Now, both the binary representation of `1` _and_ the representation of `p + 1` (where `p` is the order of the underlying finite field) will satisfy the circuit over BN128, since both are 254-bit numbers. If you cannot restrict the input size below the prime size you should use the strict versions `Num2Bits_strict` and `Bits2Num_strict` to convert to and from binary representation. Circomspect will generate a warning if it cannot prove (using constant propagation) that the input size passed to `Num2Bits` or `Bits2Num` is less than the size of the prime in bits. - ### Unconstrained less-than -The Circomlib `LessThan` template takes an input size as argument. If the individual input signals are not constrained to the input size (for example using the Circomlib `Num2Bits` circuit), it is possible to find inputs `a` and `b` such that `a > b`, but `LessThan` still evaluates to true when given `a` and `b` as inputs. +The Circomlib `LessThan` template takes an input size as argument. If the individual input signals are not constrained to be non-negative (for example using the Circomlib `Num2Bits` circuit), it is possible to find inputs `a` and `b` such that `a > b`, but `LessThan` still evaluates to true when given `a` and `b` as inputs. For example, consider the following template which takes a single input signal and attempts to constrain it to be less than two. @@ -99,33 +93,24 @@ and attempts to constrain it to be less than two. ```cpp template LessThanTwo() { signal input in; - signal output out; component lt = LessThan(8); lt.in[0] <== in; lt.in[1] <== 2; - out <== lt.out; + lt.out === 1; } ``` -Suppose that we define the private input `in` as `p - 254`, where `p` is the prime order of the field. This would result in the constraints - -```cpp - lt.in[0] <== p - 254; - lt.in[1] <== 2; -``` - -Since `p` is at least 64 bits, `p - 254` is not less than two (at least not when viewed as an unsigned integer), so we would perhaps expect `LessThanTwo` to return zero here. However, looking at [the implementation](https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/comparators.circom#L89-L99) of `LessThan`, we see that `lt.out` is given by +Suppose that we define the private input `in` as `p - 254`, where `p` is the prime order of the field. Clearly, `p - 254` is not less than two (at least not when viewed as an unsigned integer), so we would perhaps expect `LessThanTwo` to fail. However, looking at [the implementation](https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/comparators.circom#L89-L99) of `LessThan`, we see that `lt.out` is given by ```cpp - 1 - n2b.out[8] = 1 - (bit 8 of (p - 254) + 256 - 2) = 1 - 0 = 1. + 1 - n2b.out[8] = 1 - bit 8 of (p - 254 + (1 << 8) - 2) = 1 - 0 = 1. ``` -It follows that `p - 254` satisfies `LessThanTwo()`, which is perhaps not what we expected. Note that, `p - 254` is equal to -254 which _is_ less than two, so there is nothing wrong with the Circomlib `LessThan` circuit. This may just be unexpected behavior if we're thinking of field elements as unsigned integers. - -Circomspect will check if the inputs to `LessThan` are constrained to the input size using `Num2Bits`. If it cannot prove that both inputs are constrained in this way, a warning is generated. +It follows that `p - 254` satisfies `LessThanTwo()`, which is probably not what we expected. Note that, `p - 254` is equal to -254 which _is_ less than two, so there is nothing wrong with the Circomlib `LessThan` circuit. This may just be unexpected behavior if we're thinking of field elements as unsigned integers. +Circomspect will check if the inputs to `LessThan` are constrained to be strictly less than `log(p) - 1` bits using `Num2Bits`. This guarantees that both inputs are non-negative, which avoids this issue. If it cannot prove that both inputs are constrained in this way, a warning is generated. ### Unconstrained division @@ -140,46 +125,55 @@ This forces `c` to be equal to `a / b` during witness generation, and checks tha Circomspect will identify signal assignments on the form `c <-- a / b` and ensure that the expression `b` is constrained to be non-zero using the Circomlib `IsZero` template. If no such constraint is found, a warning is emitted. - ### BN128 specific circuit Circom defaults to using the BN128 scalar field (a 254-bit prime field), but it also supports BSL12-381 (which has a 255-bit scalar field) and Goldilocks (with a 64-bit scalar field). However, since there are no constants denoting either the prime or the prime size in bits available in the Circom language, some Circomlib templates like `Sign` (which returns the sign of the input signal), and `AliasCheck` (used by the strict versions of `Num2Bits` and `Bits2Num`), hardcode either the BN128 prime size or some other constant related to BN128. Using these circuits with a custom prime may thus lead to unexpected results and should be avoided. -Circomlib templates that may be problematic when used together with curves other than BN128 include the following circuit definitions. - - | Template | Circomlib Source File | - | ------------------------- | -------------------------------- | - | `Sign` | circuits/sign.circom | - | `AliasCheck` | circuits/aliascheck.circom | - | `CompConstant` | circuits/compconstant.circom | - | `Num2Bits_strict` | circuits/bitify.circom | - | `Bits2Num_strict` | circuits/bitify.circom | - | `Bits2Point_Strict` | circuits/bitify.circom | - | `Point2Bits_Strict` | circuits/bitify.circom | - | `SMTVerifier` | circuits/smt/smtverifier.circom | - | `SMTProcessor` | circuits/smt/smtprocessor.circom | - | `EdDSAVerifier` | circuits/eddsa.circom | - | `EdDSAPoseidonVerifier` | circuits/eddsaposeidon.circom | - | `EdDSAMiMCSpongeVerifier` | circuits/eddsamimcsponge.circom | - +Circomlib templates that may be problematic when used together with curves other than BN128 include the following circuit definitions. (An `x` means that the template should not be used together with the corresponding curve.) + +| Template | Goldilocks (64 bits) | BLS12-381 (255 bits) | +| :------------------------ | :------------------: | :------------------: | +| `AliasCheck` | x | x | +| `BabyPbk` | x | | +| `Bits2Num_strict` | x | x | +| `Num2Bits_strict` | x | x | +| `CompConstant` | x | x | +| `EdDSAVerifier` | x | x | +| `EdDSAMiMCVerifier` | x | x | +| `EdDSAMiMCSpongeVerifier` | x | x | +| `EdDSAPoseidonVerifier` | x | x | +| `EscalarMulAny` | x | | +| `MiMC7` | x | | +| `MultiMiMC7` | x | | +| `MiMCFeistel` | x | | +| `MiMCSponge` | x | | +| `Pedersen` | x | | +| `Bits2Point_strict` | x | x | +| `Point2Bits_strict` | x | x | +| `PoseidonEx` | x | | +| `Poseidon` | x | | +| `Sign` | x | x | +| `SMTHash1` | x | | +| `SMTHash2` | x | | +| `SMTProcessor` | x | x | +| `SMTProcessorLevel` | x | | +| `SMTVerifier` | x | x | +| `SMTVerifierLevel` | x | | ### Overly complex function or template As functions and templates grow in complexity they become more difficult to review and maintain. This typically indicates that the code should be refactored into smaller, more easily understandable, components. Circomspect uses cyclomatic complexity to estimate the complexity of each function and template, and will generate a warning if the code is considered too complex. Circomspect will also generate a warning if a function or template takes too many arguments, as this also impacts the readability of the code. - ### Bitwise complement Circom supports taking the 256-bit complement `~x` of a field element `x`. Since the result is reduced modulo `p`, it will typically not satisfy the expected relations `(~x)ᵢ == ~(xᵢ)` for each bit `i`, which could lead to surprising results. - ### Field element arithmetic Circom supports a large number of arithmetic expressions. Since arithmetic expressions can overflow or underflow in Circom it is worth paying extra attention to field arithmetic to ensure that elements are constrained to the correct range. - ### Field element comparison Field elements are normalized to the interval `(-p/2, p/2]` before they are compared, by first reducing them modulo `p` and then mapping them to the correct interval by subtracting `p` from the value `x`, if `x` is greater than `p/2`. In particular, this means that `p/2 + 1 < 0 < p/2 - 1`. This can be surprising if you are used to thinking of elements in `GF(p)` as unsigned integers. diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..c008ddc --- /dev/null +++ b/flake.lock @@ -0,0 +1,111 @@ +{ + "nodes": { + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1668681692, + "narHash": "sha256-Ht91NGdewz8IQLtWZ9LCeNXMSXHUss+9COoqu6JLmXU=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "009399224d5e398d03b22badca40a37ac85412a1", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "locked": { + "lastModified": 1667395993, + "narHash": "sha256-nuEHfE/LcWyuSWnS8t12N1wc105Qtau+/OdUAjtQ0rA=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "5aed5285a952e0b949eb3ba02c12fa4fcfef535f", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "locked": { + "lastModified": 1659877975, + "narHash": "sha256-zllb8aq3YO3h8B/U0/J1WBgAL8EX5yWf5pMj3G0NAmc=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "c0e246b9b83f637f4681389ecabcb2681b4f3af0", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1672262501, + "narHash": "sha256-ZNXqX9lwYo1tOFAqrVtKTLcJ2QMKCr3WuIvpN8emp7I=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "e182da8622a354d44c39b3d7a542dc12cd7baa5f", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1665296151, + "narHash": "sha256-uOB0oxqxN9K7XGF1hcnY+PQnlQJ+3bP2vCn/+Ru/bbc=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "14ccaaedd95a488dd7ae142757884d8e125b3363", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs", + "rust-overlay": "rust-overlay" + } + }, + "rust-overlay": { + "inputs": { + "flake-utils": "flake-utils_2", + "nixpkgs": "nixpkgs_2" + }, + "locked": { + "lastModified": 1672367043, + "narHash": "sha256-4/40kfJysfDEfSpXJ3inuMetn40czz5Mh73SjxsKTX0=", + "owner": "oxalica", + "repo": "rust-overlay", + "rev": "e6b2214363f5e18576a3b2ca0e0483d8f42fe531", + "type": "github" + }, + "original": { + "owner": "oxalica", + "repo": "rust-overlay", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..680ba03 --- /dev/null +++ b/flake.nix @@ -0,0 +1,37 @@ +{ + description = "A devShell example"; + + inputs.nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; + + inputs.flake-utils.url = "github:numtide/flake-utils"; + + inputs.flake-compat.url = "github:edolstra/flake-compat"; + inputs.flake-compat.flake = false; + + inputs.rust-overlay.url = "github:oxalica/rust-overlay"; + + outputs = { self, nixpkgs, flake-utils, flake-compat, rust-overlay, + ... }: + flake-utils.lib.eachDefaultSystem (system: + let + overlays = [ (import rust-overlay) ]; + pkgs = import nixpkgs { inherit system overlays; }; + stableToolchain = pkgs.rust-bin.stable."1.66.0".minimal.override { + extensions = [ "rustfmt" "clippy" ]; + }; + in with pkgs; + { + devShell = pkgs.mkShell { + buildInputs = with pkgs; + [ + stableToolchain + ]; + + RUST_BACKTRACE = 1; + RUST_LOG = "info"; + }; + } + ); + + +} diff --git a/parser/Cargo.toml b/parser/Cargo.toml index 5bdc357..edc9a7f 100644 --- a/parser/Cargo.toml +++ b/parser/Cargo.toml @@ -1,7 +1,8 @@ [package] name = "circomspect-parser" -version = "2.0.11" -edition = "2018" +version = "2.1.2" +edition = "2021" +rust-version = "1.65" build = "build.rs" license = "LGPL-3.0-only" description = "Support crate for the Circomspect static analyzer" @@ -18,7 +19,7 @@ num-bigint-dig = "0.6" num-traits = "0.2" [dependencies] -program_structure = { package = "circomspect-program-structure", version = "2.0.11", path = "../program_structure" } +program_structure = { package = "circomspect-program-structure", version = "2.1.2", path = "../program_structure" } lalrpop = { version = "0.19", features = ["lexer"] } lalrpop-util = "0.19" log = "0.4" @@ -30,4 +31,4 @@ serde = "1.0" serde_derive = "1.0" [dev-dependencies] -program_structure = { package = "circomspect-program-structure", version = "2.0.11", path = "../program_structure" } +program_structure = { package = "circomspect-program-structure", version = "2.1.2", path = "../program_structure" } diff --git a/parser/src/errors.rs b/parser/src/errors.rs index a4eadcf..93f53ea 100644 --- a/parser/src/errors.rs +++ b/parser/src/errors.rs @@ -1,4 +1,4 @@ -use program_structure::abstract_syntax_tree::ast::Version; +use program_structure::ast::{Meta, Version}; use program_structure::report_code::ReportCode; use program_structure::report::Report; use program_structure::file_definition::{FileID, FileLocation}; @@ -103,6 +103,64 @@ impl NoCompilerVersionWarning { } } +pub struct AnonymousComponentError { + pub meta: Option, + pub message: String, + pub primary: Option, +} + +impl AnonymousComponentError { + pub fn new(meta: Option<&Meta>, message: &str, primary: Option<&str>) -> Self { + AnonymousComponentError { + meta: meta.cloned(), + message: message.to_string(), + primary: primary.map(ToString::to_string), + } + } + + pub fn into_report(self) -> Report { + let mut report = Report::error(self.message, ReportCode::AnonymousComponentError); + if let Some(meta) = self.meta { + let primary = self.primary.unwrap_or_else(|| "The problem occurs here.".to_string()); + report.add_primary(meta.file_location(), meta.get_file_id(), primary); + } + report + } + + pub fn boxed_report(meta: &Meta, message: &str) -> Box { + Box::new(Self::new(Some(meta), message, None).into_report()) + } +} + +pub struct TupleError { + pub meta: Option, + pub message: String, + pub primary: Option, +} + +impl TupleError { + pub fn new(meta: Option<&Meta>, message: &str, primary: Option<&str>) -> Self { + TupleError { + meta: meta.cloned(), + message: message.to_string(), + primary: primary.map(ToString::to_string), + } + } + + pub fn into_report(self) -> Report { + let mut report = Report::error(self.message, ReportCode::TupleError); + if let Some(meta) = self.meta { + let primary = self.primary.unwrap_or_else(|| "The problem occurs here.".to_string()); + report.add_primary(meta.file_location(), meta.get_file_id(), primary); + } + report + } + + pub fn boxed_report(meta: &Meta, message: &str) -> Box { + Box::new(Self::new(Some(meta), message, None).into_report()) + } +} + fn version_string(version: &Version) -> String { format!("{}.{}.{}", version.0, version.1, version.2) } diff --git a/parser/src/lang.lalrpop b/parser/src/lang.lalrpop index 83b75e6..1bee105 100644 --- a/parser/src/lang.lalrpop +++ b/parser/src/lang.lalrpop @@ -2,7 +2,7 @@ use num_bigint::BigInt; use program_structure::statement_builders::*; use program_structure::expression_builders::*; use program_structure::ast::*; -use program_structure::ast_shortcuts::{self,Symbol}; +use program_structure::ast_shortcuts::{self, Symbol, TupleInit}; use std::str::FromStr; grammar; @@ -93,9 +93,9 @@ pub ParseDefinition : Definition = { // VariableDefinitions // ==================================================================== -ParseElementType : SignalElementType = { - "FieldElement" => SignalElementType::FieldElement, - "Binary" => SignalElementType::Binary, +// To generate the list of tags associated to a signal +ParseTagsList : Vec = { + "{" "}" => id, }; ParseSignalType: SignalType = { @@ -104,17 +104,17 @@ ParseSignalType: SignalType = { }; SignalHeader : VariableType = { - "signal" )?> + "signal" => { - let e = match element_type { - None => SignalElementType::FieldElement, - Some(t) => t, - }; let s = match signal_type { None => SignalType::Intermediate, Some(st) => st, }; - VariableType::Signal(s,e) + let t = match tags_list { + None => Vec::new(), + Some(tl) => tl, + }; + VariableType::Signal(s, t) } }; @@ -124,6 +124,17 @@ SignalHeader : VariableType = { // A Initialization is either just the name of a variable or // the name followed by a expression that initialices the variable. +TupleInitialization : TupleInit = { + "<==" => TupleInit { + tuple_init : (AssignOp::AssignConstraintSignal, rhe) + }, + "<--" => TupleInit { + tuple_init : (AssignOp::AssignSignal, rhe) + }, + "=" => TupleInit { + tuple_init : (AssignOp::AssignVar, rhe) + }, +} SimpleSymbol : Symbol = { @@ -174,95 +185,122 @@ SignalSymbol : Symbol = { // A declaration is the definition of a type followed by the initialization ParseDeclaration : Statement = { - + "var" "(" ",")*> ")" => { + let mut symbols = symbols; + let meta = Meta::new(s, e); + let xtype = VariableType::Var; + symbols.push(symbol); + ast_shortcuts::split_declaration_into_single_nodes_and_multi_substitution(meta, xtype, symbols, init) + }, + "(" ",")*> ")" => { + let mut symbols = symbols; + let meta = Meta::new(s, e); + symbols.push(symbol); + ast_shortcuts::split_declaration_into_single_nodes_and_multi_substitution(meta, xtype, symbols, init) + }, + "component" "(" ",")*> ")" => { + let mut symbols = symbols; + let meta = Meta::new(s, e); + let xtype = VariableType::Component; + symbols.push(symbol); + ast_shortcuts::split_declaration_into_single_nodes_and_multi_substitution(meta, xtype, symbols, init) + }, "var" ",")*> => { let mut symbols = symbols; - let meta = Meta::new(s,e); + let meta = Meta::new(s, e); let xtype = VariableType::Var; symbols.push(symbol); - ast_shortcuts::split_declaration_into_single_nodes(meta,xtype,symbols,AssignOp::AssignVar) + ast_shortcuts::split_declaration_into_single_nodes(meta, xtype, symbols, AssignOp::AssignVar) }, "component" ",")*> => { let mut symbols = symbols; - let meta = Meta::new(s,e); + let meta = Meta::new(s, e); let xtype = VariableType::Component; symbols.push(symbol); - ast_shortcuts::split_declaration_into_single_nodes(meta,xtype,symbols,AssignOp::AssignVar) + ast_shortcuts::split_declaration_into_single_nodes(meta, xtype, symbols, AssignOp::AssignVar) }, - ",")*> - => { + ",")*> => { let mut symbols = symbols; - let meta = Meta::new(s,e); + let meta = Meta::new(s, e); symbols.push(symbol); ast_shortcuts::split_declaration_into_single_nodes(meta,xtype,symbols,AssignOp::AssignConstraintSignal) }, - ",")*> - => { + ",")*> => { let mut symbols = symbols; - let meta = Meta::new(s,e); + let meta = Meta::new(s, e); symbols.push(symbol); - ast_shortcuts::split_declaration_into_single_nodes(meta,xtype,symbols,AssignOp::AssignSignal) + ast_shortcuts::split_declaration_into_single_nodes(meta, xtype, symbols, AssignOp::AssignSignal) }, }; + ParseSubstitution : Statement = { - - => {let (name,access) = variable; - build_substitution(Meta::new(s,e),name,access,op,rhe) - }, - "-->" - => {let (name,access) = variable; - build_substitution(Meta::new(s,e),name,access,AssignOp::AssignSignal,lhe) - }, - "==>" - => {let (name,access) = variable; - build_substitution(Meta::new(s,e),name,access,AssignOp::AssignConstraintSignal,lhe) - }, - "\\=" - => ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::IntDiv,Meta::new(s,e),variable,rhe), + => { + if let Expression::Variable {meta, name, access} = variable { + build_substitution(Meta::new(s, e), name, access, ops, rhe) + } else { + build_multi_substitution(Meta::new(s, e), variable, ops, rhe) + } + }, + "-->" => { + if let Expression::Variable {meta, name, access} = variable { + build_substitution(Meta::new(s, e), name, access, AssignOp::AssignSignal, lhe) + } else { + build_multi_substitution(Meta::new(s, e), variable, AssignOp::AssignSignal, lhe) + } + }, + "==>" => { + if let Expression::Variable {meta, name, access} = variable { + build_substitution(Meta::new(s, e), name, access, AssignOp::AssignConstraintSignal, lhe) + } else{ + build_multi_substitution(Meta::new(s, e), variable, AssignOp::AssignConstraintSignal, lhe) + } + }, + "\\=" => + ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::IntDiv, Meta::new(s, e), variable, rhe), - "**=" - => ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::Pow,Meta::new(s,e),variable,rhe), + "**=" => + ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::Pow, Meta::new(s, e), variable, rhe), - "+=" - => ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::Add,Meta::new(s,e),variable,rhe), + "+=" => + ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::Add, Meta::new(s, e), variable, rhe), - "-=" - => ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::Sub,Meta::new(s,e),variable,rhe), + "-=" => + ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::Sub, Meta::new(s, e), variable, rhe), - "*=" - => ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::Mul,Meta::new(s,e),variable,rhe), + "*=" => + ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::Mul, Meta::new(s, e), variable, rhe), - "/=" - => ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::Div,Meta::new(s,e),variable,rhe), + "/=" => + ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::Div, Meta::new(s, e), variable, rhe), - "%=" - => ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::Mod,Meta::new(s,e),variable,rhe), + "%=" => + ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::Mod, Meta::new(s, e), variable, rhe), - "<<=" - => ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::ShiftL,Meta::new(s,e),variable,rhe), + "<<=" => + ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::ShiftL, Meta::new(s, e), variable, rhe), - ">>=" - => ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::ShiftR,Meta::new(s,e),variable,rhe), + ">>=" => + ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::ShiftR, Meta::new(s, e), variable, rhe), - "&=" - => ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::BitAnd,Meta::new(s,e),variable,rhe), + "&=" => + ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::BitAnd, Meta::new(s, e), variable, rhe), - "|=" - => ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::BitOr,Meta::new(s,e),variable,rhe), + "|=" => + ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::BitOr, Meta::new(s, e), variable, rhe), - "^=" - => ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::BitXor,Meta::new(s,e),variable,rhe), + "^=" => + ast_shortcuts::assign_with_op_shortcut(ExpressionInfixOpcode::BitXor, Meta::new(s, e), variable, rhe), - "++" - => ast_shortcuts::plusplus(Meta::new(s,e),variable), + "++" => + ast_shortcuts::plusplus(Meta::new(s,e),variable), - "--" - => ast_shortcuts::subsub(Meta::new(s,e),variable), + "--" => + ast_shortcuts::subsub(Meta::new(s, e), variable), }; ParseBlock : Statement = { "{" "}" - => build_block(Meta::new(s,e),stmts), + => build_block(Meta::new(s, e), stmts), }; pub ParseStatement : Statement = { @@ -279,44 +317,49 @@ ParseStatement0 : Statement = { }; ParseStmt0NB : Statement = { - "if" "(" ")" - => build_conditional_block(Meta::new(s,e),cond,if_case,None), + "if" "(" ")" => + build_conditional_block(Meta::new(s, e), cond, if_case, None), - "if" "(" ")" - => build_conditional_block(Meta::new(s,e),cond,if_case,None), + "if" "(" ")" => + build_conditional_block(Meta::new(s, e), cond, if_case, None), - "if" "(" ")" > - => build_conditional_block(Meta::new(s,e),cond,if_case,Some(else_case)), + "if" "(" ")" > => + build_conditional_block(Meta::new(s, e), cond, if_case, Some(else_case)), }; ParseStatement1 : Statement = { - "if" "(" ")" > - => build_conditional_block(Meta::new(s,e),cond,if_case,Some(else_case)), + "if" "(" ")" > => + build_conditional_block(Meta::new(s, e), cond, if_case, Some(else_case)), + ParseStatement2 }; + ParseStatement2 : Statement = { - "for" "(" ";" ";" ")" - => ast_shortcuts::for_into_while(Meta::new(s,e),init,cond,step,body), + "for" "(" ";" ";" ")" => + ast_shortcuts::for_into_while(Meta::new(s, e), init, cond, step, body), - "for" "(" ";" ";" ")" - => ast_shortcuts::for_into_while(Meta::new(s,e),init,cond,step,body), + "for" "(" ";" ";" ")" => + ast_shortcuts::for_into_while(Meta::new(s, e), init, cond, step, body), - "while" "(" ")" - => build_while_block(Meta::new(s,e),cond,stmt), + "while" "(" ")" => + build_while_block(Meta::new(s, e), cond, stmt), - "return" ";" - => build_return(Meta::new(s,e),value), + "return" ";" => + build_return(Meta::new(s, e), value), - ";" - => subs, + ";" => + subs, - "===" ";" - => build_constraint_equality(Meta::new(s,e),lhe,rhe), + "===" ";" => + build_constraint_equality(Meta::new(s, e), lhe, rhe), ParseStatementLog, - "assert" "(" ")" ";" - => build_assert(Meta::new(s,e),arg), + "assert" "(" ")" ";" => + build_assert(Meta::new(s,e),arg), + + ";" => + build_anonymous_component_statement(Meta::new(s, e), lhe), ParseBlock }; @@ -347,41 +390,62 @@ ParseVarAccess : Access = { => build_array_access(arr_dec), => build_component_access(component_acc), }; + ParseArrayAcc: Expression = { "[""]" => dim }; + ParseComponentAcc: String = { "." => id, }; + ParseVariable : (String,Vec) = { - => (name,access), + => (name, access), }; + // ==================================================================== // Expression // ==================================================================== Listable: Vec = { - ",")*> - => { + ",")*> => { let mut e = e; e.push(tail); e }, }; -ParseString : LogArgument = { - - => { - build_log_string(e) +ListableWithInputNames : (Vec,Option>) = { + < ParseExpression> ",")*> + => { + let (mut operators_names, mut signals) = unzip_3(e); + signals.push(signal); + match operators_names.len() { + 0 => (signals, Option::None), + _ => { operators_names.push((op,name)); (signals, Option::Some(operators_names)) + } + } + } +}; + +ListableAnon : (Vec,Option>) = { + => { + (l, Option::None) }, + + => + l, +}; + +ParseString : LogArgument = { + => + build_log_string(e), }; ParseLogExp: LogArgument = { - - => { - build_log_expression(e) - } + => + build_log_expression(e), } ParseLogArgument : LogArgument = { @@ -390,24 +454,33 @@ ParseLogArgument : LogArgument = { }; LogListable: Vec = { - ",")*> - => { + ",")*> => { let mut e = e; e.push(tail); e }, }; +TwoElemsListable: Vec = { + "," )*> + => { + let mut rest = rest; + let mut new_v = vec![head, head1]; + new_v.append(&mut rest); + new_v + }, +}; + InfixOpTier : Expression = { - > - => build_infix(Meta::new(s,e),lhe,infix_op,rhe), + > => + build_infix(Meta::new(s, e), lhe, infix_op, rhe), NextTier }; PrefixOpTier: Expression = { - - => build_prefix(Meta::new(s,e),prefix_op,rhe), + => + build_prefix(Meta::new(s, e), prefix_op, rhe), NextTier }; @@ -434,74 +507,90 @@ Expression14: Expression = { // ops: e ? a : i Expression13 : Expression = { "?" ":" - => build_inline_switch_op(Meta::new(s,e),cond,if_true,if_false), + => build_inline_switch_op(Meta::new(s, e), cond, if_true, if_false), }; // ops: || -Expression12 = InfixOpTier; +Expression12 = InfixOpTier; // ops: && -Expression11 = InfixOpTier; +Expression11 = InfixOpTier; // ops: == != < > <= >= -Expression10 = InfixOpTier; +Expression10 = InfixOpTier; // ops: | -Expression9 = InfixOpTier; +Expression9 = InfixOpTier; // ops: ^ -Expression8 = InfixOpTier; +Expression8 = InfixOpTier; // ops: & -Expression7 = InfixOpTier; +Expression7 = InfixOpTier; // ops: << >> -Expression6 = InfixOpTier; +Expression6 = InfixOpTier; // ops: + - -Expression5 = InfixOpTier; +Expression5 = InfixOpTier; // ops: * / \\ % -Expression4 = InfixOpTier; +Expression4 = InfixOpTier; // ops: ** -Expression3 = InfixOpTier; +Expression3 = InfixOpTier; // ops: Unary - ! ~ -Expression2 = PrefixOpTier; +Expression2 = PrefixOpTier; -// function call, array inline +// function call, array inline, anonymous component call Expression1: Expression = { - "(" ")" - => match args { - None => build_call(Meta::new(s,e),id,Vec::new()), - Some(a) => build_call(Meta::new(s,e),id,a), + "(" ")" "(" ")" => { + let params = match args { + None => Vec::new(), + Some(a) => a + }; + let (signals, names) = match args2 { + None => (Vec::new(),Option::None), + Some(a) => a + }; + build_anonymous_component(Meta::new(s, e), id, params, signals, names, false) + }, + + "(" ")" => match args { + None => build_call(Meta::new(s, e), id, Vec::new()), + Some(a) => build_call(Meta::new(s, e), id, a), }, - "[" "]" - => build_array_in_line(Meta::new(s,e),values), + "[" "]" => + build_array_in_line(Meta::new(s, e), values), + + "(" ")" => + build_tuple(Meta::new(s,e), values), Expression0, }; // Literal, parentheses Expression0: Expression = { - - => { - let (name,access) = variable; - build_variable(Meta::new(s,e),name,access) + => { + let (name, access) = variable; + build_variable(Meta::new(s, e), name, access) }, - - => build_number(Meta::new(s,e),value), + "_" => + build_variable(Meta::new(s, e), "_".to_string(), Vec::new()), - - => build_number(Meta::new(s,e),value), + + => + build_number(Meta::new(s, e), value), + + => + build_number(Meta::new(s, e), value), "(" ")" }; - // ==================================================================== // Terminals // ==================================================================== diff --git a/parser/src/lib.rs b/parser/src/lib.rs index 2f548a3..f52020c 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -2,6 +2,7 @@ extern crate num_bigint_dig as num_bigint; extern crate num_traits; extern crate serde; extern crate serde_derive; + #[macro_use] extern crate lalrpop_util; @@ -13,6 +14,11 @@ use log::debug; mod errors; mod include_logic; mod parser_logic; +mod syntax_sugar_traits; +mod syntax_sugar_remover; + +pub use parser_logic::parse_definition; + use include_logic::FileStack; use program_structure::ast::{Version, AST}; use program_structure::report::{Report, ReportCollection}; @@ -30,14 +36,24 @@ pub enum ParseResult { Library(Box, ReportCollection), } -pub fn parse_files(file_paths: &[PathBuf], compiler_version: &Version) -> ParseResult { +pub fn parse_files( + file_paths: &[PathBuf], + follow_includes: bool, + compiler_version: &Version, +) -> ParseResult { let mut reports = ReportCollection::new(); let mut file_stack = FileStack::new(file_paths, &mut reports); let mut file_library = FileLibrary::new(); let mut definitions = HashMap::new(); let mut main_components = Vec::new(); while let Some(file_path) = FileStack::take_next(&mut file_stack) { - match parse_file(&file_path, &mut file_stack, &mut file_library, compiler_version) { + match parse_file( + &file_path, + &mut file_stack, + &mut file_library, + follow_includes, + compiler_version, + ) { Ok((file_id, program, mut warnings)) => { if let Some(main_component) = program.main_component { main_components.push((file_id, main_component, program.custom_gates)); @@ -50,7 +66,8 @@ pub fn parse_files(file_paths: &[PathBuf], compiler_version: &Version) -> ParseR } } } - match &main_components[..] { + // Create a parse result. + let mut result = match &main_components[..] { [(main_id, main_component, custom_gates)] => { // TODO: This calls FillMeta::fill a second time. match ProgramArchive::new( @@ -78,13 +95,50 @@ pub fn parse_files(file_paths: &[PathBuf], compiler_version: &Version) -> ParseR let template_library = TemplateLibrary::new(definitions, file_library); ParseResult::Library(Box::new(template_library), reports) } + }; + // Remove anonymous components and tuples. + // + // TODO: This could be moved to the lifting phase. + match &mut result { + ParseResult::Program(program_archive, reports) => { + if program_archive.main_expression().is_anonymous_component() { + reports.push( + errors::AnonymousComponentError::new( + Some(program_archive.main_expression().meta()), + "The main component cannot contain an anonymous call.", + Some("Main component defined here."), + ) + .into_report(), + ); + } + let (new_templates, new_functions) = syntax_sugar_remover::remove_syntactic_sugar( + &program_archive.templates, + &program_archive.functions, + &program_archive.file_library, + reports, + ); + program_archive.templates = new_templates; + program_archive.functions = new_functions; + } + ParseResult::Library(template_library, reports) => { + let (new_templates, new_functions) = syntax_sugar_remover::remove_syntactic_sugar( + &template_library.templates, + &template_library.functions, + &template_library.file_library, + reports, + ); + template_library.templates = new_templates; + template_library.functions = new_functions; + } } + result } fn parse_file( file_path: &PathBuf, file_stack: &mut FileStack, file_library: &mut FileLibrary, + follow_includes: bool, compiler_version: &Version, ) -> Result<(FileID, AST, ReportCollection), Box> { let mut reports = ReportCollection::new(); @@ -95,15 +149,17 @@ fn parse_file( debug!("parsing file `{}`", file_path.display()); let program = parser_logic::parse_file(&file_content, file_id)?; - for include in &program.includes { - if let Err(report) = FileStack::add_include(file_stack, include) { - reports.push(*report); - } - } match check_compiler_version(file_path, program.compiler_version, compiler_version) { Ok(warnings) => reports.extend(warnings), Err(error) => reports.push(*error), } + if follow_includes { + for include in &program.includes { + if let Err(report) = file_stack.add_include(include) { + reports.push(*report); + } + } + } Ok((file_id, program, reports)) } @@ -147,16 +203,6 @@ fn check_compiler_version( } } -/// Parse a single (function or template) definition for testing purposes. -use program_structure::ast::Definition; - -pub fn parse_definition(src: &str) -> Option { - match parser_logic::parse_string(src) { - Some(AST { mut definitions, .. }) if definitions.len() == 1 => definitions.pop(), - _ => None, - } -} - #[cfg(test)] mod tests { use std::path::PathBuf; diff --git a/parser/src/parser_logic.rs b/parser/src/parser_logic.rs index a4849e6..6f1ce4f 100644 --- a/parser/src/parser_logic.rs +++ b/parser/src/parser_logic.rs @@ -116,6 +116,16 @@ pub fn parse_string(src: &str) -> Option { lang::ParseAstParser::new().parse(&src).ok() } +/// Parse a single (function or template) definition for testing purposes. +use program_structure::ast::Definition; + +pub fn parse_definition(src: &str) -> Option { + match parse_string(src) { + Some(AST { mut definitions, .. }) if definitions.len() == 1 => definitions.pop(), + _ => None, + } +} + #[must_use] fn format_expected(tokens: &[String]) -> String { if tokens.is_empty() { diff --git a/parser/src/syntax_sugar_remover.rs b/parser/src/syntax_sugar_remover.rs new file mode 100644 index 0000000..f88c7d6 --- /dev/null +++ b/parser/src/syntax_sugar_remover.rs @@ -0,0 +1,1020 @@ +use program_structure::ast::*; +use program_structure::statement_builders::{build_block, build_substitution}; +use program_structure::report::{Report, ReportCollection}; +use program_structure::expression_builders::{build_call, build_tuple, build_parallel_op}; +use program_structure::file_definition::FileLibrary; +use program_structure::statement_builders::{ + build_declaration, build_log_call, build_assert, build_return, build_constraint_equality, + build_initialization_block, +}; +use program_structure::template_data::TemplateData; +use program_structure::function_data::FunctionData; +use std::collections::HashMap; +use num_bigint::BigInt; + +use crate::errors::{AnonymousComponentError, TupleError}; +use crate::syntax_sugar_traits::ContainsExpression; + +/// This functions desugars all anonymous components and tuples. +#[must_use] +pub(crate) fn remove_syntactic_sugar( + templates: &HashMap, + functions: &HashMap, + file_library: &FileLibrary, + reports: &mut ReportCollection, +) -> (HashMap, HashMap) { + // Remove anonymous components and tuples from templates. + let mut new_templates = HashMap::new(); + for (name, template) in templates { + let body = template.get_body().clone(); + let (new_body, declarations) = + match remove_anonymous_from_statement(templates, file_library, body, &None) { + Ok(result) => result, + Err(report) => { + // If we encounter an error we simply report the error and continue. + // This means that the template is dropped and no more analysis is + // performed on it. + // + // TODO: If we want to do inter-procedural analysis we need to track + // removed templates. + reports.push(*report); + continue; + } + }; + if let Statement::Block { meta, mut stmts } = new_body { + let (component_decs, variable_decs, mut substitutions) = + separate_declarations_in_comp_var_subs(declarations); + let mut init_block = vec![ + build_initialization_block(meta.clone(), VariableType::Var, variable_decs), + build_initialization_block(meta.clone(), VariableType::Component, component_decs), + ]; + init_block.append(&mut substitutions); + init_block.append(&mut stmts); + let new_body_with_inits = build_block(meta, init_block); + let new_body = match remove_tuples_from_statement(new_body_with_inits) { + Ok(result) => result, + Err(report) => { + // If we encounter an error we simply report the error and continue. + // This means that the template is dropped and no more analysis is + // performed on it. + // + // TODO: If we want to do inter-procedural analysis we need to track + // removed templates. + reports.push(*report); + continue; + } + }; + let mut new_template = template.clone(); + *new_template.get_mut_body() = new_body; + new_templates.insert(name.clone(), new_template); + } else { + unreachable!() + } + } + + // Drop any functions containing anonymous components or tuples. + let mut new_functions = HashMap::new(); + for (name, function) in functions { + let body = function.get_body(); + if body.contains_tuple(Some(reports)) { + continue; + } + if body.contains_anonymous_component(Some(reports)) { + continue; + } + new_functions.insert(name.clone(), function.clone()); + } + (new_templates, new_functions) +} + +fn remove_anonymous_from_statement( + templates: &HashMap, + file_library: &FileLibrary, + stmt: Statement, + var_access: &Option, +) -> Result<(Statement, Vec), Box> { + match stmt { + Statement::MultiSubstitution { meta, lhe, op, rhe } => { + if lhe.contains_anonymous_component(None) { + return Err(AnonymousComponentError::boxed_report( + lhe.meta(), + "An anonymous component cannot occur as the left-hand side of an assignment", + )); + } else { + let (mut stmts, declarations, new_rhe) = + remove_anonymous_from_expression(templates, file_library, rhe, var_access)?; + let subs = + Statement::MultiSubstitution { meta: meta.clone(), lhe, op, rhe: new_rhe }; + let mut substs = Vec::new(); + if stmts.is_empty() { + Ok((subs, declarations)) + } else { + substs.append(&mut stmts); + substs.push(subs); + Ok((Statement::Block { meta, stmts: substs }, declarations)) + } + } + } + Statement::IfThenElse { meta, cond, if_case, else_case } => { + if cond.contains_anonymous_component(None) { + return Err(AnonymousComponentError::boxed_report( + cond.meta(), + "Anonymous components cannot be used inside conditions.", + )); + } else { + let (new_if_case, mut declarations) = + remove_anonymous_from_statement(templates, file_library, *if_case, var_access)?; + match else_case { + Some(else_case) => { + let (new_else_case, mut new_declarations) = + remove_anonymous_from_statement( + templates, + file_library, + *else_case, + var_access, + )?; + declarations.append(&mut new_declarations); + Ok(( + Statement::IfThenElse { + meta, + cond, + if_case: Box::new(new_if_case), + else_case: Some(Box::new(new_else_case)), + }, + declarations, + )) + } + None => Ok(( + Statement::IfThenElse { + meta, + cond, + if_case: Box::new(new_if_case), + else_case: None, + }, + declarations, + )), + } + } + } + Statement::While { meta, cond, stmt } => { + if cond.contains_anonymous_component(None) { + return Err(AnonymousComponentError::boxed_report( + cond.meta(), + "Anonymous components cannot be used inside conditions.", + )); + } else { + let id_var_while = "anon_var_".to_string() + + &file_library.get_line(meta.start, meta.get_file_id()).unwrap().to_string() + + "_" + + &meta.start.to_string(); + let var_access = Expression::Variable { + meta: meta.clone(), + name: id_var_while.clone(), + access: Vec::new(), + }; + let mut declarations = vec![]; + let (new_stmt, mut new_declarations) = remove_anonymous_from_statement( + templates, + file_library, + *stmt, + &Some(var_access.clone()), + )?; + let boxed_stmt = if !new_declarations.is_empty() { + declarations.push(build_declaration( + meta.clone(), + VariableType::Var, + id_var_while.clone(), + Vec::new(), + )); + declarations.push(build_substitution( + meta.clone(), + id_var_while.clone(), + vec![], + AssignOp::AssignVar, + Expression::Number(meta.clone(), BigInt::from(0)), + )); + declarations.append(&mut new_declarations); + let next_access = Expression::InfixOp { + meta: meta.clone(), + infix_op: ExpressionInfixOpcode::Add, + lhe: Box::new(var_access), + rhe: Box::new(Expression::Number(meta.clone(), BigInt::from(1))), + }; + let subs_access = Statement::Substitution { + meta: meta.clone(), + var: id_var_while, + access: Vec::new(), + op: AssignOp::AssignVar, + rhe: next_access, + }; + + let new_block = + Statement::Block { meta: meta.clone(), stmts: vec![new_stmt, subs_access] }; + Box::new(new_block) + } else { + Box::new(new_stmt) + }; + + Ok((Statement::While { meta, cond, stmt: boxed_stmt }, declarations)) + } + } + Statement::LogCall { meta, args } => { + for arg in &args { + if let program_structure::ast::LogArgument::LogExp(exp) = arg { + if exp.contains_anonymous_component(None) { + return Err(AnonymousComponentError::boxed_report( + &meta, + "An anonymous component cannot be used inside a log statement.", + )); + } + } + } + Ok((build_log_call(meta, args), Vec::new())) + } + Statement::Assert { meta, arg } => Ok((build_assert(meta, arg), Vec::new())), + Statement::Return { meta, value: arg } => { + if arg.contains_anonymous_component(None) { + Err(AnonymousComponentError::boxed_report( + &meta, + "An anonymous component cannot be used as a return value.", + )) + } else { + Ok((build_return(meta, arg), Vec::new())) + } + } + Statement::ConstraintEquality { meta, lhe, rhe } => { + if lhe.contains_anonymous_component(None) || rhe.contains_anonymous_component(None) { + Err(AnonymousComponentError::boxed_report( + &meta, + "Anonymous components cannot be used together with the constraint equality operator `===`.", + )) + } else { + Ok((build_constraint_equality(meta, lhe, rhe), Vec::new())) + } + } + Statement::Declaration { meta, xtype, name, dimensions, .. } => { + for exp in dimensions.clone() { + if exp.contains_anonymous_component(None) { + return Err(AnonymousComponentError::boxed_report( + exp.meta(), + "An anonymous component cannot be used to define the dimensions of an array.", + )); + } + } + Ok((build_declaration(meta, xtype, name, dimensions), Vec::new())) + } + Statement::InitializationBlock { meta, xtype, initializations } => { + let mut new_inits = Vec::new(); + let mut declarations = Vec::new(); + for stmt in initializations { + let (stmt_ok, mut declaration) = + remove_anonymous_from_statement(templates, file_library, stmt, var_access)?; + new_inits.push(stmt_ok); + declarations.append(&mut declaration) + } + Ok(( + Statement::InitializationBlock { meta, xtype, initializations: new_inits }, + declarations, + )) + } + Statement::Block { meta, stmts } => { + let mut new_stmts = Vec::new(); + let mut declarations = Vec::new(); + for stmt in stmts { + let (stmt_ok, mut declaration) = + remove_anonymous_from_statement(templates, file_library, stmt, var_access)?; + new_stmts.push(stmt_ok); + declarations.append(&mut declaration); + } + Ok((Statement::Block { meta, stmts: new_stmts }, declarations)) + } + Statement::Substitution { meta, var, op, rhe, access } => { + let (mut stmts, declarations, new_rhe) = + remove_anonymous_from_expression(templates, file_library, rhe, var_access)?; + let subs = + Statement::Substitution { meta: meta.clone(), var, access, op, rhe: new_rhe }; + let mut substs = Vec::new(); + if stmts.is_empty() { + Ok((subs, declarations)) + } else { + substs.append(&mut stmts); + substs.push(subs); + Ok((Statement::Block { meta, stmts: substs }, declarations)) + } + } + } +} + +// returns a block with the substitutions, the declarations and finally the output expression +fn remove_anonymous_from_expression( + templates: &HashMap, + file_library: &FileLibrary, + expr: Expression, + var_access: &Option, // in case the call is inside a loop, variable used to control the access +) -> Result<(Vec, Vec, Expression), Box> { + use Expression::*; + match expr.clone() { + ArrayInLine { values, .. } => { + for value in values { + if value.contains_anonymous_component(None) { + return Err(AnonymousComponentError::boxed_report( + value.meta(), + "An anonymous component cannot be used to define the dimensions of an array.", + )); + } + } + Ok((Vec::new(), Vec::new(), expr)) + } + Number(_, _) => Ok((Vec::new(), Vec::new(), expr)), + Variable { meta, .. } => { + if expr.contains_anonymous_component(None) { + return Err(AnonymousComponentError::boxed_report( + &meta, + "An anonymous component cannot be used to access an array.", + )); + } + Ok((Vec::new(), Vec::new(), expr)) + } + InfixOp { meta, lhe, rhe, .. } => { + if lhe.contains_anonymous_component(None) || rhe.contains_anonymous_component(None) { + return Err(AnonymousComponentError::boxed_report( + &meta, + "Anonymous components cannot be used in arithmetic or boolean expressions.", + )); + } + Ok((Vec::new(), Vec::new(), expr)) + } + PrefixOp { meta, rhe, .. } => { + if rhe.contains_anonymous_component(None) { + return Err(AnonymousComponentError::boxed_report( + &meta, + "Anonymous components cannot be used in arithmetic or boolean expressions.", + )); + } + Ok((Vec::new(), Vec::new(), expr)) + } + InlineSwitchOp { meta, cond, if_true, if_false } => { + if cond.contains_anonymous_component(None) + || if_true.contains_anonymous_component(None) + || if_false.contains_anonymous_component(None) + { + return Err(AnonymousComponentError::boxed_report( + &meta, + "An anonymous component cannot be used inside an inline switch expression.", + )); + } + Ok((Vec::new(), Vec::new(), expr)) + } + Call { meta, args, .. } => { + for value in args { + if value.contains_anonymous_component(None) { + return Err(AnonymousComponentError::boxed_report( + &meta, + "An anonymous component cannot be used as an argument to a template call.", + )); + } + } + Ok((Vec::new(), Vec::new(), expr)) + } + AnonymousComponent { meta, id, params, signals, names, is_parallel } => { + let template = templates.get(&id); + let mut declarations = Vec::new(); + if template.is_none() { + return Err(Box::new( + AnonymousComponentError::new( + Some(&meta), + &format!("The template `{id}` does not exist."), + Some(&format!("Unknown template `{id}` instantiated here.")), + ) + .into_report(), + )); + } + let mut i = 0; + let mut seq_substs = Vec::new(); + let id_anon_temp = id.to_string() + + "_" + + &file_library.get_line(meta.start, meta.get_file_id()).unwrap().to_string() + + "_" + + &meta.start.to_string(); + if var_access.is_none() { + declarations.push(build_declaration( + meta.clone(), + VariableType::Component, + id_anon_temp.clone(), + Vec::new(), + )); + } else { + declarations.push(build_declaration( + meta.clone(), + VariableType::AnonymousComponent, + id_anon_temp.clone(), + vec![var_access.as_ref().unwrap().clone()], + )); + } + let call = build_call(meta.clone(), id, params); + if call.contains_anonymous_component(None) { + return Err(AnonymousComponentError::boxed_report( + &meta, + "An anonymous component cannot be used as a argument to a template call.", + )); + } + + let exp_with_call = + if is_parallel { build_parallel_op(meta.clone(), call) } else { call }; + let access = if var_access.is_none() { + Vec::new() + } else { + vec![build_array_access(var_access.as_ref().unwrap().clone())] + }; + let sub = build_substitution( + meta.clone(), + id_anon_temp.clone(), + access, + AssignOp::AssignVar, + exp_with_call, + ); + seq_substs.push(sub); + let inputs = template.unwrap().get_declaration_inputs(); + let mut new_signals = Vec::new(); + let mut new_operators = Vec::new(); + if let Some(m) = names { + let (operators, names): (Vec, Vec) = m.iter().cloned().unzip(); + for inp in inputs { + if !names.contains(&inp.0) { + return Err(AnonymousComponentError::boxed_report( + &meta, + &format!("The input signal `{}` is not assigned by the anonymous component call.", inp.0), + )); + } else { + let pos = names.iter().position(|r| *r == inp.0).unwrap(); + new_signals.push(signals.get(pos).unwrap().clone()); + new_operators.push(*operators.get(pos).unwrap()); + } + } + } else { + new_signals = signals.clone(); + for _ in 0..signals.len() { + new_operators.push(AssignOp::AssignConstraintSignal); + } + } + if inputs.len() != new_signals.len() || inputs.len() != signals.len() { + return Err(AnonymousComponentError::boxed_report(&meta, "The number of input arguments must be equal to the number of input signals of the template.")); + } + for inp in inputs { + let mut acc = if var_access.is_none() { + Vec::new() + } else { + vec![build_array_access(var_access.as_ref().unwrap().clone())] + }; + acc.push(Access::ComponentAccess(inp.0.clone())); + let (mut stmts, mut new_declarations, new_expr) = remove_anonymous_from_expression( + templates, + file_library, + new_signals.get(i).unwrap().clone(), + var_access, + )?; + if new_expr.contains_anonymous_component(None) { + return Err(AnonymousComponentError::boxed_report( + new_expr.meta(), + "The inputs to an anonymous component cannot contain anonymous components.", + )); + } + seq_substs.append(&mut stmts); + declarations.append(&mut new_declarations); + let subs = Statement::Substitution { + meta: meta.clone(), + var: id_anon_temp.clone(), + access: acc, + op: *new_operators.get(i).unwrap(), + rhe: new_expr, + }; + i += 1; + seq_substs.push(subs); + } + let outputs = template.unwrap().get_declaration_outputs(); + if outputs.len() == 1 { + let output = outputs.get(0).unwrap().0.clone(); + let mut acc = if var_access.is_none() { + Vec::new() + } else { + vec![build_array_access(var_access.as_ref().unwrap().clone())] + }; + + acc.push(Access::ComponentAccess(output)); + let out_exp = + Expression::Variable { meta: meta.clone(), name: id_anon_temp, access: acc }; + Ok((vec![Statement::Block { meta, stmts: seq_substs }], declarations, out_exp)) + } else { + let mut new_values = Vec::new(); + for output in outputs { + let mut acc = if var_access.is_none() { + Vec::new() + } else { + vec![build_array_access(var_access.as_ref().unwrap().clone())] + }; + acc.push(Access::ComponentAccess(output.0.clone())); + let out_exp = Expression::Variable { + meta: meta.clone(), + name: id_anon_temp.clone(), + access: acc, + }; + new_values.push(out_exp); + } + let out_exp = Tuple { meta: meta.clone(), values: new_values }; + Ok((vec![Statement::Block { meta, stmts: seq_substs }], declarations, out_exp)) + } + } + Tuple { meta, values } => { + let mut new_values = Vec::new(); + let mut new_stmts: Vec = Vec::new(); + let mut declarations: Vec = Vec::new(); + for val in values { + let result = + remove_anonymous_from_expression(templates, file_library, val, var_access); + match result { + Ok((mut stm, mut declaration, val2)) => { + new_stmts.append(&mut stm); + new_values.push(val2); + declarations.append(&mut declaration); + } + Err(er) => { + return Err(er); + } + } + } + Ok((new_stmts, declarations, build_tuple(meta, new_values))) + } + ParallelOp { meta, rhe } => { + if !rhe.is_call() + && !rhe.is_anonymous_component() + && rhe.contains_anonymous_component(None) + { + return Err(AnonymousComponentError::boxed_report( + &meta, + "Invalid use of the parallel operator together with an anonymous component.", + )); + } else if rhe.is_call() && rhe.contains_anonymous_component(None) { + return Err(AnonymousComponentError::boxed_report( + &meta, + "An anonymous component cannot be used as a parameter in a template call.", + )); + } else if rhe.is_anonymous_component() { + let rhe2 = rhe.make_anonymous_parallel(); + return remove_anonymous_from_expression(templates, file_library, rhe2, var_access); + } + Ok((Vec::new(), Vec::new(), expr)) + } + } +} + +fn separate_declarations_in_comp_var_subs( + declarations: Vec, +) -> (Vec, Vec, Vec) { + let mut components_dec = Vec::new(); + let mut variables_dec = Vec::new(); + let mut substitutions = Vec::new(); + for dec in declarations { + if let Statement::Declaration { ref xtype, .. } = dec { + if matches!(xtype, VariableType::Component | VariableType::AnonymousComponent) { + components_dec.push(dec); + } else if VariableType::Var.eq(xtype) { + variables_dec.push(dec); + } else { + unreachable!(); + } + } else if let Statement::Substitution { .. } = dec { + substitutions.push(dec); + } else { + unreachable!(); + } + } + (components_dec, variables_dec, substitutions) +} + +fn remove_tuples_from_statement(stmt: Statement) -> Result> { + match stmt { + Statement::MultiSubstitution { meta, lhe, op, rhe } => { + let new_lhe = remove_tuple_from_expression(lhe)?; + let new_rhe = remove_tuple_from_expression(rhe)?; + match (new_lhe, new_rhe) { + ( + Expression::Tuple { values: mut lhe_values, .. }, + Expression::Tuple { values: mut rhe_values, .. }, + ) => { + if lhe_values.len() == rhe_values.len() { + let mut substs = Vec::new(); + while !lhe_values.is_empty() { + let lhe = lhe_values.remove(0); + if let Expression::Variable { meta, name, access } = lhe { + let rhe = rhe_values.remove(0); + if name != "_" { + substs.push(build_substitution( + meta.clone(), + name.clone(), + access.to_vec(), + op, + rhe, + )); + } + } else { + return Err(TupleError::boxed_report(&meta, "The elements of the destination tuple must be either signals or variables.")); + } + } + Ok(build_block(meta, substs)) + } else if !lhe_values.is_empty() { + Err(TupleError::boxed_report( + &meta, + "The two tuples do not have the same length.", + )) + } else { + Err(TupleError::boxed_report( + &meta, + "This expression must be the right-hand side of an assignment.", + )) + } + } + (lhe, rhe) => { + if lhe.is_tuple() || lhe.is_variable() { + return Err(TupleError::boxed_report( + rhe.meta(), + "This expression must be a tuple or an anonymous component.", + )); + } else { + return Err(TupleError::boxed_report( + lhe.meta(), + "This expression must be a tuple, a component, a signal or a variable.", + )); + } + } + } + } + Statement::IfThenElse { meta, cond, if_case, else_case } => { + if cond.contains_tuple(None) { + Err(TupleError::boxed_report(&meta, "Tuples cannot be used in conditions.")) + } else { + let new_if_case = remove_tuples_from_statement(*if_case)?; + match else_case { + Some(else_case) => { + let new_else_case = remove_tuples_from_statement(*else_case)?; + Ok(Statement::IfThenElse { + meta, + cond, + if_case: Box::new(new_if_case), + else_case: Some(Box::new(new_else_case)), + }) + } + None => Ok(Statement::IfThenElse { + meta, + cond, + if_case: Box::new(new_if_case), + else_case: None, + }), + } + } + } + Statement::While { meta, cond, stmt } => { + if cond.contains_tuple(None) { + Err(TupleError::boxed_report(&meta, "Tuples cannot be used in conditions.")) + } else { + let new_stmt = remove_tuples_from_statement(*stmt)?; + Ok(Statement::While { meta, cond, stmt: Box::new(new_stmt) }) + } + } + Statement::LogCall { meta, args } => { + let mut new_args = Vec::new(); + for arg in args { + match arg { + LogArgument::LogStr(str) => { + new_args.push(LogArgument::LogStr(str)); + } + LogArgument::LogExp(exp) => { + let mut sep_args = separate_tuple_for_log_call(vec![exp]); + new_args.append(&mut sep_args); + } + } + } + Ok(build_log_call(meta, new_args)) + } + Statement::Assert { meta, arg } => Ok(build_assert(meta, arg)), + Statement::Return { meta, value } => { + if value.contains_tuple(None) { + Err(TupleError::boxed_report(&meta, "Tuple cannot be used in return values.")) + } else { + Ok(build_return(meta, value)) + } + } + Statement::ConstraintEquality { meta, lhe, rhe } => { + if lhe.contains_tuple(None) || rhe.contains_tuple(None) { + Err(TupleError::boxed_report( + &meta, + "Tuples cannot be used together with the constraint equality operator `===`.", + )) + } else { + Ok(build_constraint_equality(meta, lhe, rhe)) + } + } + Statement::Declaration { meta, xtype, name, dimensions, .. } => { + for expr in &dimensions { + if expr.contains_tuple(None) { + return Err(TupleError::boxed_report( + &meta, + "A tuple cannot be used to define the dimensions of an array.", + )); + } + } + Ok(build_declaration(meta, xtype, name, dimensions)) + } + Statement::InitializationBlock { meta, xtype, initializations } => { + let mut new_inits = Vec::new(); + for stmt in initializations { + let new_stmt = remove_tuples_from_statement(stmt)?; + new_inits.push(new_stmt); + } + Ok(Statement::InitializationBlock { meta, xtype, initializations: new_inits }) + } + Statement::Block { meta, stmts } => { + let mut new_stmts = Vec::new(); + for stmt in stmts { + let new_stmt = remove_tuples_from_statement(stmt)?; + new_stmts.push(new_stmt); + } + Ok(Statement::Block { meta, stmts: new_stmts }) + } + Statement::Substitution { meta, var, op, rhe, access } => { + let new_rhe = remove_tuple_from_expression(rhe)?; + if new_rhe.is_tuple() { + return Err(TupleError::boxed_report( + &meta, + "Left-hand side of the statement is not a tuple.", + )); + } + for access in &access { + if let Access::ArrayAccess(index) = access { + if index.contains_tuple(None) { + return Err(TupleError::boxed_report( + index.meta(), + "A tuple cannot be used to access an array.", + )); + } + } + } + if var != "_" { + Ok(Statement::Substitution { meta, var, access, op, rhe: new_rhe }) + } else { + // Since expressions cannot have side effects, we can ignore this. + Ok(build_block(meta, Vec::new())) + } + } + } +} + +fn separate_tuple_for_log_call(values: Vec) -> Vec { + let mut new_values = Vec::new(); + for value in values { + if let Expression::Tuple { values: values2, .. } = value { + new_values.push(LogArgument::LogStr("(".to_string())); + let mut sep_values = separate_tuple_for_log_call(values2); + new_values.append(&mut sep_values); + new_values.push(LogArgument::LogStr(")".to_string())); + } else { + new_values.push(LogArgument::LogExp(value)); + } + } + new_values +} + +fn remove_tuple_from_expression(expr: Expression) -> Result> { + use Expression::*; + match expr.clone() { + ArrayInLine { meta, values } => { + for value in values { + if value.contains_tuple(None) { + return Err(TupleError::boxed_report( + &meta, + "A tuple cannot be used to define the dimensions of an array.", + )); + } + } + Ok(expr) + } + Number(_, _) => Ok(expr), + Variable { meta, .. } => { + if expr.contains_tuple(None) { + return Err(TupleError::boxed_report( + &meta, + "A tuple cannot be used to access an array.", + )); + } + Ok(expr) + } + InfixOp { meta, lhe, rhe, .. } => { + if lhe.contains_tuple(None) || rhe.contains_tuple(None) { + return Err(TupleError::boxed_report( + &meta, + "Tuples cannot be used in arithmetic or boolean expressions.", + )); + } + Ok(expr) + } + PrefixOp { meta, rhe, .. } => { + if rhe.contains_tuple(None) { + return Err(TupleError::boxed_report( + &meta, + "Tuples cannot be used in arithmetic or boolean expressions.", + )); + } + Ok(expr) + } + InlineSwitchOp { meta, cond, if_true, if_false } => { + if cond.contains_tuple(None) + || if_true.contains_tuple(None) + || if_false.contains_tuple(None) + { + return Err(TupleError::boxed_report( + &meta, + "Tuples cannot be used inside an inline switch expression.", + )); + } + Ok(expr) + } + Call { meta, args, .. } => { + for value in args { + if value.contains_tuple(None) { + return Err(TupleError::boxed_report( + &meta, + "Tuples cannot be used as an argument to a function call.", + )); + } + } + Ok(expr) + } + AnonymousComponent { .. } => { + // This is called after anonymous components have been removed. + unreachable!(); + } + Tuple { meta, values } => { + let mut unfolded_values = Vec::new(); + for value in values { + let new_value = remove_tuple_from_expression(value)?; + if let Tuple { values: mut inner, .. } = new_value { + unfolded_values.append(&mut inner); + } else { + unfolded_values.push(new_value); + } + } + Ok(build_tuple(meta, unfolded_values)) + } + ParallelOp { meta, rhe } => { + if rhe.contains_tuple(None) { + return Err(TupleError::boxed_report( + &meta, + "Tuples cannot be used in parallel operators.", + )); + } + Ok(expr) + } + } +} + +#[cfg(test)] +mod tests { + use crate::parse_definition; + + use super::*; + + #[test] + fn test_desugar_multi_sub() { + let src = [ + r#" + template Anonymous(n) { + signal input a; + signal input b; + signal output c; + signal output d; + signal output e; + + (c, d, e) <== (a + 1, b + 2, c + 3); + } + "#, + r#" + template Test(n) { + signal input a; + signal input b; + signal output c; + signal output d; + + (c, _, d) <== Anonymous(n)(a, b); + } + "#, + ]; + validate_ast(&src, 0); + } + + #[test] + fn test_nested_tuples() { + let src = [r#" + template Test(n) { + signal input a; + signal input b; + signal output c; + signal output d; + signal output e; + + ((c, d), (_)) <== ((a + 1, b + 2), (c + 3)); + } + "#]; + validate_ast(&src, 0); + + // TODO: Invalid, but is currently accepted by the compiler. + let src = [r#" + template Test(n) { + signal input a; + signal input b; + signal output c; + signal output d; + signal output e; + + ((c, d), e) <== (a + 1, (b + 2, c + 3)); + } + "#]; + validate_ast(&src, 0); + + // TODO: Invalid, but is currently accepted by the compiler. + let src = [r#" + template Test(n) { + signal input a; + signal input b; + signal output c; + + (((c))) <== (a + b); + } + "#]; + validate_ast(&src, 0); + } + + #[test] + fn test_invalid_tuples() { + let src = [r#" + template Test(n) { + signal input a; + signal input b; + signal output c; + signal output d; + signal output e; + + ((c, d), e) <== (b + 2, c + 3); + } + "#]; + validate_ast(&src, 1); + } + + fn validate_ast(src: &[&str], errors: usize) { + let mut reports = ReportCollection::new(); + let (templates, file_library) = parse_templates(src); + + // Verify that `remove_syntactic_sugar` is successful. + let (templates, _) = + remove_syntactic_sugar(&templates, &HashMap::new(), &file_library, &mut reports); + assert_eq!(reports.len(), errors); + + // Ensure that no template contains a tuple or an anonymous component. + for template in templates.values() { + assert!(!template.get_body().contains_tuple(None)); + assert!(!template.get_body().contains_anonymous_component(None)); + } + } + + fn parse_templates(src: &[&str]) -> (HashMap, FileLibrary) { + let mut templates = HashMap::new(); + let mut file_library = FileLibrary::new(); + let mut elem_id = 0; + for src in src { + let file_id = file_library.add_file("memory".to_string(), src.to_string()); + let definition = parse_definition(src).unwrap(); + let Definition::Template { + name, + args, + arg_location, + body, + parallel, + is_custom_gate, + .. + } = definition else { + unreachable!(); + }; + let template = TemplateData::new( + name.clone(), + file_id, + body, + args.len(), + args, + arg_location, + &mut elem_id, + parallel, + is_custom_gate, + ); + templates.insert(name, template); + } + (templates, file_library) + } +} diff --git a/parser/src/syntax_sugar_traits.rs b/parser/src/syntax_sugar_traits.rs new file mode 100644 index 0000000..e4f4887 --- /dev/null +++ b/parser/src/syntax_sugar_traits.rs @@ -0,0 +1,202 @@ +use program_structure::ast::*; +use program_structure::report::ReportCollection; + +use crate::errors::TupleError; + +pub(crate) trait ContainsExpression { + /// Returns true if `self` contains `expr` such that `matcher(expr)` + /// evaluates to true. If the callback is not `None` it is invoked on + /// `expr.meta()` for each matching expression. + fn contains_expr( + &self, + matcher: &impl Fn(&Expression) -> bool, + callback: &mut impl FnMut(&Meta), + ) -> bool; + + /// Returns true if the node contains a tuple. If `reports` is not `None`, a + /// report is generated for each occurrence. + fn contains_tuple(&self, reports: Option<&mut ReportCollection>) -> bool { + let matcher = |expr: &Expression| expr.is_tuple(); + if let Some(reports) = reports { + let mut callback = |meta: &Meta| { + let error = TupleError::new( + Some(meta), + "Tuples are not allowed in functions.", + Some("Tuple instantiated here."), + ); + reports.push(error.into_report()); + }; + self.contains_expr(&matcher, &mut callback) + } else { + // We need to pass a dummy callback because rustc isn't smart enough + // to infer the type parameter to `Option` if we use options here. + let mut dummy = |_: &Meta| {}; + self.contains_expr(&matcher, &mut dummy) + } + } + + /// Returns true if the node contains an anonymous component. If `reports` + /// is not `None`, a report is generated for each occurrence. + fn contains_anonymous_component(&self, reports: Option<&mut ReportCollection>) -> bool { + let matcher = |expr: &Expression| expr.is_anonymous_component(); + if let Some(reports) = reports { + let mut callback = |meta: &Meta| { + let error = TupleError::new( + Some(meta), + "Anonymous components are not allowed in functions.", + Some("Anonymous component instantiated here."), + ); + reports.push(error.into_report()); + }; + self.contains_expr(&matcher, &mut callback) + } else { + // We need to pass a dummy callback because rustc isn't smart enough + // to infer the type parameter to `Option` if we use options here. + let mut dummy = |_: &Meta| {}; + self.contains_expr(&matcher, &mut dummy) + } + } +} + +impl ContainsExpression for Expression { + fn contains_expr( + &self, + matcher: &impl Fn(&Expression) -> bool, + callback: &mut impl FnMut(&Meta), + ) -> bool { + use Expression::*; + // Check if the current expression matches and invoke the callback if + // defined. + if matcher(self) { + callback(self.meta()); + return true; + } + let mut result = false; + match &self { + InfixOp { lhe, rhe, .. } => { + result = lhe.contains_expr(matcher, callback) || result; + result = rhe.contains_expr(matcher, callback) || result; + result + } + PrefixOp { rhe, .. } => rhe.contains_expr(matcher, callback), + InlineSwitchOp { cond, if_true, if_false, .. } => { + result = cond.contains_expr(matcher, callback) || result; + result = if_true.contains_expr(matcher, callback) || result; + result = if_false.contains_expr(matcher, callback) || result; + result + } + Call { args, .. } => { + for arg in args { + result = arg.contains_expr(matcher, callback) || result; + } + result + } + ArrayInLine { values, .. } => { + for value in values { + result = value.contains_expr(matcher, callback) || result; + } + result + } + AnonymousComponent { params, signals, .. } => { + for param in params { + result = param.contains_expr(matcher, callback) || result; + } + for signal in signals { + result = signal.contains_expr(matcher, callback) || result; + } + result + } + Variable { access, .. } => { + for access in access { + if let Access::ArrayAccess(index) = access { + result = index.contains_expr(matcher, callback) || result; + } + } + result + } + Number(_, _) => false, + Tuple { values, .. } => { + for value in values { + result = value.contains_expr(matcher, callback) || result; + } + result + } + ParallelOp { rhe, .. } => rhe.contains_expr(matcher, callback), + } + } +} + +impl ContainsExpression for Statement { + fn contains_expr( + &self, + matcher: &impl Fn(&Expression) -> bool, + callback: &mut impl FnMut(&Meta), + ) -> bool { + use LogArgument::*; + use Statement::*; + use Access::*; + let mut result = false; + match self { + IfThenElse { cond, if_case, else_case, .. } => { + result = cond.contains_expr(matcher, callback) || result; + result = if_case.contains_expr(matcher, callback) || result; + if let Some(else_case) = else_case { + result = else_case.contains_expr(matcher, callback) || result; + } + result + } + While { cond, stmt, .. } => { + result = cond.contains_expr(matcher, callback) || result; + result = stmt.contains_expr(matcher, callback) || result; + result + } + Return { value, .. } => value.contains_expr(matcher, callback), + InitializationBlock { initializations, .. } => { + for init in initializations { + result = init.contains_expr(matcher, callback) || result; + } + result + } + Block { stmts, .. } => { + for stmt in stmts { + result = stmt.contains_expr(matcher, callback) || result; + } + result + } + Declaration { dimensions, .. } => { + for size in dimensions { + result = size.contains_expr(matcher, callback) || result; + } + result + } + Substitution { access, rhe, .. } => { + for access in access { + if let ArrayAccess(index) = access { + result = index.contains_expr(matcher, callback) || result; + } + } + result = rhe.contains_expr(matcher, callback) || result; + result + } + MultiSubstitution { lhe, rhe, .. } => { + result = lhe.contains_expr(matcher, callback) || result; + result = rhe.contains_expr(matcher, callback) || result; + result + } + ConstraintEquality { lhe, rhe, .. } => { + result = lhe.contains_expr(matcher, callback) || result; + result = rhe.contains_expr(matcher, callback) || result; + result + } + LogCall { args, .. } => { + for arg in args { + if let LogExp(expr) = arg { + result = expr.contains_expr(matcher, callback) || result; + } + } + result + } + Assert { arg, .. } => arg.contains_expr(matcher, callback), + } + } +} diff --git a/program_analysis/Cargo.toml b/program_analysis/Cargo.toml index 86ae9a6..5819b01 100644 --- a/program_analysis/Cargo.toml +++ b/program_analysis/Cargo.toml @@ -1,7 +1,8 @@ [package] name = "circomspect-program-analysis" -version = "0.7.2" +version = "0.8.0" edition = "2021" +rust-version = "1.65" license = "LGPL-3.0-only" authors = ["Trail of Bits"] description = "Support crate for the Circomspect static analyzer" @@ -13,9 +14,9 @@ log = "0.4" num-bigint-dig = "0.8" num-traits = "0.2" thiserror = "1.0" -parser = { package = "circomspect-parser", version = "2.0.11", path = "../parser" } -program_structure = { package = "circomspect-program-structure", version = "2.0.11", path = "../program_structure" } +parser = { package = "circomspect-parser", version = "2.1.2", path = "../parser" } +program_structure = { package = "circomspect-program-structure", version = "2.1.2", path = "../program_structure" } [dev-dependencies] -parser = { package = "circomspect-parser", version = "2.0.11", path = "../parser" } -program_structure = { package = "circomspect-program-structure", version = "2.0.11", path = "../program_structure" } +parser = { package = "circomspect-parser", version = "2.1.2", path = "../parser" } +program_structure = { package = "circomspect-program-structure", version = "2.1.2", path = "../program_structure" } diff --git a/program_analysis/src/analysis_context.rs b/program_analysis/src/analysis_context.rs index 9d592c6..2525a90 100644 --- a/program_analysis/src/analysis_context.rs +++ b/program_analysis/src/analysis_context.rs @@ -32,8 +32,6 @@ pub enum AnalysisError { /// Context passed to each analysis pass. pub trait AnalysisContext { - type Error; - /// Returns true if the context knows of a function with the given name. /// This method does not compute the CFG of the function which saves time /// compared to `AnalysisContext::function`. @@ -45,15 +43,15 @@ pub trait AnalysisContext { fn is_template(&self, name: &str) -> bool; /// Returns the CFG for the function with the given name. - fn function(&mut self, name: &str) -> Result<&Cfg, Self::Error>; + fn function(&mut self, name: &str) -> Result<&Cfg, AnalysisError>; /// Returns the CFG for the template with the given name. - fn template(&mut self, name: &str) -> Result<&Cfg, Self::Error>; + fn template(&mut self, name: &str) -> Result<&Cfg, AnalysisError>; /// Returns the string corresponding to the given file ID and location. fn underlying_str( &self, file_id: &FileID, file_location: &FileLocation, - ) -> Result; + ) -> Result; } diff --git a/cli/src/analysis_runner.rs b/program_analysis/src/analysis_runner.rs similarity index 81% rename from cli/src/analysis_runner.rs rename to program_analysis/src/analysis_runner.rs index 892acf1..8f9655d 100644 --- a/cli/src/analysis_runner.rs +++ b/program_analysis/src/analysis_runner.rs @@ -1,12 +1,9 @@ -use log::debug; +use log::{debug, trace}; use std::path::PathBuf; use std::collections::HashMap; use parser::ParseResult; -use program_analysis::{ - analysis_context::{AnalysisContext, AnalysisError}, - get_analysis_passes, -}; + use program_structure::{ writers::{LogWriter, ReportWriter}, template_data::TemplateInfo, @@ -17,43 +14,50 @@ use program_structure::{ report::{ReportCollection, Report}, }; -use crate::config; +#[cfg(test)] +use program_structure::template_library::TemplateLibrary; + +use crate::{ + analysis_context::{AnalysisContext, AnalysisError}, + get_analysis_passes, config, +}; type CfgCache = HashMap; type ReportCache = HashMap; +/// A type responsible for caching CFGs and running analysis passes over all +/// functions and templates. +#[derive(Default)] pub struct AnalysisRunner { curve: Curve, file_library: FileLibrary, + /// Template ASTs generated by the parser. template_asts: TemplateInfo, + /// Function ASTs generated by the parser. function_asts: FunctionInfo, + /// Cached template CFGs generated on demand. template_cfgs: CfgCache, + /// Cached function CFGs generated on demand. function_cfgs: CfgCache, + /// Reports created during CFG generation. template_reports: ReportCache, + /// Reports created during CFG generation. function_reports: ReportCache, } impl AnalysisRunner { - pub fn new(curve: &Curve) -> Self { - AnalysisRunner { - curve: curve.clone(), - file_library: FileLibrary::new(), - template_asts: TemplateInfo::new(), - function_asts: FunctionInfo::new(), - template_cfgs: CfgCache::new(), - function_cfgs: CfgCache::new(), - template_reports: ReportCache::new(), - function_reports: ReportCache::new(), - } + pub fn new(curve: Curve) -> Self { + AnalysisRunner { curve, ..Default::default() } } pub fn with_files( - &mut self, + mut self, input_files: &[PathBuf], + follow_includes: bool, writer: &mut (impl LogWriter + ReportWriter), - ) -> &mut Self { + ) -> Self { let (template_asts, function_asts, file_library) = - match parser::parse_files(input_files, &config::COMPILER_VERSION) { + match parser::parse_files(input_files, follow_includes, &config::COMPILER_VERSION) { ParseResult::Program(program, warnings) => { writer.write_reports(&warnings, &program.file_library); (program.templates, program.functions, program.file_library) @@ -70,6 +74,26 @@ impl AnalysisRunner { self } + /// Convenience method used to generate a runner for testing purposes. + #[cfg(test)] + pub fn with_src(mut self, file_contents: &[&str]) -> Self { + use parser::parse_definition; + + let mut library_contents = HashMap::new(); + let mut file_library = FileLibrary::default(); + for (file_index, file_source) in file_contents.iter().enumerate() { + let file_name = format!("file-{file_index}.circom"); + let file_id = file_library.add_file(file_name, file_source.to_string()); + library_contents.insert(file_id, vec![parse_definition(file_source).unwrap()]); + } + let template_library = TemplateLibrary::new(library_contents, file_library.clone()); + self.template_asts = template_library.templates; + self.function_asts = template_library.functions; + self.file_library = template_library.file_library; + + self + } + pub fn file_library(&self) -> &FileLibrary { &self.file_library } @@ -90,7 +114,7 @@ impl AnalysisRunner { // We take ownership of the CFG and any previously generated reports // here to avoid holding multiple mutable and immutable references to // `self`. This may lead to the CFG being regenerated during analysis if - // the template is invoked recursively. + // the template is invoked recursively. If it is then ¯\_(ツ)_/¯. let mut reports = self.take_template_reports(name); if let Ok(cfg) = self.take_template(name) { for analysis_pass in get_analysis_passes() { @@ -116,7 +140,7 @@ impl AnalysisRunner { // We take ownership of the CFG and any previously generated reports // here to avoid holding multiple mutable and immutable references to // `self`. This may lead to the CFG being regenerated during analysis if - // the function is invoked recursively. + // the function is invoked recursively. If it is then ¯\_(ツ)_/¯. let mut reports = self.take_function_reports(name); if let Ok(cfg) = self.take_function(name) { for analysis_pass in get_analysis_passes() { @@ -169,16 +193,19 @@ impl AnalysisRunner { } // Get the AST corresponding to the template. let Some(ast) = self.template_asts.get(name) else { + trace!("failed to lift unknown template `{name}`"); return Err(AnalysisError::UnknownTemplate { name: name.to_string() }) }; // Generate the template CFG from the AST. Cache any reports. let mut reports = ReportCollection::new(); let cfg = generate_cfg(ast, &self.curve, &mut reports).map_err(|report| { reports.push(*report); + trace!("failed to lift template `{name}`"); AnalysisError::FailedToLiftTemplate { name: name.to_string() } })?; self.append_template_reports(name, &mut reports); self.template_cfgs.insert(name.to_string(), cfg); + trace!("successfully lifted template `{name}`"); } Ok(self.template_cfgs.get(name).unwrap()) } @@ -192,44 +219,45 @@ impl AnalysisRunner { } // Get the AST corresponding to the function. let Some(ast) = self.function_asts.get(name) else { + trace!("failed to lift unknown function `{name}`"); return Err(AnalysisError::UnknownFunction { name: name.to_string() }) }; // Generate the function CFG from the AST. Cache any reports. let mut reports = ReportCollection::new(); let cfg = generate_cfg(ast, &self.curve, &mut reports).map_err(|report| { reports.push(*report); + trace!("failed to lift function `{name}`"); AnalysisError::FailedToLiftFunction { name: name.to_string() } })?; self.append_function_reports(name, &mut reports); self.function_cfgs.insert(name.to_string(), cfg); + trace!("successfully lifted function `{name}`"); } Ok(self.function_cfgs.get(name).unwrap()) } - fn take_template(&mut self, name: &str) -> Result { + pub fn take_template(&mut self, name: &str) -> Result { self.cache_template(name)?; // The CFG must be available since caching was successful. Ok(self.template_cfgs.remove(name).unwrap()) } - fn take_function(&mut self, name: &str) -> Result { + pub fn take_function(&mut self, name: &str) -> Result { self.cache_function(name)?; // The CFG must be available since caching was successful. Ok(self.function_cfgs.remove(name).unwrap()) } - fn replace_template(&mut self, name: &str, cfg: Cfg) -> bool { + pub fn replace_template(&mut self, name: &str, cfg: Cfg) -> bool { self.template_cfgs.insert(name.to_string(), cfg).is_some() } - fn replace_function(&mut self, name: &str, cfg: Cfg) -> bool { + pub fn replace_function(&mut self, name: &str, cfg: Cfg) -> bool { self.function_cfgs.insert(name.to_string(), cfg).is_some() } } impl AnalysisContext for AnalysisRunner { - type Error = AnalysisError; - fn is_template(&self, name: &str) -> bool { self.template_asts.get(name).is_some() } @@ -238,11 +266,11 @@ impl AnalysisContext for AnalysisRunner { self.function_asts.get(name).is_some() } - fn template(&mut self, name: &str) -> Result<&Cfg, Self::Error> { + fn template(&mut self, name: &str) -> Result<&Cfg, AnalysisError> { self.cache_template(name) } - fn function(&mut self, name: &str) -> Result<&Cfg, Self::Error> { + fn function(&mut self, name: &str) -> Result<&Cfg, AnalysisError> { self.cache_function(name) } @@ -250,7 +278,7 @@ impl AnalysisContext for AnalysisRunner { &self, file_id: &FileID, file_location: &FileLocation, - ) -> Result { + ) -> Result { let Ok(file) = self.file_library.to_storage().get(*file_id) else { return Err(AnalysisError::UnknownFile { file_id: *file_id }); }; @@ -278,14 +306,13 @@ fn generate_cfg( #[cfg(test)] mod tests { - use parser::parse_definition; - use program_structure::{template_library::TemplateLibrary, intermediate_representation::Statement}; + use program_structure::ir::Statement; use super::*; #[test] fn test_function() { - let mut runner = runner_from_src(&[r#" + let mut runner = AnalysisRunner::new(Curve::Goldilocks).with_src(&[r#" function foo(a) { return a[0] + a[1]; } @@ -314,7 +341,7 @@ mod tests { #[test] fn test_template() { - let mut runner = runner_from_src(&[r#" + let mut runner = AnalysisRunner::new(Curve::Goldilocks).with_src(&[r#" template Foo(n) { signal input a[2]; @@ -346,7 +373,7 @@ mod tests { #[test] fn test_underlying_str() { use Statement::*; - let mut runner = runner_from_src(&[r#" + let mut runner = AnalysisRunner::new(Curve::Goldilocks).with_src(&[r#" template Foo(n) { signal input a[2]; @@ -360,28 +387,11 @@ mod tests { let file_location = stmt.meta().file_location(); let string = runner.underlying_str(&file_id, &file_location).unwrap(); match stmt { + // TODO: Why do some statements include the semi-colon and others don't? Declaration { .. } => assert_eq!(string, "signal input a[2]"), ConstraintEquality { .. } => assert_eq!(string, "a[0] === a[1];"), _ => unreachable!(), } } } - - fn runner_from_src(src: &[&str]) -> AnalysisRunner { - let mut file_content = HashMap::new(); - let mut file_library = FileLibrary::default(); - for (file_index, file_source) in src.iter().enumerate() { - let file_name = format!("{file_index}.circom"); - let file_id = file_library.add_file(file_name, file_source.to_string()); - println!("{file_id}"); - file_content.insert(file_id, vec![parse_definition(file_source).unwrap()]); - } - let template_library = TemplateLibrary::new(file_content, file_library.clone()); - - let mut runner = AnalysisRunner::new(&Curve::Goldilocks); - runner.template_asts = template_library.templates; - runner.function_asts = template_library.functions; - runner.file_library = file_library; - runner - } } diff --git a/program_analysis/src/bn128_specific_circuit.rs b/program_analysis/src/bn128_specific_circuit.rs index 3408e9a..bd1d021 100644 --- a/program_analysis/src/bn128_specific_circuit.rs +++ b/program_analysis/src/bn128_specific_circuit.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use log::debug; use program_structure::cfg::Cfg; @@ -7,19 +9,49 @@ use program_structure::report::{Report, ReportCollection}; use program_structure::report_code::ReportCode; use program_structure::file_definition::{FileLocation, FileID}; -const BN128_SPECIFIC_CIRCUITS: [&str; 12] = [ - "Sign", +const PROBLEMATIC_GOLDILOCK_TEMPLATES: [&str; 26] = [ + "BabyPbk", "AliasCheck", "CompConstant", "Num2Bits_strict", "Bits2Num_strict", + "EdDSAVerifier", + "EdDSAMiMCVerifier", + "EdDSAMiMCSpongeVerifier", + "EdDSAPoseidonVerifier", + "EscalarMulAny", + "MiMC7", + "MultiMiMC7", + "MiMCFeistel", + "MiMCSponge", + "Pedersen", "Bits2Point_Strict", "Point2Bits_Strict", - "SMTVerifier", + "PoseidonEx", + "Poseidon", + "Sign", + "SMTHash1", + "SMTHash2", "SMTProcessor", + "SMTProcessorLevel", + "SMTVerifier", + "SMTVerifierLevel", +]; + +const PROBLEMATIC_BLS12_381_TEMPLATES: [&str; 13] = [ + "AliasCheck", + "CompConstant", + "Num2Bits_strict", + "Bits2Num_strict", "EdDSAVerifier", - "EdDSAPoseidonVerifier", + "EdDSAMiMCVerifier", "EdDSAMiMCSpongeVerifier", + "EdDSAPoseidonVerifier", + "Bits2Point_Strict", + "Point2Bits_Strict", + "SMTVerifier", + "SMTProcessor", + "Sign", ]; pub struct BN128SpecificCircuitWarning { @@ -32,7 +64,7 @@ impl BN128SpecificCircuitWarning { pub fn into_report(self) -> Report { let mut report = Report::warning( format!( - "The `{}` template hard-codes BN128 specific parameters and should not be used with other curves.", + "The `{}` template relies on BN128 specific parameters and should not be used with other curves.", self.template_name ), ReportCode::BN128SpecificCircuit, @@ -48,23 +80,65 @@ impl BN128SpecificCircuitWarning { } } +// This analysis pass identifies Circomlib templates with hard-coded constants +// related to BN128. If these are used together with a different prime, this may +// be an issue. +// +// The following table contains a check for each problematic template-curve pair. +// +// Template Goldilocks (64 bits) BLS12-381 (255 bits) +// ----------------------------------------------------------------- +// AliasCheck x x +// BabyPbk x +// Bits2Num_strict x x +// Num2Bits_strict x x +// CompConstant x x +// EdDSAVerifier x x +// EdDSAMiMCVerifier x x +// EdDSAMiMCSpongeVerifier x x +// EdDSAPoseidonVerifier x x +// EscalarMulAny x +// MiMC7 x +// MultiMiMC7 x +// MiMCFeistel x +// MiMCSponge x +// Pedersen x +// Bits2Point_strict x x +// Point2Bits_strict x x +// PoseidonEx x +// Poseidon x +// Sign x x +// SMTHash1 x +// SMTHash2 x +// SMTProcessor x x +// SMTProcessorLevel x +// SMTVerifier x x +// SMTVerifierLevel x pub fn find_bn128_specific_circuits(cfg: &Cfg) -> ReportCollection { - if cfg.constants().curve() == &Curve::Bn128 { - // Exit early if we're using the default curve. - return ReportCollection::new(); - } + let problematic_templates = match cfg.constants().curve() { + Curve::Goldilocks => HashSet::from(PROBLEMATIC_GOLDILOCK_TEMPLATES), + Curve::Bls12_381 => HashSet::from(PROBLEMATIC_BLS12_381_TEMPLATES), + Curve::Bn128 => { + // Exit early if we're using the default curve. + return ReportCollection::new(); + } + }; debug!("running bn128-specific circuit analysis pass"); let mut reports = ReportCollection::new(); for basic_block in cfg.iter() { for stmt in basic_block.iter() { - visit_statement(stmt, &mut reports); + visit_statement(stmt, &problematic_templates, &mut reports); } } debug!("{} new reports generated", reports.len()); reports } -fn visit_statement(stmt: &Statement, reports: &mut ReportCollection) { +fn visit_statement( + stmt: &Statement, + problematic_templates: &HashSet<&str>, + reports: &mut ReportCollection, +) { use AssignOp::*; use Expression::*; use Statement::*; @@ -78,7 +152,7 @@ fn visit_statement(stmt: &Statement, reports: &mut ReportCollection) { // A component initialization on the form `var = component_name(...)`. if let Call { meta: component_meta, name: component_name, .. } = rhe { - if BN128_SPECIFIC_CIRCUITS.contains(&&component_name[..]) { + if problematic_templates.contains(&&component_name[..]) { reports.push(build_report(component_meta, component_name)); } } @@ -137,7 +211,7 @@ mod tests { let mut reports = ReportCollection::new(); let cfg = parse_definition(src) .unwrap() - .into_cfg(&Curve::Goldilocks, &mut reports) + .into_cfg(&Curve::Bls12_381, &mut reports) .unwrap() .into_ssa() .unwrap(); diff --git a/program_analysis/src/config.rs b/program_analysis/src/config.rs new file mode 100644 index 0000000..9156406 --- /dev/null +++ b/program_analysis/src/config.rs @@ -0,0 +1,5 @@ +use program_structure::ast::Version; + +pub const COMPILER_VERSION: Version = (2, 1, 2); +pub const DEFAULT_LEVEL: &str = "WARNING"; +pub const DEFAULT_CURVE: &str = "BN128"; diff --git a/program_analysis/src/constant_conditional.rs b/program_analysis/src/constant_conditional.rs index 954d753..9336b6f 100644 --- a/program_analysis/src/constant_conditional.rs +++ b/program_analysis/src/constant_conditional.rs @@ -48,8 +48,8 @@ fn visit_statement(stmt: &Statement, reports: &mut ReportCollection) { use Statement::*; use ValueReduction::*; if let IfThenElse { cond, .. } = stmt { - let value = cond.meta().value_knowledge().get_reduces_to(); - if let Some(Boolean { value }) = value { + let value = cond.meta().value_knowledge(); + if let Boolean(Some(value)) = dbg!(value) { reports.push(build_report(cond.meta(), *value)); } } diff --git a/program_analysis/src/constraint_analysis.rs b/program_analysis/src/constraint_analysis.rs index 77743d3..8f81f2f 100644 --- a/program_analysis/src/constraint_analysis.rs +++ b/program_analysis/src/constraint_analysis.rs @@ -159,9 +159,9 @@ mod tests { } "#; let sources = [ - VariableName::from_name("in"), - VariableName::from_name("out"), - VariableName::from_name("tmp"), + VariableName::from_string("in"), + VariableName::from_string("out"), + VariableName::from_string("tmp"), ]; let sinks = [2, 1, 1]; validate_constraints(src, &sources, &sinks); @@ -178,9 +178,9 @@ mod tests { } "#; let sources = [ - VariableName::from_name("in"), - VariableName::from_name("out"), - VariableName::from_name("tmp"), + VariableName::from_string("in"), + VariableName::from_string("out"), + VariableName::from_string("tmp"), ]; let sinks = [2, 1, 1]; validate_constraints(src, &sources, &sinks); diff --git a/program_analysis/src/definition_complexity.rs b/program_analysis/src/definition_complexity.rs index 387981b..6f2ec89 100644 --- a/program_analysis/src/definition_complexity.rs +++ b/program_analysis/src/definition_complexity.rs @@ -88,3 +88,40 @@ pub fn run_complexity_analysis(cfg: &Cfg) -> ReportCollection { } reports } + +#[cfg(test)] +mod tests { + use parser::parse_definition; + use program_structure::{report::ReportCollection, constants::Curve, cfg::IntoCfg}; + + use crate::definition_complexity::run_complexity_analysis; + + #[test] + fn test_small_template() { + let src = r#" + template Example () { + signal input a; + signal output b; + a <== b; + } + "#; + validate_reports(src, 0); + } + + fn validate_reports(src: &str, expected_len: usize) { + // Build CFG. + let mut reports = ReportCollection::new(); + let cfg = parse_definition(src) + .unwrap() + .into_cfg(&Curve::default(), &mut reports) + .unwrap() + .into_ssa() + .unwrap(); + assert!(reports.is_empty()); + + // Generate report collection. + let reports = run_complexity_analysis(&cfg); + + assert_eq!(reports.len(), expected_len); + } +} diff --git a/program_analysis/src/lib.rs b/program_analysis/src/lib.rs index 020be2f..19bb005 100644 --- a/program_analysis/src/lib.rs +++ b/program_analysis/src/lib.rs @@ -1,4 +1,5 @@ -use analysis_context::{AnalysisContext, AnalysisError}; +use analysis_context::AnalysisContext; + use program_structure::cfg::Cfg; use program_structure::report::ReportCollection; @@ -7,8 +8,10 @@ extern crate num_bigint_dig as num_bigint; pub mod constraint_analysis; pub mod taint_analysis; pub mod analysis_context; +pub mod analysis_runner; +pub mod config; -// Analysis passes. +// Intra-process analysis passes. mod bitwise_complement; mod bn128_specific_circuit; mod constant_conditional; @@ -16,18 +19,21 @@ mod definition_complexity; mod field_arithmetic; mod field_comparisons; mod nonstrict_binary_conversion; +mod non_boolean_condition; mod under_constrained_signals; mod unconstrained_less_than; mod unconstrained_division; mod side_effect_analysis; mod signal_assignments; +// Inter-process analysis passes. +mod unused_output_signal; + /// An analysis pass is a function which takes an analysis context and a CFG and /// returns a set of reports. -type AnalysisPass<'a> = - dyn Fn(&mut dyn AnalysisContext, &'a Cfg) -> ReportCollection + 'a; +type AnalysisPass = dyn Fn(&mut dyn AnalysisContext, &Cfg) -> ReportCollection; -pub fn get_analysis_passes<'a>() -> Vec>> { +pub fn get_analysis_passes() -> Vec> { vec![ // Intra-process analysis passes. Box::new(|_, cfg| bitwise_complement::find_bitwise_complement(cfg)), @@ -42,5 +48,8 @@ pub fn get_analysis_passes<'a>() -> Vec>> { Box::new(|_, cfg| constant_conditional::find_constant_conditional_statement(cfg)), Box::new(|_, cfg| under_constrained_signals::find_under_constrained_signals(cfg)), Box::new(|_, cfg| nonstrict_binary_conversion::find_nonstrict_binary_conversion(cfg)), + Box::new(|_, cfg| non_boolean_condition::find_non_boolean_conditional(cfg)), + // Inter-process analysis passes. + Box::new(unused_output_signal::find_unused_output_signals), ] } diff --git a/program_analysis/src/non_boolean_condition.rs b/program_analysis/src/non_boolean_condition.rs new file mode 100644 index 0000000..a516cc5 --- /dev/null +++ b/program_analysis/src/non_boolean_condition.rs @@ -0,0 +1,222 @@ +#![deny(warnings)] +use log::debug; + +use program_structure::cfg::Cfg; +use program_structure::report_code::ReportCode; +use program_structure::report::{Report, ReportCollection}; +use program_structure::file_definition::{FileID, FileLocation}; +use program_structure::ir::value_meta::ValueReduction; +use program_structure::ir::*; + +pub struct NonBooleanConditionWarning { + value: ValueReduction, + file_id: Option, + file_location: FileLocation, +} + +impl NonBooleanConditionWarning { + pub fn into_report(self) -> Report { + let mut report = Report::warning( + "Value used in boolean position may not be boolean.".to_string(), + ReportCode::ConstantBranchCondition, + ); + if let Some(file_id) = self.file_id { + let msg = match self.value { + ValueReduction::FieldElement(v) => format!( + "This value is a field element{}.", + if let Some(v) = v { format!(" equal to {v}") } else { "".to_string() } + ), + + _ => "This value may or may not be a boolean".to_string(), + }; + + report.add_primary(self.file_location, file_id, msg); + } + report + } +} + +/// This analysis pass uses constant propagation to determine cases where +/// the expression in a condition may not be a Boolean. +pub fn find_non_boolean_conditional(cfg: &Cfg) -> ReportCollection { + debug!("running non-boolean conditional analysis pass"); + let mut reports = ReportCollection::new(); + for basic_block in cfg.iter() { + for stmt in basic_block.iter() { + visit_statement(stmt, &mut reports); + } + } + debug!("{} new reports generated", reports.len()); + reports +} + +fn expect_boolean(e: &Expression, reports: &mut ReportCollection) { + let value = e.meta().value_knowledge(); + if !matches!(value, ValueReduction::Boolean(_)) { + reports.push(build_report(e.meta(), value.clone())); + } +} + +fn visit_statement(stmt: &Statement, reports: &mut ReportCollection) { + use Statement::*; + match stmt { + IfThenElse { cond, .. } => { + visit_expression(cond, reports); + expect_boolean(cond, reports); + } + + Declaration { dimensions, .. } => { + for d in dimensions { + visit_expression(d, reports); + } + } + + Return { value, .. } => visit_expression(value, reports), + + Substitution { rhe, .. } => visit_expression(rhe, reports), + ConstraintEquality { lhe, rhe, .. } => { + visit_expression(lhe, reports); + visit_expression(rhe, reports); + } + LogCall { args, .. } => { + for arg in args { + if let LogArgument::Expr(e) = arg { + visit_expression(e, reports); + } + } + } + + Assert { arg, .. } => visit_expression(arg, reports), + } +} + +fn visit_expression(e: &Expression, reports: &mut ReportCollection) { + use Expression::*; + match e { + InfixOp { meta: _, lhe, infix_op, rhe } => { + visit_expression(lhe, reports); + visit_expression(rhe, reports); + + use ExpressionInfixOpcode::*; + match infix_op { + BoolOr | BoolAnd => { + expect_boolean(lhe, reports); + expect_boolean(rhe, reports); + } + _ => {} + } + } + + PrefixOp { meta: _, prefix_op, rhe } => { + visit_expression(rhe, reports); + if let ExpressionPrefixOpcode::BoolNot = prefix_op { + expect_boolean(rhe, reports); + } + } + + SwitchOp { meta: _, cond, if_true, if_false } => { + visit_expression(if_true, reports); + visit_expression(if_false, reports); + visit_expression(cond, reports); + expect_boolean(cond, reports); + } + + Call { args, .. } => { + for a in args { + visit_expression(a, reports); + } + } + + InlineArray { values, .. } => { + for v in values { + visit_expression(v, reports); + } + } + + Access { access, .. } => { + for a in access { + if let AccessType::ArrayAccess(e) = a { + visit_expression(e, reports); + } + } + } + + Update { access, rhe, .. } => { + for a in access { + if let AccessType::ArrayAccess(e) = a { + visit_expression(e, reports); + } + } + visit_expression(rhe, reports); + } + + Phi { .. } => {} + Variable { .. } => {} + Number { .. } => {} + } +} + +fn build_report(meta: &Meta, value: ValueReduction) -> Report { + NonBooleanConditionWarning { + value, + file_id: meta.file_id(), + file_location: meta.file_location(), + } + .into_report() +} + +#[cfg(test)] +mod tests { + use parser::parse_definition; + use program_structure::{cfg::IntoCfg, constants::Curve}; + + use super::*; + + #[test] + fn test_non_boolean_conditional() { + let src = r#" + function f(x) { + var a = 1; + var b = (2 * a * a + 1) << 2; + var c = (3 * b / b - 2) >> 1; + if (c >> 4 || x) { + a += x; + b += x * a; + } + return a + b; + } + "#; + validate_reports(src, 2); + + let src = r#" + function f(x) { + var a = 1; + var b = (2 * a * a + 1) << 2; + var c = (3 * b / x - 2) >> 1; + if (c > 4) { + a += x; + b += x * a; + } + return a + b; + } + "#; + validate_reports(src, 0); + } + + fn validate_reports(src: &str, expected_len: usize) { + // Build CFG. + let mut reports = ReportCollection::new(); + let cfg = parse_definition(src) + .unwrap() + .into_cfg(&Curve::default(), &mut reports) + .unwrap() + .into_ssa() + .unwrap(); + assert!(reports.is_empty()); + + // Generate report collection. + let reports = find_non_boolean_conditional(&cfg); + + assert_eq!(reports.len(), expected_len); + } +} diff --git a/program_analysis/src/nonstrict_binary_conversion.rs b/program_analysis/src/nonstrict_binary_conversion.rs index 44ed8e4..7578943 100644 --- a/program_analysis/src/nonstrict_binary_conversion.rs +++ b/program_analysis/src/nonstrict_binary_conversion.rs @@ -103,8 +103,8 @@ fn visit_statement(stmt: &Statement, prime_size: &BigInt, reports: &mut ReportCo let arg = &args[0]; // If the input size is known to be less than the prime size, this // initialization is safe. - if let Some(FieldElement { value }) = arg.value() { - if value < prime_size { + if let FieldElement(Some(value)) = arg.value() { + if &value < prime_size { return; } } @@ -115,8 +115,8 @@ fn visit_statement(stmt: &Statement, prime_size: &BigInt, reports: &mut ReportCo let arg = &args[0]; // If the input size is known to be less than the prime size, this // initialization is safe. - if let Some(FieldElement { value }) = arg.value() { - if value < prime_size { + if let FieldElement(Some(value)) = arg.value() { + if &value < prime_size { return; } } diff --git a/program_analysis/src/side_effect_analysis.rs b/program_analysis/src/side_effect_analysis.rs index 932893e..8447e15 100644 --- a/program_analysis/src/side_effect_analysis.rs +++ b/program_analysis/src/side_effect_analysis.rs @@ -255,7 +255,7 @@ pub fn run_side_effect_analysis(cfg: &Cfg) -> ReportCollection { .declarations() .iter() .filter_map(|(name, declaration)| { - if matches!(declaration.variable_type(), VariableType::Signal(_)) { + if matches!(declaration.variable_type(), VariableType::Signal(_, _)) { Some((name, declaration)) } else { None @@ -267,7 +267,7 @@ pub fn run_side_effect_analysis(cfg: &Cfg) -> ReportCollection { .filter_map(|(name, declaration)| { if matches!( declaration.variable_type(), - VariableType::Signal(SignalType::Input | SignalType::Output) + VariableType::Signal(SignalType::Input | SignalType::Output, _) ) { Some(*name) } else { @@ -324,9 +324,15 @@ pub fn run_side_effect_analysis(cfg: &Cfg) -> ReportCollection { let mut reported_vars = HashSet::new(); // Generate a report for any variable that does not taint a sink. + // // TODO: The call to TaintAnalysis::taints_any chokes on CFGs containing // large (65536 element) arrays. for source in taint_analysis.definitions() { + // Circom 2.1.2 introduces `_` for ignored variables in tuple + // assignments. We respect this convention here as well. + if source.to_string() == "_" { + continue; + } if !variables_read.contains(source.name()) { // If the variable is unread, the corresponding value is unused. if cfg.parameters().contains(source.name()) { @@ -346,9 +352,15 @@ pub fn run_side_effect_analysis(cfg: &Cfg) -> ReportCollection { } } // Generate reports for unused or unconstrained signals. + // // TODO: The call to TaintAnalysis::taints_any chokes on CFGs containing // large (65536 element) arrays. for (source, declaration) in signal_decls { + // Circom 2.1.2 introduces `_` for ignored variables in tuple + // assignments. We respect this convention here as well. + if source.to_string() == "_" { + continue; + } // Don't generate multiple reports for the same variable. if reported_vars.contains(&source.to_string()) { continue; diff --git a/program_analysis/src/taint_analysis.rs b/program_analysis/src/taint_analysis.rs index 97d9c98..f523763 100644 --- a/program_analysis/src/taint_analysis.rs +++ b/program_analysis/src/taint_analysis.rs @@ -138,7 +138,7 @@ pub fn run_taint_analysis(cfg: &Cfg) -> TaintAnalysis { IfThenElse { cond, .. } => { // A variable which occurs in a non-constant condition taints all // variables assigned in the if-statement body. - if cond.value().is_some() { + if cond.value().is_constant() { continue; } let true_branch = cfg.get_true_branch(basic_block); @@ -243,7 +243,7 @@ mod tests { let taint_analysis = run_taint_analysis(&cfg); for (source, expected_sinks) in taint_map { - let source = VariableName::from_name(source).with_version(0); + let source = VariableName::from_string(source).with_version(0); let sinks = taint_analysis .multi_step_taint(&source) .iter() diff --git a/program_analysis/src/unconstrained_division.rs b/program_analysis/src/unconstrained_division.rs index 9280a9f..57fa9ff 100644 --- a/program_analysis/src/unconstrained_division.rs +++ b/program_analysis/src/unconstrained_division.rs @@ -73,11 +73,11 @@ impl Component { fn output(&self) -> Option { use ValueReduction::*; - let value = self.output.as_ref().and_then(|output| output.value()); + let value = self.output.as_ref().map(|output| output.value()); match value { - Some(FieldElement { value }) => Some(!value.is_zero()), - Some(Boolean { value }) => Some(*value), - None => None, + Some(FieldElement(Some(value))) => Some(!value.is_zero()), + Some(Boolean(Some(value))) => Some(value), + _ => None, } } } diff --git a/program_analysis/src/unconstrained_less_than.rs b/program_analysis/src/unconstrained_less_than.rs index f5edc08..4401a4b 100644 --- a/program_analysis/src/unconstrained_less_than.rs +++ b/program_analysis/src/unconstrained_less_than.rs @@ -3,39 +3,41 @@ use std::fmt; use log::{debug, trace}; +use num_bigint::BigInt; use program_structure::cfg::Cfg; +use program_structure::ir::value_meta::{ValueMeta, ValueReduction}; use program_structure::report_code::ReportCode; use program_structure::report::{Report, ReportCollection}; -use program_structure::file_definition::{FileID, FileLocation}; use program_structure::ir::*; pub struct UnconstrainedLessThanWarning { - input_size: Expression, - file_id: Option, - primary_location: FileLocation, - secondary_location: FileLocation, + value: Expression, + bit_sizes: Vec<(Meta, Expression)>, } - impl UnconstrainedLessThanWarning { + fn primary_meta(&self) -> &Meta { + self.value.meta() + } + pub fn into_report(self) -> Report { let mut report = Report::warning( - "Inputs to `LessThan` should typically be constrained to the input size".to_string(), + "Inputs to `LessThan` need to be constrained to ensure that they are non-negative" + .to_string(), ReportCode::UnconstrainedLessThan, ); - if let Some(file_id) = self.file_id { + if let Some(file_id) = self.primary_meta().file_id { report.add_primary( - self.primary_location, + self.primary_meta().file_location(), file_id, - format!( - "This input to `LessThan` should be constrained to `{}` bits.", - self.input_size - ), - ); - report.add_secondary( - self.secondary_location, - file_id, - Some("Circomlib template `LessThan` instantiated here.".to_string()), + format!("`{}` needs to be constrained to ensure that it is <= p/2.", self.value), ); + for (meta, size) in self.bit_sizes { + report.add_secondary( + meta.file_location(), + file_id, + Some(format!("`{}` is constrained to `{}` bits here.", self.value, size)), + ); + } } report } @@ -54,109 +56,64 @@ impl VariableAccess { } } -/// Tracks component instantiations `var = T(...)` where `T` is either `LessThan` -/// or `Num2Bits`. +/// Tracks component instantiations `var = T(...)` where then template `T` is +/// either `LessThan` or `Num2Bits`. enum Component { - LessThan { meta: Box, required_size: Box }, - Num2Bits { enforced_size: Box }, + LessThan, + Num2Bits { bit_size: Box }, } impl Component { - fn less_than(meta: &Meta, required_size: &Expression) -> Self { - Self::LessThan { - meta: Box::new(meta.clone()), - required_size: Box::new(required_size.clone()), - } + fn less_than() -> Self { + Self::LessThan } - fn num_2_bits(enforced_size: &Expression) -> Self { - Self::Num2Bits { enforced_size: Box::new(enforced_size.clone()) } + fn num_2_bits(bit_size: &Expression) -> Self { + Self::Num2Bits { bit_size: Box::new(bit_size.clone()) } } } /// Tracks component input signal initializations on the form `T.in <== input` /// where `T` is either `LessThan` or `Num2Bits`. enum ComponentInput { - LessThan { - component_meta: Box, - input_meta: Box, - value: Box, - required_size: Box, - }, - Num2Bits { - value: Box, - enforced_size: Box, - }, + LessThan { value: Box }, + Num2Bits { value: Box, bit_size: Box }, } impl ComponentInput { - fn less_than( - component_meta: &Meta, - input_meta: &Meta, - value: &Expression, - required_size: &Expression, - ) -> Self { - Self::LessThan { - component_meta: Box::new(component_meta.clone()), - input_meta: Box::new(input_meta.clone()), - value: Box::new(value.clone()), - required_size: Box::new(required_size.clone()), - } - } - - fn num_2_bits(value: &Expression, enforced_size: &Expression) -> Self { - Self::Num2Bits { - value: Box::new(value.clone()), - enforced_size: Box::new(enforced_size.clone()), - } - } -} - -// The signal input at `signal_meta` for the component defined at -// `component_meta` must be at most `size` bits. -struct SizeEntry { - pub component_meta: Meta, - pub input_meta: Meta, - pub required_size: Expression, -} - -impl SizeEntry { - pub fn new(component_meta: &Meta, input_meta: &Meta, required_size: &Expression) -> Self { - SizeEntry { - component_meta: component_meta.clone(), - input_meta: input_meta.clone(), - required_size: required_size.clone(), - } + fn less_than(value: &Expression) -> Self { + Self::LessThan { value: Box::new(value.clone()) } } -} -impl fmt::Debug for SizeEntry { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:?}", self.required_size) + fn num_2_bits(value: &Expression, bit_size: &Expression) -> Self { + Self::Num2Bits { value: Box::new(value.clone()), bit_size: Box::new(bit_size.clone()) } } } -/// Size constraints for a single component input. -#[derive(Debug, Default)] -struct SizeConstraints { - /// Size constraint required by `LessThan`. - pub required: Vec, - /// Size constraint enforced by `Num2Bits`. - pub enforced: Vec, +/// Tracks constraints for a single input to `LessThan`. +#[derive(Default)] +struct ConstraintData { + /// Input to `LessThan`. + pub less_than: Vec, + /// Input to `Num2Bits`. + pub num_2_bits: Vec, + /// Size constraints enforced by `Num2Bits`. + pub bit_sizes: Vec, } /// The `LessThan` template from Circomlib does not constrain the individual -/// inputs to the input size `n`. If the input size can be more than `n` bits, -/// it is possible to find field elements `a` and `b` such that +/// inputs to the input size `n` bits, or to be positive. If the inputs are +/// allowed to be greater than p/2 it is possible to find field elements `a` and +/// `b` such that /// -/// 1. `a > b`, +/// 1. `a > b` either as unsigned integers, or as signed elements in GF(p), /// 2. lt = LessThan(n), /// 3. lt.in[0] = a, /// 4. lt.in[1] = b, and /// 5. lt.out = 1 /// /// This analysis pass looks for instantiations of `LessThan` where the inputs -/// are not constrained to `n` bits using `Num2Bits`. +/// are not constrained to be <= p/2 using `Num2Bits`. pub fn find_unconstrained_less_than(cfg: &Cfg) -> ReportCollection { debug!("running unconstrained less-than analysis pass"); let mut components = HashMap::new(); @@ -171,35 +128,45 @@ pub fn find_unconstrained_less_than(cfg: &Cfg) -> ReportCollection { update_inputs(stmt, &components, &mut inputs); } } - let mut constraints = HashMap::::new(); + let mut constraints = HashMap::::new(); for input in inputs { match input { - ComponentInput::LessThan { component_meta, input_meta, value, required_size } => { - constraints.entry(*value.clone()).or_default().required.push(SizeEntry::new( - &component_meta, - &input_meta, - &required_size, - )); + ComponentInput::LessThan { value } => { + let entry = constraints.entry(*value.clone()).or_default(); + entry.less_than.push(value.meta().clone()); } - ComponentInput::Num2Bits { value, enforced_size, .. } => { - constraints.entry(*value.clone()).or_default().enforced.push(*enforced_size); + ComponentInput::Num2Bits { value, bit_size, .. } => { + let entry = constraints.entry(*value.clone()).or_default(); + entry.num_2_bits.push(value.meta().clone()); + entry.bit_sizes.push(*bit_size.clone()); } } } // Generate a report for each input to `LessThan` where the input size is - // not constrained to the `LessThan` bit size using `Num2Bits`. + // not constrained to be positive using `Num2Bits`. let mut reports = ReportCollection::new(); - for sizes in constraints.values() { - for required in &sizes.required { - if !sizes.enforced.contains(&required.required_size) { - reports.push(build_report( - &required.component_meta, - &required.input_meta, - &required.required_size, - )) + let max_value = BigInt::from(cfg.constants().prime_size() - 1); + for (value, data) in constraints { + // Check if the the value is used as input for `LessThan`. + if data.less_than.is_empty() { + continue; + } + // Check if the value is constrained to be positive. + let mut is_positive = false; + for bit_size in &data.bit_sizes { + if let ValueReduction::FieldElement(Some(ref value)) = bit_size.value() { + if value < &max_value { + is_positive = true; + break; + } } } + if is_positive { + continue; + } + // We failed to prove that the input is positive. Generate a report. + reports.push(build_report(&value, &data)); } debug!("{} new reports generated", reports.len()); reports @@ -228,7 +195,7 @@ fn update_components(stmt: &Statement, components: &mut HashMap Report { +fn build_report(value: &Expression, data: &ConstraintData) -> Report { UnconstrainedLessThanWarning { - input_size: size.clone(), - file_id: component_meta.file_id, - primary_location: input_meta.file_location(), - secondary_location: component_meta.file_location(), + value: value.clone(), + bit_sizes: data.num_2_bits.iter().cloned().zip(data.bit_sizes.iter().cloned()).collect(), } .into_report() } @@ -355,7 +318,7 @@ mod tests { ok <== lt.out; } "#; - validate_reports(src, 1); + validate_reports(src, 2); let src = r#" template Test(n) { @@ -367,7 +330,7 @@ mod tests { component n2b[2]; n2b[0] = Num2Bits(n); n2b[0].in <== small; - n2b[1] = Num2Bits(n); + n2b[1] = Num2Bits(32); n2b[1].in <== large; // Check that small < large. @@ -378,7 +341,7 @@ mod tests { ok <== lt.out; } "#; - validate_reports(src, 0); + validate_reports(src, 1); let src = r#" template Test(n) { @@ -393,9 +356,9 @@ mod tests { // Constrain inputs to n bits. component n2b[2]; - n2b[0] = Num2Bits(n); + n2b[0] = Num2Bits(32); n2b[0].in <== small; - n2b[1] = Num2Bits(n); + n2b[1] = Num2Bits(64); n2b[1].in <== large; ok <== lt.out; diff --git a/program_analysis/src/under_constrained_signals.rs b/program_analysis/src/under_constrained_signals.rs index 73bb8eb..d68b225 100644 --- a/program_analysis/src/under_constrained_signals.rs +++ b/program_analysis/src/under_constrained_signals.rs @@ -97,7 +97,8 @@ pub fn find_under_constrained_signals(cfg: &Cfg) -> ReportCollection { let mut constraint_locations = cfg .variables() .filter_map(|name| { - if matches!(cfg.get_type(name), Some(VariableType::Signal(SignalType::Intermediate))) { + if matches!(cfg.get_type(name), Some(VariableType::Signal(SignalType::Intermediate, _))) + { Some((name.clone(), Vec::new())) } else { None diff --git a/program_analysis/src/unused_output_signal.rs b/program_analysis/src/unused_output_signal.rs new file mode 100644 index 0000000..e16d0e7 --- /dev/null +++ b/program_analysis/src/unused_output_signal.rs @@ -0,0 +1,627 @@ +use log::debug; +use std::collections::HashSet; + +use program_structure::{ + ir::*, + ir::value_meta::ValueMeta, + report_code::ReportCode, + cfg::{Cfg, DefinitionType}, + report::{Report, ReportCollection}, + file_definition::{FileID, FileLocation}, +}; + +use crate::analysis_context::AnalysisContext; + +// Known templates that are commonly instantiated without accessing the +// corresponding output signals. +const ALLOW_LIST: [&str; 1] = ["Num2Bits"]; + +struct UnusedOutputSignalWarning { + // Location of template instantiation. + file_id: Option, + file_location: FileLocation, + // The currently analyzed template. + current_template: String, + // The instantiated template with an unused output signal. + component_template: String, + // The name of the unused output signal. + signal_name: String, +} + +impl UnusedOutputSignalWarning { + pub fn into_report(self) -> Report { + let mut report = Report::warning( + format!( + "The output signal `{}` defined by the template `{}` is not constrained in `{}`.", + self.signal_name, self.component_template, self.current_template + ), + ReportCode::UnusedOutputSignal, + ); + if let Some(file_id) = self.file_id { + report.add_primary( + self.file_location, + file_id, + format!("The template `{}` is instantiated here.", self.component_template), + ); + } + report + } +} + +#[derive(Clone, Debug)] +struct VariableAccess { + pub var: VariableName, + pub access: Vec, +} + +impl VariableAccess { + fn new(var: &VariableName, access: &[AccessType]) -> Self { + // We disregard the version to make sure accesses are not order dependent. + VariableAccess { var: var.without_version(), access: access.to_vec() } + } +} + +/// A reflexive and symmetric relation capturing partial information about +/// equality. +trait MaybeEqual { + fn maybe_equal(&self, other: &Self) -> bool; +} + +/// This is a reflexive and symmetric (but not transitive!) relation +/// identifying all array accesses where the indices are not explicitly known +/// to be different (e.g. from constant propagation). The relation is not +/// transitive since `v[0] == v[i]` and `v[i] == v[1]`, but `v[0] != v[1]`. +/// +/// Since `maybe_equal` is not transitive we cannot use it to define +/// `PartialEq` for `VariableAccess`. This also means that we cannot use hash +/// sets or hash maps to track variable accesses using this as our equality +/// relation. +impl MaybeEqual for VariableAccess { + fn maybe_equal(&self, other: &VariableAccess) -> bool { + use AccessType::*; + if self.var.name() != other.var.name() { + return false; + } + if self.access.len() != other.access.len() { + return false; + } + for (self_access, other_access) in self.access.iter().zip(other.access.iter()) { + match (self_access, other_access) { + (ArrayAccess(_), ComponentAccess(_)) => { + return false; + } + (ComponentAccess(_), ArrayAccess(_)) => { + return false; + } + (ComponentAccess(self_name), ComponentAccess(other_name)) + if self_name != other_name => + { + return false; + } + (ArrayAccess(self_index), ArrayAccess(other_index)) => { + use value_meta::ValueReduction::*; + match (self_index.value(), other_index.value()) { + (FieldElement(Some(self_value)), FieldElement(Some(other_value))) + if self_value != other_value => + { + return false; + } + (Boolean(Some(self_value)), Boolean(Some(other_value))) + if self_value != other_value => + { + return false; + } + // Identify all other array accesses. + _ => {} + } + } + // Identify all array accesses. + _ => {} + } + } + true + } +} + +/// A relation capturing partial information about containment. +trait MaybeContains { + fn maybe_contains(&self, element: &T) -> bool; +} + +impl MaybeContains for Vec +where + T: MaybeEqual, +{ + fn maybe_contains(&self, element: &T) -> bool { + self.iter().any(|item| item.maybe_equal(element)) + } +} + +struct ComponentData { + pub meta: Meta, + pub var_name: VariableName, + pub var_access: Vec, + pub template_name: String, +} + +impl ComponentData { + pub fn new( + meta: &Meta, + var_name: &VariableName, + var_access: &[AccessType], + template_name: &str, + ) -> Self { + ComponentData { + meta: meta.clone(), + var_name: var_name.clone(), + var_access: var_access.to_vec(), + template_name: template_name.to_string(), + } + } +} + +struct SignalData { + pub meta: Meta, + pub template_name: String, + pub signal_name: String, + pub signal_access: VariableAccess, +} + +impl SignalData { + pub fn new( + meta: &Meta, + template_name: &str, + signal_name: &str, + signal_access: VariableAccess, + ) -> SignalData { + SignalData { + meta: meta.clone(), + template_name: template_name.to_string(), + signal_name: signal_name.to_string(), + signal_access, + } + } +} + +pub fn find_unused_output_signals( + context: &mut dyn AnalysisContext, + current_cfg: &Cfg, +) -> ReportCollection { + // Exit early if the given CFG represents a function. + if matches!(current_cfg.definition_type(), DefinitionType::Function) { + return ReportCollection::new(); + } + debug!("running unused output signal analysis pass"); + let allow_list = HashSet::from(ALLOW_LIST); + + // Collect all instantiated components. + let mut components = Vec::new(); + let mut accesses = Vec::new(); + for basic_block in current_cfg.iter() { + for stmt in basic_block.iter() { + visit_statement(stmt, current_cfg, &mut components, &mut accesses); + } + } + let mut output_signals = Vec::new(); + for component in components { + // Ignore templates on the allow list. + if allow_list.contains(&component.template_name[..]) { + continue; + } + if let Ok(component_cfg) = context.template(&component.template_name) { + for output_signal in component_cfg.output_signals() { + if let Some(declaration) = component_cfg.get_declaration(output_signal) { + // The signal access pattern is given by the component + // access pattern, followed by the output signal name, + // followed by an array access corresponding to each + // dimension entry for the signal. + // + // E.g., for the component `c[i]` with an output signal + // `out` which is a double array, we get `c[i].out[j][k]`. + // Since we identify array accesses we simply use `i` for + // each array access corresponding to the dimensions of the + // signal. + let mut var_access = component.var_access.clone(); + var_access.push(AccessType::ComponentAccess(output_signal.name().to_string())); + for _ in declaration.dimensions() { + let meta = Meta::new(&(0..0), &None); + let index = + Expression::Variable { meta, name: VariableName::from_string("i") }; + var_access.push(AccessType::ArrayAccess(Box::new(index))); + } + let signal_access = VariableAccess::new(&component.var_name, &var_access); + output_signals.push(SignalData::new( + &component.meta, + &component.template_name, + output_signal.name(), + signal_access, + )); + } + } + } + } + let mut reports = ReportCollection::new(); + for output_signal in output_signals { + if !maybe_accesses(&accesses, &output_signal.signal_access) { + reports.push(build_report( + &output_signal.meta, + current_cfg.name(), + &output_signal.template_name, + &output_signal.signal_name, + )) + } + } + + debug!("{} new reports generated", reports.len()); + reports +} + +// Check if there is an access to a prefix of the output signal access which +// contains the output signal name. E.g. for the output signal `n2b[1].out[0]` +// it is enough that the list of all variable accesses `maybe_contains` the +// prefix `n2b[1].out`. This is to catch instances where the template passes the +// output signal as input to a function. +fn maybe_accesses(accesses: &Vec, signal_access: &VariableAccess) -> bool { + use AccessType::*; + let mut signal_access = signal_access.clone(); + while !accesses.maybe_contains(&signal_access) { + if let Some(ComponentAccess(_)) = signal_access.access.last() { + // The output signal name is the last component access in the access + // array. If it is not included in the access, the output signal is + // not accessed by the template. + return false; + } else { + signal_access.access.pop(); + } + } + true +} + +fn visit_statement( + stmt: &Statement, + cfg: &Cfg, + components: &mut Vec, + accesses: &mut Vec, +) { + use Statement::*; + use Expression::*; + use VariableType::*; + // Collect all instantiated components. + if let Substitution { var: var_name, rhe, .. } = stmt { + let (var_access, rhe) = if let Update { access, rhe, .. } = rhe { + (access.clone(), *rhe.clone()) + } else { + (Vec::new(), rhe.clone()) + }; + if let (Some(Component), Call { meta, name: template_name, .. }) = + (cfg.get_type(var_name), rhe) + { + components.push(ComponentData::new(&meta, var_name, &var_access, &template_name)); + } + } + // Collect all variable accesses. + match stmt { + Substitution { rhe, .. } => visit_expression(rhe, accesses), + ConstraintEquality { lhe, rhe, .. } => { + visit_expression(lhe, accesses); + visit_expression(rhe, accesses); + } + Declaration { .. } => { /* We ignore dimensions in declarations. */ } + IfThenElse { .. } => { /* We ignore if-statement conditions. */ } + Return { .. } => { /* We ignore return statements. */ } + LogCall { .. } => { /* We ignore log statements. */ } + Assert { .. } => { /* We ignore asserts. */ } + } +} + +fn visit_expression(expr: &Expression, accesses: &mut Vec) { + use Expression::*; + match expr { + PrefixOp { rhe, .. } => { + visit_expression(rhe, accesses); + } + InfixOp { lhe, rhe, .. } => { + visit_expression(lhe, accesses); + visit_expression(rhe, accesses); + } + SwitchOp { cond, if_true, if_false, .. } => { + visit_expression(cond, accesses); + visit_expression(if_true, accesses); + visit_expression(if_false, accesses); + } + Call { args, .. } => { + for arg in args { + visit_expression(arg, accesses); + } + } + InlineArray { values, .. } => { + for value in values { + visit_expression(value, accesses); + } + } + Access { var, access, .. } => { + accesses.push(VariableAccess::new(var, access)); + } + Update { rhe, .. } => { + // We ignore accesses in assignments. + visit_expression(rhe, accesses); + } + Variable { .. } | Number(_, _) | Phi { .. } => (), + } +} + +fn build_report( + meta: &Meta, + current_template: &str, + component_template: &str, + signal_name: &str, +) -> Report { + UnusedOutputSignalWarning { + file_id: meta.file_id(), + file_location: meta.file_location(), + current_template: current_template.to_string(), + component_template: component_template.to_string(), + signal_name: signal_name.to_string(), + } + .into_report() +} + +#[cfg(test)] +mod tests { + use num_bigint_dig::BigInt; + use program_structure::{ + constants::Curve, + intermediate_representation::{ + VariableName, AccessType, Expression, Meta, value_meta::ValueReduction, + }, + }; + + use crate::{ + analysis_runner::AnalysisRunner, + unused_output_signal::{MaybeEqual, MaybeContains, maybe_accesses}, + }; + + use super::{find_unused_output_signals, VariableAccess}; + + #[test] + fn test_maybe_equal() { + use AccessType::*; + use Expression::*; + use ValueReduction::*; + + let var = VariableName::from_string("var"); + let meta = Meta::new(&(0..0), &None); + let mut zero = Box::new(Number(meta.clone(), BigInt::from(0))); + let mut one = Box::new(Number(meta.clone(), BigInt::from(1))); + let i = Box::new(Variable { meta, name: VariableName::from_string("i") }); + + // Set the value of `zero` and `one` explicitly. + let _ = zero + .meta_mut() + .value_knowledge_mut() + .set_reduces_to(FieldElement { value: BigInt::from(0) }); + let _ = one + .meta_mut() + .value_knowledge_mut() + .set_reduces_to(FieldElement { value: BigInt::from(1) }); + + // `var[0].out` + let first_access = VariableAccess::new( + &var.with_version(1), + &[ArrayAccess(zero.clone()), ComponentAccess("out".to_string())], + ); + // `var[i].out` + let second_access = VariableAccess::new( + &var.with_version(2), + &[ArrayAccess(i.clone()), ComponentAccess("out".to_string())], + ); + // `var[1].out` + let third_access = VariableAccess::new( + &var.with_version(3), + &[ArrayAccess(one), ComponentAccess("out".to_string())], + ); + // `var[i].out[0]` + let fourth_access = VariableAccess::new( + &var.with_version(4), + &[ArrayAccess(i), ComponentAccess("out".to_string()), ArrayAccess(zero)], + ); + + // The first and second accesses should be identified. + assert!(first_access.maybe_equal(&second_access)); + // The first and third accesses should not be identified. + assert!(!first_access.maybe_equal(&third_access)); + + let accesses = vec![first_access]; + + // The first and second accesses should be identified. + assert!(accesses.maybe_contains(&second_access)); + // The first and third accesses should not be identified. + assert!(!accesses.maybe_contains(&third_access)); + + // The fourth access is not equal to the first, but a prefix is. + assert!(!accesses.maybe_contains(&fourth_access)); + assert!(maybe_accesses(&accesses, &fourth_access)); + } + + #[test] + fn test_maybe_accesses() {} + + #[test] + fn test_unused_output_signal() { + // The output signal `out` in `Test` is not accessed, for any of the two + // instantiated components. + let src = [ + r#" + template Test() { + signal input in; + signal output out; + + out <== 2 * in + 1; + } + "#, + r#" + template Main() { + signal input in[2]; + + component test[2]; + test[0] = Test(); + test[1] = Test(); + test[0].in <== in[0]; + test[1].in <== in[1]; + } + "#, + ]; + validate_reports("Main", &src, 2); + + // `Num2Bits` is on the allow list and should not produce a report. + let src = [ + r#" + template Num2Bits(n) { + signal input in; + signal output out[n]; + + for (var i = 0; i < n; i++) { + out[i] <== in; + } + } + "#, + r#" + template Main() { + signal input in; + + component n2b = Num2Bits(); + n2b.in <== in[0]; + + in[1] === in[0] + 1; + } + "#, + ]; + validate_reports("Main", &src, 0); + + // If the template is not known we should not produce a report. + let src = [r#" + template Main() { + signal input in[2]; + + component test[2]; + test[0] = Test(); + test[1] = Test(); + test[0].in <== in[0]; + test[1].in <== in[1]; + } + "#]; + validate_reports("Main", &src, 0); + + // Should generate a warning for `test[1]` but not for `test[0]`. + let src = [ + r#" + template Test() { + signal input in; + signal output out; + + out <== 2 * in + 1; + } + "#, + r#" + template Main() { + signal input in[2]; + + component test[2]; + test[0] = Test(); + test[1] = Test(); + test[0].in <== in[0]; + test[1].in <== in[1]; + + test[0].out === 1; + } + "#, + ]; + validate_reports("Main", &src, 1); + + // Should not generate a warning for `test.out`. + let src = [ + r#" + template Test() { + signal input in; + signal output out[2]; + + out[0] <== 2 * in + 1; + out[1] <== 3 * in + 2; + } + "#, + r#" + template Main() { + signal input in; + + component test; + test = Test(); + test.in <== in[0]; + + func(test.out) === 1; + } + "#, + ]; + validate_reports("Main", &src, 0); + + // TODO: Should detect that `test[i].out[1]` is not accessed. + let src = [ + r#" + template Test() { + signal input in; + signal output out[2]; + + out[0] <== 2 * in + 1; + out[1] <== 3 * in + 2; + } + "#, + r#" + template Main() { + signal input in[2]; + + component test[2]; + for (var i = 0; i < 2; i++) { + test[i] = Test(); + test[i].in <== in[i]; + } + for (var i = 0; i < 2; i++) { + test[i].out[0] === 1; + } + } + "#, + ]; + validate_reports("Main", &src, 0); + + // TODO: Should detect that `test[1].out` is not accessed. + let src = [ + r#" + template Test() { + signal input in; + signal output out; + + out <== 2 * in + 1; + } + "#, + r#" + template Main() { + signal input in[2]; + + component test[2]; + for (var i = 0; i < 2; i++) { + test[i] = Test(); + test[i].in = in[i]; + } + + test[0].out === 1; + } + "#, + ]; + validate_reports("Main", &src, 0); + } + + fn validate_reports(name: &str, src: &[&str], expected_len: usize) { + let mut context = AnalysisRunner::new(Curve::Goldilocks).with_src(src); + let cfg = context.take_template(name).unwrap(); + let reports = find_unused_output_signals(&mut context, &cfg); + assert_eq!(reports.len(), expected_len); + } +} diff --git a/program_structure/Cargo.toml b/program_structure/Cargo.toml index f45a6ad..9447fec 100644 --- a/program_structure/Cargo.toml +++ b/program_structure/Cargo.toml @@ -1,7 +1,8 @@ [package] name = "circomspect-program-structure" -version = "2.0.11" -edition = "2018" +version = "2.1.2" +edition = "2021" +rust-version = "1.65" license = "LGPL-3.0-only" description = "Support crate for the Circomspect static analyzer" repository = "https://github.com/trailofbits/circomspect" @@ -13,7 +14,7 @@ authors = [ [dependencies] anyhow = "1.0" atty = "0.2" -circom_algebra = { package = "circomspect-circom-algebra", version = "2.0.1", path = "../circom_algebra" } +circom_algebra = { package = "circomspect-circom-algebra", version = "2.0.2", path = "../circom_algebra" } codespan = "0.11" codespan-reporting = "0.11" log = "0.4" @@ -30,4 +31,4 @@ termcolor = "1.1.3" [dev-dependencies] proptest = "1.0" -circom_algebra = { package = "circomspect-circom-algebra", version = "2.0.1", path = "../circom_algebra" } +circom_algebra = { package = "circomspect-circom-algebra", version = "2.0.2", path = "../circom_algebra" } diff --git a/program_structure/src/abstract_syntax_tree/ast.rs b/program_structure/src/abstract_syntax_tree/ast.rs index c1acf69..514a4bb 100644 --- a/program_structure/src/abstract_syntax_tree/ast.rs +++ b/program_structure/src/abstract_syntax_tree/ast.rs @@ -13,6 +13,7 @@ pub fn build_main_component(public: Vec, call: Expression) -> MainCompon } pub type Version = (usize, usize, usize); +pub type TagList = Vec; #[derive(Clone)] pub struct Include { @@ -196,6 +197,12 @@ pub enum Statement { op: AssignOp, rhe: Expression, }, + MultiSubstitution { + meta: Meta, + lhe: Expression, + op: AssignOp, + rhe: Expression, + }, ConstraintEquality { meta: Meta, lhe: Expression, @@ -229,11 +236,12 @@ pub enum SignalType { Intermediate, } -#[derive(Copy, Clone, PartialEq, Ord, PartialOrd, Eq)] +#[derive(Clone, PartialEq, Ord, PartialOrd, Eq)] pub enum VariableType { Var, - Signal(SignalType, SignalElementType), + Signal(SignalType, TagList), Component, + AnonymousComponent, } #[derive(Clone)] @@ -270,10 +278,29 @@ pub enum Expression { id: String, args: Vec, }, + AnonymousComponent { + meta: Meta, + id: String, + is_parallel: bool, + params: Vec, + signals: Vec, + names: Option>, + }, + // UniformArray is only used internally by Circom for default initialization + // of uninitialized arrays. + // UniformArray { + // meta: Meta, + // value: Box, + // dimension: Box, + // }, ArrayInLine { meta: Meta, values: Vec, }, + Tuple { + meta: Meta, + values: Vec, + }, } #[derive(Clone)] @@ -288,14 +315,14 @@ pub fn build_array_access(expr: Expression) -> Access { Access::ArrayAccess(expr) } -#[derive(Copy, Clone, Eq, PartialEq)] +#[derive(Copy, Clone, Eq, PartialEq, Debug)] pub enum AssignOp { AssignVar, AssignSignal, AssignConstraintSignal, } -#[derive(Copy, Clone, PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum ExpressionInfixOpcode { Mul, Div, @@ -346,6 +373,7 @@ pub enum TypeReduction { Variable, Component, Signal, + Tag, } #[derive(Default, Clone)] @@ -359,21 +387,24 @@ impl TypeKnowledge { pub fn set_reduces_to(&mut self, reduces_to: TypeReduction) { self.reduces_to = Option::Some(reduces_to); } - pub fn get_reduces_to(&self) -> TypeReduction { + pub fn reduces_to(&self) -> TypeReduction { if let Option::Some(t) = &self.reduces_to { *t } else { - panic!("reduces_to knowledge is been look at without being initialized"); + panic!("Type knowledge accessed before it is initialized."); } } pub fn is_var(&self) -> bool { - self.get_reduces_to() == TypeReduction::Variable + self.reduces_to() == TypeReduction::Variable } pub fn is_component(&self) -> bool { - self.get_reduces_to() == TypeReduction::Component + self.reduces_to() == TypeReduction::Component } pub fn is_signal(&self) -> bool { - self.get_reduces_to() == TypeReduction::Signal + self.reduces_to() == TypeReduction::Signal + } + pub fn is_tag(&self) -> bool { + self.reduces_to() == TypeReduction::Tag } } @@ -394,25 +425,25 @@ impl MemoryKnowledge { pub fn set_abstract_memory_address(&mut self, value: usize) { self.abstract_memory_address = Option::Some(value); } - pub fn get_concrete_dimensions(&self) -> &[usize] { + pub fn concrete_dimensions(&self) -> &[usize] { if let Option::Some(v) = &self.concrete_dimensions { v } else { - panic!("concrete dimensions was look at without being initialized"); + panic!("Concrete dimensions accessed before it is initialized."); } } - pub fn get_full_length(&self) -> usize { + pub fn full_length(&self) -> usize { if let Option::Some(v) = &self.full_length { *v } else { - panic!("full dimension was look at without being initialized"); + panic!("Full dimension accessed before it is initialized."); } } - pub fn get_abstract_memory_address(&self) -> usize { + pub fn abstract_memory_address(&self) -> usize { if let Option::Some(v) = &self.abstract_memory_address { *v } else { - panic!("abstract memory address was look at without being initialized"); + panic!("Abstract memory address accessed before it is initialized."); } } } diff --git a/program_structure/src/abstract_syntax_tree/ast_shortcuts.rs b/program_structure/src/abstract_syntax_tree/ast_shortcuts.rs index 868fed2..4952767 100644 --- a/program_structure/src/abstract_syntax_tree/ast_shortcuts.rs +++ b/program_structure/src/abstract_syntax_tree/ast_shortcuts.rs @@ -11,6 +11,10 @@ pub struct Symbol { pub init: Option, } +pub struct TupleInit { + pub tuple_init: (AssignOp, Expression), +} + pub fn assign_with_op_shortcut( op: ExpressionInfixOpcode, meta: Meta, @@ -56,7 +60,7 @@ pub fn split_declaration_into_single_nodes( for symbol in symbols { let with_meta = meta.clone(); - let has_type = xtype; + let has_type = xtype.clone(); let name = symbol.name.clone(); let dimensions = symbol.is_array; let possible_init = symbol.init; @@ -67,9 +71,9 @@ pub fn split_declaration_into_single_nodes( let substitution = build_substitution(meta.clone(), symbol.name, vec![], op, init); initializations.push(substitution); } - // If the variable is not initialialized it is default initialized to 0. - // We remove this because we don't want this assignment to be flagged as - // an unused assignment by the side-effect analysis. + // If the variable is not initialized it is default initialized to 0 by + // Circom. We remove this because we don't want this assignment to be + // flagged as an unused assignment by the side-effect analysis. // else if xtype == Var { // let mut value = Expression::Number(meta.clone(), BigInt::from(0)); // for dim_expr in dimensions.iter().rev() { @@ -82,3 +86,48 @@ pub fn split_declaration_into_single_nodes( } build_initialization_block(meta, xtype, initializations) } + +pub fn split_declaration_into_single_nodes_and_multi_substitution( + meta: Meta, + xtype: VariableType, + symbols: Vec, + init: Option, +) -> Statement { + let mut initializations = Vec::new(); + let mut values = Vec::new(); + for symbol in symbols { + let with_meta = meta.clone(); + let has_type = xtype.clone(); + let name = symbol.name.clone(); + let dimensions = symbol.is_array; + debug_assert!(symbol.init.is_none()); + let single_declaration = + build_declaration(with_meta.clone(), has_type, name.clone(), dimensions.clone()); + initializations.push(single_declaration); + // Circom default initializes local arrays to 0. We remove this because + // we don't want these assignments to be flagged as unused assignments + // by the side-effect analysis. + // if xtype == Var && init.is_none() { + // let mut value = Expression::Number(meta.clone(), BigInt::from(0)); + // for dim_expr in dimensions.iter().rev() { + // value = build_uniform_array(meta.clone(), value, dim_expr.clone()); + // } + + // let substitution = + // build_substitution(meta.clone(), symbol.name, vec![], AssignOp::AssignVar, value); + // initializations.push(substitution); + // } + values.push(Expression::Variable { meta: with_meta.clone(), name, access: Vec::new() }) + } + if let Some(tuple) = init { + let (op, expression) = tuple.tuple_init; + let multi_sub = build_multi_substitution( + meta.clone(), + build_tuple(meta.clone(), values), + op, + expression, + ); + initializations.push(multi_sub); + } + build_initialization_block(meta, xtype, initializations) +} diff --git a/program_structure/src/abstract_syntax_tree/expression_builders.rs b/program_structure/src/abstract_syntax_tree/expression_builders.rs index cfe98d5..fdfa2db 100644 --- a/program_structure/src/abstract_syntax_tree/expression_builders.rs +++ b/program_structure/src/abstract_syntax_tree/expression_builders.rs @@ -38,13 +38,40 @@ pub fn build_variable(meta: Meta, name: String, access: Vec) -> Expressi } pub fn build_number(meta: Meta, value: BigInt) -> Expression { - Expression::Number(meta, value) + Number(meta, value) } pub fn build_call(meta: Meta, id: String, args: Vec) -> Expression { Call { meta, id, args } } +pub fn build_anonymous_component( + meta: Meta, + id: String, + params: Vec, + signals: Vec, + names: Option>, + is_parallel: bool, +) -> Expression { + AnonymousComponent { meta, id, params, signals, names, is_parallel } +} + pub fn build_array_in_line(meta: Meta, values: Vec) -> Expression { ArrayInLine { meta, values } } + +pub fn build_tuple(meta: Meta, values: Vec) -> Expression { + Tuple { meta, values } +} + +pub fn unzip_3( + vec: Vec<(String, AssignOp, Expression)>, +) -> (Vec<(AssignOp, String)>, Vec) { + let mut op_name = Vec::new(); + let mut exprs = Vec::new(); + for i in vec { + op_name.push((i.1, i.0)); + exprs.push(i.2); + } + (op_name, exprs) +} diff --git a/program_structure/src/abstract_syntax_tree/expression_impl.rs b/program_structure/src/abstract_syntax_tree/expression_impl.rs index 13fdac6..69ec8ab 100644 --- a/program_structure/src/abstract_syntax_tree/expression_impl.rs +++ b/program_structure/src/abstract_syntax_tree/expression_impl.rs @@ -1,8 +1,10 @@ -use super::ast::*; use std::fmt::{Debug, Display, Error, Formatter}; +use super::ast::*; +use super::expression_builders::build_anonymous_component; + impl Expression { - pub fn get_meta(&self) -> &Meta { + pub fn meta(&self) -> &Meta { use Expression::*; match self { InfixOp { meta, .. } @@ -12,10 +14,12 @@ impl Expression { | ParallelOp { meta, .. } | Number(meta, ..) | Call { meta, .. } - | ArrayInLine { meta, .. } => meta, + | AnonymousComponent { meta, .. } + | ArrayInLine { meta, .. } + | Tuple { meta, .. } => meta, } } - pub fn get_mut_meta(&mut self) -> &mut Meta { + pub fn meta_mut(&mut self) -> &mut Meta { use Expression::*; match self { InfixOp { meta, .. } @@ -25,7 +29,9 @@ impl Expression { | ParallelOp { meta, .. } | Number(meta, ..) | Call { meta, .. } - | ArrayInLine { meta, .. } => meta, + | AnonymousComponent { meta, .. } + | ArrayInLine { meta, .. } + | Tuple { meta, .. } => meta, } } @@ -68,14 +74,35 @@ impl Expression { use Expression::*; matches!(self, ParallelOp { .. }) } + + pub fn is_tuple(&self) -> bool { + use Expression::*; + matches!(self, Tuple { .. }) + } + + pub fn is_anonymous_component(&self) -> bool { + use Expression::*; + matches!(self, AnonymousComponent { .. }) + } + + pub fn make_anonymous_parallel(self) -> Expression { + use Expression::*; + match self { + AnonymousComponent { meta, id, params, signals, names, .. } => { + build_anonymous_component(meta, id, params, signals, names, true) + } + _ => self, + } + } } impl FillMeta for Expression { fn fill(&mut self, file_id: usize, element_id: &mut usize) { use Expression::*; - self.get_mut_meta().elem_id = *element_id; + self.meta_mut().elem_id = *element_id; *element_id += 1; match self { + Tuple { meta, values } => fill_tuple(meta, values, file_id, element_id), Number(meta, _) => fill_number(meta, file_id, element_id), Variable { meta, access, .. } => fill_variable(meta, access, file_id, element_id), InfixOp { meta, lhe, rhe, .. } => fill_infix(meta, lhe, rhe, file_id, element_id), @@ -88,6 +115,9 @@ impl FillMeta for Expression { ArrayInLine { meta, values, .. } => { fill_array_inline(meta, values, file_id, element_id) } + AnonymousComponent { meta, params, signals, .. } => { + fill_anonymous_component(meta, params, signals, file_id, element_id) + } } } } @@ -155,6 +185,29 @@ fn fill_array_inline( } } +fn fill_anonymous_component( + meta: &mut Meta, + params: &mut [Expression], + signals: &mut [Expression], + file_id: usize, + element_id: &mut usize, +) { + meta.set_file_id(file_id); + for param in params { + param.fill(file_id, element_id); + } + for signal in signals { + signal.fill(file_id, element_id); + } +} + +fn fill_tuple(meta: &mut Meta, values: &mut [Expression], file_id: usize, element_id: &mut usize) { + meta.set_file_id(file_id); + for value in values { + value.fill(file_id, element_id); + } +} + fn fill_parallel(meta: &mut Meta, rhe: &mut Expression, file_id: usize, element_id: &mut usize) { meta.set_file_id(file_id); rhe.fill(file_id, element_id); @@ -162,7 +215,19 @@ fn fill_parallel(meta: &mut Meta, rhe: &mut Expression, file_id: usize, element_ impl Debug for Expression { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { - write!(f, "{}", self) + use Expression::*; + match self { + InfixOp { .. } => write!(f, "Expression::InfixOp"), + PrefixOp { .. } => write!(f, "Expression::PrefixOp"), + InlineSwitchOp { .. } => write!(f, "Expression::InlineSwitchOp"), + Variable { .. } => write!(f, "Expression::Variable"), + ParallelOp { .. } => write!(f, "Expression::ParallelOp"), + Number(..) => write!(f, "Expression::Number"), + Call { .. } => write!(f, "Expression::Call"), + AnonymousComponent { .. } => write!(f, "Expression::AnonymousComponent"), + ArrayInLine { .. } => write!(f, "Expression::ArrayInline"), + Tuple { .. } => write!(f, "Expression::Tuple"), + } } } @@ -170,6 +235,7 @@ impl Display for Expression { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { use Expression::*; match self { + Tuple { values, .. } => write!(f, "({})", vec_to_string(values)), Number(_, value) => write!(f, "{}", value), Variable { name, access, .. } => { write!(f, "{name}")?; @@ -184,8 +250,11 @@ impl Display for Expression { InlineSwitchOp { cond, if_true, if_false, .. } => { write!(f, "({cond}? {if_true} : {if_false})") } - Call { id, args, .. } => write!(f, "{}({})", id, vec_to_string(args)), + Call { id, args, .. } => write!(f, "{id}({})", vec_to_string(args)), ArrayInLine { values, .. } => write!(f, "[{}]", vec_to_string(values)), + AnonymousComponent { id, params, signals, names, .. } => { + write!(f, "{id}({})({})", vec_to_string(params), signals_to_string(names, signals)) + } } } } @@ -242,3 +311,16 @@ impl Display for Access { fn vec_to_string(elems: &[Expression]) -> String { elems.iter().map(|arg| arg.to_string()).collect::>().join(", ") } + +fn signals_to_string(names: &Option>, signals: &[Expression]) -> String { + if let Some(names) = names { + names + .iter() + .zip(signals.iter()) + .map(|((op, name), signal)| format!("{name} {op} {signal}")) + .collect::>() + } else { + signals.iter().map(|signal| signal.to_string()).collect::>() + } + .join(", ") +} diff --git a/program_structure/src/abstract_syntax_tree/statement_builders.rs b/program_structure/src/abstract_syntax_tree/statement_builders.rs index a7ddfc1..dc66556 100644 --- a/program_structure/src/abstract_syntax_tree/statement_builders.rs +++ b/program_structure/src/abstract_syntax_tree/statement_builders.rs @@ -83,3 +83,21 @@ fn split_string(str: String) -> Vec { pub fn build_assert(meta: Meta, arg: Expression) -> Statement { Assert { meta, arg } } + +pub fn build_multi_substitution( + meta: Meta, + lhe: Expression, + op: AssignOp, + rhe: Expression, +) -> Statement { + MultiSubstitution { meta, lhe, op, rhe } +} + +pub fn build_anonymous_component_statement(meta: Meta, arg: Expression) -> Statement { + MultiSubstitution { + meta: meta.clone(), + lhe: crate::expression_builders::build_tuple(meta, Vec::new()), + op: AssignOp::AssignConstraintSignal, + rhe: arg, + } +} diff --git a/program_structure/src/abstract_syntax_tree/statement_impl.rs b/program_structure/src/abstract_syntax_tree/statement_impl.rs index 66853ad..85930a8 100644 --- a/program_structure/src/abstract_syntax_tree/statement_impl.rs +++ b/program_structure/src/abstract_syntax_tree/statement_impl.rs @@ -11,6 +11,7 @@ impl Statement { | Return { meta, .. } | Declaration { meta, .. } | Substitution { meta, .. } + | MultiSubstitution { meta, .. } | LogCall { meta, .. } | Block { meta, .. } | Assert { meta, .. } @@ -26,6 +27,7 @@ impl Statement { | Return { meta, .. } | Declaration { meta, .. } | Substitution { meta, .. } + | MultiSubstitution { meta, .. } | LogCall { meta, .. } | Block { meta, .. } | Assert { meta, .. } @@ -38,38 +40,52 @@ impl Statement { use Statement::*; matches!(self, IfThenElse { .. }) } + pub fn is_while(&self) -> bool { use Statement::*; matches!(self, While { .. }) } + pub fn is_return(&self) -> bool { use Statement::*; matches!(self, Return { .. }) } + pub fn is_initialization_block(&self) -> bool { use Statement::*; matches!(self, InitializationBlock { .. }) } + pub fn is_declaration(&self) -> bool { use Statement::*; matches!(self, Declaration { .. }) } + pub fn is_substitution(&self) -> bool { use Statement::*; matches!(self, Substitution { .. }) } + + pub fn is_multi_substitution(&self) -> bool { + use Statement::*; + matches!(self, MultiSubstitution { .. }) + } + pub fn is_constraint_equality(&self) -> bool { use Statement::*; matches!(self, ConstraintEquality { .. }) } + pub fn is_log_call(&self) -> bool { use Statement::*; matches!(self, LogCall { .. }) } + pub fn is_block(&self) -> bool { use Statement::*; matches!(self, Block { .. }) } + pub fn is_assert(&self) -> bool { use Statement::*; matches!(self, Assert { .. }) @@ -96,6 +112,9 @@ impl FillMeta for Statement { Substitution { meta, access, rhe, .. } => { fill_substitution(meta, access, rhe, file_id, element_id) } + MultiSubstitution { meta, lhe, rhe, .. } => { + fill_multi_substitution(meta, lhe, rhe, file_id, element_id); + } ConstraintEquality { meta, lhe, rhe } => { fill_constraint_equality(meta, lhe, rhe, file_id, element_id) } @@ -179,6 +198,18 @@ fn fill_substitution( } } +fn fill_multi_substitution( + meta: &mut Meta, + lhe: &mut Expression, + rhe: &mut Expression, + file_id: usize, + element_id: &mut usize, +) { + meta.set_file_id(file_id); + rhe.fill(file_id, element_id); + lhe.fill(file_id, element_id); +} + fn fill_constraint_equality( meta: &mut Meta, lhe: &mut Expression, @@ -221,16 +252,17 @@ impl Debug for Statement { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { use Statement::*; match self { - IfThenElse { .. } => f.write_str("Statement::IfThenElse"), - While { .. } => f.write_str("Statement::While"), - Return { .. } => f.write_str("Statement::Return"), - Declaration { .. } => f.write_str("Statement::Declaration"), - Substitution { .. } => f.write_str("Statement::Substitution"), - LogCall { .. } => f.write_str("Statement::LogCall"), - Block { .. } => f.write_str("Statement::Block"), - Assert { .. } => f.write_str("Statement::Assert"), - ConstraintEquality { .. } => f.write_str("Statement::ConstraintEquality"), - InitializationBlock { .. } => f.write_str("Statement::InitializationBlock"), + IfThenElse { .. } => write!(f, "Statement::IfThenElse"), + While { .. } => write!(f, "Statement::While"), + Return { .. } => write!(f, "Statement::Return"), + Declaration { .. } => write!(f, "Statement::Declaration"), + Substitution { .. } => write!(f, "Statement::Substitution"), + MultiSubstitution { .. } => write!(f, "Statement::MultiSubstitution"), + LogCall { .. } => write!(f, "Statement::LogCall"), + Block { .. } => write!(f, "Statement::Block"), + Assert { .. } => write!(f, "Statement::Assert"), + ConstraintEquality { .. } => write!(f, "Statement::ConstraintEquality"), + InitializationBlock { .. } => write!(f, "Statement::InitializationBlock"), } } } @@ -240,7 +272,7 @@ impl Display for Statement { use Statement::*; match self { IfThenElse { cond, else_case, .. } => match else_case { - Some(_) => write!(f, "if-else {cond}"), + Some(_) => write!(f, "if {cond} else"), None => write!(f, "if {cond}"), }, While { cond, .. } => write!(f, "while {cond}"), @@ -253,21 +285,11 @@ impl Display for Statement { } write!(f, " {op} {rhe}") } - LogCall { args, .. } => { - write!(f, "log(")?; - for (index, arg) in args.iter().enumerate() { - if index > 0 { - write!(f, ", ")?; - } - write!(f, "{arg}")?; - } - write!(f, ")") - } - // TODO: Remove this when switching to IR. + MultiSubstitution { lhe, op, rhe, .. } => write!(f, "{lhe} {op} {rhe}"), + LogCall { args, .. } => write!(f, "log({})", vec_to_string(args)), Block { .. } => Ok(()), Assert { arg, .. } => write!(f, "assert({arg})"), ConstraintEquality { lhe, rhe, .. } => write!(f, "{lhe} === {rhe}"), - // TODO: Remove this when switching to IR. InitializationBlock { .. } => Ok(()), } } @@ -290,14 +312,20 @@ impl Display for VariableType { use VariableType::*; match self { Var => write!(f, "var"), - Signal(signal_type, _) => { + Signal(signal_type, tag_list) => { if matches!(signal_type, Intermediate) { - write!(f, "signal") + write!(f, "signal")?; } else { - write!(f, "signal {signal_type}") + write!(f, "signal {signal_type}")?; + } + if !tag_list.is_empty() { + write!(f, " {{{}}}", tag_list.join("}} {{")) + } else { + Ok(()) } } Component => write!(f, "component"), + AnonymousComponent => write!(f, "anonymous component"), } } } @@ -322,3 +350,7 @@ impl Display for LogArgument { } } } + +fn vec_to_string(elems: &[T]) -> String { + elems.iter().map(|arg| arg.to_string()).collect::>().join(", ") +} diff --git a/program_structure/src/control_flow_graph/cfg.rs b/program_structure/src/control_flow_graph/cfg.rs index e702c8b..5151f34 100644 --- a/program_structure/src/control_flow_graph/cfg.rs +++ b/program_structure/src/control_flow_graph/cfg.rs @@ -9,7 +9,7 @@ use crate::ir::declarations::{Declaration, Declarations}; use crate::ir::degree_meta::{DegreeEnvironment, Degree, DegreeRange}; use crate::ir::value_meta::ValueEnvironment; use crate::ir::variable_meta::VariableMeta; -use crate::ir::{VariableName, VariableType}; +use crate::ir::{VariableName, VariableType, SignalType}; use crate::ssa::dominator_tree::DominatorTree; use crate::ssa::errors::SSAResult; use crate::ssa::{insert_phi_statements, insert_ssa_variables}; @@ -174,6 +174,30 @@ impl Cfg { self.declarations.iter().map(|(name, _)| name) } + /// Returns an iterator over the input signals of the CFG. + pub fn input_signals(&self) -> impl Iterator { + use SignalType::*; + use VariableType::*; + self.declarations.iter().filter_map(|(name, declaration)| { + match declaration.variable_type() { + Signal(Input, _) => Some(name), + _ => None, + } + }) + } + + /// Returns an iterator over the output signals of the CFG. + pub fn output_signals(&self) -> impl Iterator { + use SignalType::*; + use VariableType::*; + self.declarations.iter().filter_map(|(name, declaration)| { + match declaration.variable_type() { + Signal(Output, _) => Some(name), + _ => None, + } + }) + } + /// Returns the declaration of the given variable. #[must_use] pub fn get_declaration(&self, name: &VariableName) -> Option<&Declaration> { diff --git a/program_structure/src/control_flow_graph/errors.rs b/program_structure/src/control_flow_graph/errors.rs index d77a7f3..3be18b3 100644 --- a/program_structure/src/control_flow_graph/errors.rs +++ b/program_structure/src/control_flow_graph/errors.rs @@ -31,9 +31,9 @@ pub enum CFGError { pub type CFGResult = Result; impl CFGError { - pub fn produce_report(error: Self) -> Report { + pub fn into_report(self) -> Report { use CFGError::*; - match error { + match self { UndefinedVariableError { name, file_id, file_location } => { let mut report = Report::error( format!("The variable `{name}` is used before it is defined."), @@ -124,6 +124,6 @@ impl From for CFGError { impl From for Report { fn from(error: CFGError) -> Report { - CFGError::produce_report(error) + error.into_report() } } diff --git a/program_structure/src/control_flow_graph/parameters.rs b/program_structure/src/control_flow_graph/parameters.rs index e5f2fd9..990922a 100644 --- a/program_structure/src/control_flow_graph/parameters.rs +++ b/program_structure/src/control_flow_graph/parameters.rs @@ -19,7 +19,7 @@ impl Parameters { file_location: FileLocation, ) -> Parameters { Parameters { - param_names: param_names.iter().map(VariableName::from_name).collect(), + param_names: param_names.iter().map(VariableName::from_string).collect(), file_id, file_location, } diff --git a/program_structure/src/control_flow_graph/unique_vars.rs b/program_structure/src/control_flow_graph/unique_vars.rs index eb594c3..3a694ce 100644 --- a/program_structure/src/control_flow_graph/unique_vars.rs +++ b/program_structure/src/control_flow_graph/unique_vars.rs @@ -132,7 +132,9 @@ impl TryFrom<&Parameters> for DeclarationEnvironment { } } -/// Renames variables to ensure that variable names are globally unique. +/// Renames variables to ensure that variable names are globally unique. This +/// is done before the CFG is generated to ensure that different variables with +/// the same names are not identified by mistake. /// /// There are a number of different cases to consider. /// @@ -231,6 +233,10 @@ fn visit_statement( } visit_expression(rhe, env); } + MultiSubstitution { lhe, rhe, .. } => { + visit_expression(lhe, env); + visit_expression(rhe, env); + } LogCall { args, .. } => { use LogArgument::*; for arg in args { @@ -312,7 +318,7 @@ fn visit_expression(expr: &mut Expression, env: &DeclarationEnvironment) { visit_expression(arg, env); } } - ArrayInLine { values, .. } => { + Tuple { values, .. } | ArrayInLine { values, .. } => { for value in values { visit_expression(value, env); } @@ -320,15 +326,38 @@ fn visit_expression(expr: &mut Expression, env: &DeclarationEnvironment) { ParallelOp { rhe, .. } => { visit_expression(rhe, env); } + AnonymousComponent { params, signals, names, .. } => { + for param in params { + visit_expression(param, env) + } + for signal in signals { + visit_expression(signal, env) + } + if let Some(names) = names { + for (_, name) in names { + trace!("visiting variable '{name}'"); + *name = match env.get_current_version(name) { + Some(version) => { + trace!( + "renaming occurrence of variable `{name}` to `{name}.{version}`" + ); + format!("{name}.{version}") + } + None => name.clone(), + }; + } + } + } } } fn build_report(name: &str, primary_meta: &Meta, secondary_decl: &Declaration) -> Report { - CFGError::produce_report(CFGError::ShadowingVariableWarning { + CFGError::ShadowingVariableWarning { name: name.to_string(), primary_file_id: primary_meta.file_id, primary_location: primary_meta.file_location(), secondary_file_id: secondary_decl.file_id(), secondary_location: secondary_decl.file_location(), - }) + } + .into() } diff --git a/program_structure/src/intermediate_representation/degree_meta.rs b/program_structure/src/intermediate_representation/degree_meta.rs index 7089e49..ea486c0 100644 --- a/program_structure/src/intermediate_representation/degree_meta.rs +++ b/program_structure/src/intermediate_representation/degree_meta.rs @@ -457,6 +457,8 @@ impl DegreeEnvironment { } /// Sets the degree range of the given variable. Returns true on first update. + /// TODO: Should probably take the supremum of the given range and any + /// existing range. pub fn set_degree(&mut self, var: &VariableName, range: &DegreeRange) -> bool { if self.degree_ranges.insert(var.clone(), range.clone()).is_none() { trace!("setting degree range of `{var:?}` to {range:?}"); diff --git a/program_structure/src/intermediate_representation/expression_impl.rs b/program_structure/src/intermediate_representation/expression_impl.rs index 55d1a33..5880392 100644 --- a/program_structure/src/intermediate_representation/expression_impl.rs +++ b/program_structure/src/intermediate_representation/expression_impl.rs @@ -1,6 +1,6 @@ +#![deny(warnings)] use log::trace; use num_traits::Zero; -use std::collections::HashSet; use std::fmt; use std::hash::{Hash, Hasher}; @@ -392,11 +392,11 @@ impl VariableMeta for Expression { trace!("adding `{name:?}` to local variables read"); locals_read.insert(VariableUse::new(meta, name, &Vec::new())); } - Some(VariableType::Component) => { + Some(VariableType::Component | VariableType::AnonymousComponent) => { trace!("adding `{name:?}` to components read"); components_read.insert(VariableUse::new(meta, name, &Vec::new())); } - Some(VariableType::Signal(_)) => { + Some(VariableType::Signal(_, _)) => { trace!("adding `{name:?}` to signals read"); signals_read.insert(VariableUse::new(meta, name, &Vec::new())); } @@ -442,11 +442,11 @@ impl VariableMeta for Expression { trace!("adding `{var:?}` to local variables read"); locals_read.insert(VariableUse::new(meta, var, access)); } - Some(VariableType::Component) => { + Some(VariableType::Component | VariableType::AnonymousComponent) => { trace!("adding `{var:?}` to components read"); components_read.insert(VariableUse::new(meta, var, access)); } - Some(VariableType::Signal(_)) => { + Some(VariableType::Signal(_, _)) => { trace!("adding `{var:?}` to signals read"); signals_read.insert(VariableUse::new(meta, var, access)); } @@ -478,11 +478,11 @@ impl VariableMeta for Expression { trace!("adding `{var:?}` to local variables read"); locals_read.insert(VariableUse::new(meta, var, &Vec::new())); } - Some(VariableType::Component) => { + Some(VariableType::Component | VariableType::AnonymousComponent) => { trace!("adding `{var:?}` to components read"); components_read.insert(VariableUse::new(meta, var, &Vec::new())); } - Some(VariableType::Signal(_)) => { + Some(VariableType::Signal(_, _)) => { trace!("adding `{var:?}` to signals read"); signals_read.insert(VariableUse::new(meta, var, &Vec::new())); } @@ -535,66 +535,50 @@ impl ValueMeta for Expression { use ValueReduction::*; match self { InfixOp { meta, lhe, infix_op, rhe, .. } => { - let mut result = lhe.propagate_values(env) || rhe.propagate_values(env); - if let Some(value) = infix_op.propagate_values(lhe.value(), rhe.value(), env) { - result = result || meta.value_knowledge_mut().set_reduces_to(value) - } - result + let result = lhe.propagate_values(env) || rhe.propagate_values(env); + let value = infix_op.propagate_values(&lhe.value(), &rhe.value(), env); + result || meta.value_knowledge_mut().set_reduces_to(value) } PrefixOp { meta, prefix_op, rhe } => { - let mut result = rhe.propagate_values(env); - if let Some(value) = prefix_op.propagate_values(rhe.value(), env) { - result = result || meta.value_knowledge_mut().set_reduces_to(value) - } - result + let result = rhe.propagate_values(env); + let value = prefix_op.propagate_values(&rhe.value(), env); + result || meta.value_knowledge_mut().set_reduces_to(value) } SwitchOp { meta, cond, if_true, if_false } => { - let mut result = cond.propagate_values(env) + let result = cond.propagate_values(env) | if_true.propagate_values(env) | if_false.propagate_values(env); - match (cond.value(), if_true.value(), if_false.value()) { - ( - // The case true? value: _ - Some(Boolean { value: cond }), - Some(value), - _, - ) if *cond => { - result = result || meta.value_knowledge_mut().set_reduces_to(value.clone()) - } - ( - // The case false? _: value - Some(Boolean { value: cond }), - _, - Some(value), - ) if !cond => { - result = result || meta.value_knowledge_mut().set_reduces_to(value.clone()) - } - ( - // The case true? value: _ - Some(FieldElement { value: cond }), - Some(value), - _, - ) if !cond.is_zero() => { - result = result || meta.value_knowledge_mut().set_reduces_to(value.clone()) - } - ( - // The case false? _: value - Some(FieldElement { value: cond }), - _, - Some(value), - ) if cond.is_zero() => { - result = result || meta.value_knowledge_mut().set_reduces_to(value.clone()) - } - _ => {} - } result + || match (cond.value(), if_true.value(), if_false.value()) { + (Boolean(cond), t, f) => { + let value = match cond { + Some(true) => t, + Some(false) => f, + None => t.union(&f), + }; + meta.value_knowledge_mut().set_reduces_to(value) + } + + (FieldElement(cond), t, f) => { + let value = match cond.map(|c| !c.is_zero()) { + Some(true) => t, + Some(false) => f, + None => t.union(&f), + }; + + meta.value_knowledge_mut().set_reduces_to(value) + } + + (Unknown, t, f) => meta.value_knowledge_mut().set_reduces_to(t.union(&f)), + + (Impossible, _, _) => meta.value_knowledge_mut().set_reduces_to(Impossible), + } + } + Variable { meta, name, .. } => { + meta.value_knowledge_mut().set_reduces_to(env.get_variable(name)) } - Variable { meta, name, .. } => match env.get_variable(name) { - Some(value) => meta.value_knowledge_mut().set_reduces_to(value.clone()), - None => false, - }, Number(meta, value) => { - let value = FieldElement { value: value.clone() }; + let value = FieldElement(Some(value.clone())); meta.value_knowledge_mut().set_reduces_to(value) } Call { args, .. } => { @@ -634,35 +618,32 @@ impl ValueMeta for Expression { result } Phi { meta, args, .. } => { - // Only set the value of the phi expression if all arguments agree on the value. - let values = - args.iter().map(|name| env.get_variable(name)).collect::>>(); - match values { - Some(values) if values.len() == 1 => { - // This unwrap is safe since the size is non-zero. - let value = *values.iter().next().unwrap(); - meta.value_knowledge_mut().set_reduces_to(value.clone()) - } - _ => false, + // set the value of the phi expression to the union of all + // possible inputs + let mut value = ValueReduction::default(); + for name in args.iter() { + let v = env.get_variable(name); + value = value.union(&v); } + meta.value_knowledge_mut().set_reduces_to(value) } } } fn is_constant(&self) -> bool { - self.value().is_some() + self.value().is_constant() } fn is_boolean(&self) -> bool { - matches!(self.value(), Some(ValueReduction::Boolean { .. })) + matches!(self.value(), ValueReduction::Boolean(Some(_))) } fn is_field_element(&self) -> bool { - matches!(self.value(), Some(ValueReduction::FieldElement { .. })) + matches!(self.value(), ValueReduction::FieldElement(Some(_))) } - fn value(&self) -> Option<&ValueReduction> { - self.meta().value_knowledge().get_reduces_to() + fn value(&self) -> ValueReduction { + self.meta().value_knowledge().clone() } } @@ -703,102 +684,129 @@ impl ExpressionInfixOpcode { fn propagate_values( &self, - lhv: Option<&ValueReduction>, - rhv: Option<&ValueReduction>, + lhv: &ValueReduction, + rhv: &ValueReduction, env: &ValueEnvironment, - ) -> Option { + ) -> ValueReduction { let p = env.prime(); use ValueReduction::*; + match (lhv, rhv) { // lhv and rhv reduce to two field elements. - (Some(FieldElement { value: lhv }), Some(FieldElement { value: rhv })) => { + (FieldElement(Some(lhv)), FieldElement(Some(rhv))) => { use ExpressionInfixOpcode::*; match self { Mul => { let value = modular_arithmetic::mul(lhv, rhv, p); - Some(FieldElement { value }) + FieldElement(Some(value)) } Div => modular_arithmetic::div(lhv, rhv, p) .ok() - .map(|value| FieldElement { value }), + .map(|value| FieldElement(Some(value))) + .unwrap_or(Impossible), Add => { let value = modular_arithmetic::add(lhv, rhv, p); - Some(FieldElement { value }) + FieldElement(Some(value)) } Sub => { let value = modular_arithmetic::sub(lhv, rhv, p); - Some(FieldElement { value }) + FieldElement(Some(value)) } Pow => { let value = modular_arithmetic::pow(lhv, rhv, p); - Some(FieldElement { value }) + FieldElement(Some(value)) } IntDiv => modular_arithmetic::idiv(lhv, rhv, p) .ok() - .map(|value| FieldElement { value }), + .map(|value| FieldElement(Some(value))) + .unwrap_or(Impossible), Mod => modular_arithmetic::mod_op(lhv, rhv, p) .ok() - .map(|value| FieldElement { value }), + .map(|value| FieldElement(Some(value))) + .unwrap_or(Impossible), ShiftL => modular_arithmetic::shift_l(lhv, rhv, p) .ok() - .map(|value| FieldElement { value }), + .map(|value| FieldElement(Some(value))) + .unwrap_or(Impossible), ShiftR => modular_arithmetic::shift_r(lhv, rhv, p) .ok() - .map(|value| FieldElement { value }), + .map(|value| FieldElement(Some(value))) + .unwrap_or(Impossible), LesserEq => { let value = modular_arithmetic::lesser_eq(lhv, rhv, p); - Some(Boolean { value: modular_arithmetic::as_bool(&value, p) }) + Boolean(Some(modular_arithmetic::as_bool(&value, p))) } GreaterEq => { let value = modular_arithmetic::greater_eq(lhv, rhv, p); - Some(Boolean { value: modular_arithmetic::as_bool(&value, p) }) + Boolean(Some(modular_arithmetic::as_bool(&value, p))) } Lesser => { let value = modular_arithmetic::lesser(lhv, rhv, p); - Some(Boolean { value: modular_arithmetic::as_bool(&value, p) }) + Boolean(Some(modular_arithmetic::as_bool(&value, p))) } Greater => { let value = modular_arithmetic::greater(lhv, rhv, p); - Some(Boolean { value: modular_arithmetic::as_bool(&value, p) }) + Boolean(Some(modular_arithmetic::as_bool(&value, p))) } Eq => { let value = modular_arithmetic::eq(lhv, rhv, p); - Some(Boolean { value: modular_arithmetic::as_bool(&value, p) }) + Boolean(Some(modular_arithmetic::as_bool(&value, p))) } NotEq => { let value = modular_arithmetic::not_eq(lhv, rhv, p); - Some(Boolean { value: modular_arithmetic::as_bool(&value, p) }) + Boolean(Some(modular_arithmetic::as_bool(&value, p))) } BitOr => { let value = modular_arithmetic::bit_or(lhv, rhv, p); - Some(FieldElement { value }) + FieldElement(Some(value)) } BitAnd => { let value = modular_arithmetic::bit_and(lhv, rhv, p); - Some(FieldElement { value }) + FieldElement(Some(value)) } BitXor => { let value = modular_arithmetic::bit_xor(lhv, rhv, p); - Some(FieldElement { value }) + FieldElement(Some(value)) } // Remaining operations do not make sense. // TODO: Add report/error propagation here. - _ => None, + _ => Unknown, } } // lhv and rhv reduce to two booleans. - (Some(Boolean { value: lhv }), Some(Boolean { value: rhv })) => { + (Boolean(lhv), Boolean(rhv)) => { use ExpressionInfixOpcode::*; match self { - BoolAnd => Some(Boolean { value: *lhv && *rhv }), - BoolOr => Some(Boolean { value: *lhv || *rhv }), + BoolAnd => Boolean(match (lhv, rhv) { + (Some(true), r) => *r, + (Some(false), _) => Some(false), + (None, Some(false)) => Some(false), + _ => None, + }), + BoolOr => Boolean(match (lhv, rhv) { + (Some(false), r) => *r, + (Some(true), _) => Some(true), + (None, Some(true)) => Some(true), + _ => None, + }), // Remaining operations do not make sense. // TODO: Add report propagation here as well. - _ => None, + _ => Unknown, + } + } + _ => { + use ExpressionInfixOpcode::*; + // TODO: should we check the input types? + match self { + Mul | Div | Add | Sub | Pow | IntDiv | Mod | ShiftL | ShiftR | BitOr + | BitAnd | BitXor => FieldElement(None), + + LesserEq | GreaterEq | Lesser | Greater | Eq | NotEq | BoolAnd | BoolOr => { + Boolean(None) + } } } - _ => None, } } } @@ -817,43 +825,37 @@ impl ExpressionPrefixOpcode { } } - fn propagate_values( - &self, - rhe: Option<&ValueReduction>, - env: &ValueEnvironment, - ) -> Option { + fn propagate_values(&self, rhe: &ValueReduction, env: &ValueEnvironment) -> ValueReduction { let p = env.prime(); use ValueReduction::*; match rhe { // arg reduces to a field element. - Some(FieldElement { value: arg }) => { + FieldElement(arg) => { use ExpressionPrefixOpcode::*; match self { Sub => { - let value = modular_arithmetic::prefix_sub(arg, p); - Some(FieldElement { value }) - } - Complement => { - let value = modular_arithmetic::complement_256(arg, p); - Some(FieldElement { value }) + FieldElement(arg.as_ref().map(|arg| modular_arithmetic::prefix_sub(arg, p))) } + Complement => FieldElement( + arg.as_ref().map(|arg| modular_arithmetic::complement_256(arg, p)), + ), // Remaining operations do not make sense. // TODO: Add report propagation here as well. - _ => None, + _ => Unknown, } } // arg reduces to a boolean. - Some(Boolean { value: arg }) => { + Boolean(arg) => { use ExpressionPrefixOpcode::*; match self { - BoolNot => Some(Boolean { value: !arg }), + BoolNot => Boolean(arg.map(|x| !x)), // Remaining operations do not make sense. // TODO: Add report propagation here as well. - _ => None, + _ => Unknown, } } - None => None, + _ => Unknown, } } } @@ -1014,10 +1016,10 @@ mod tests { use ExpressionInfixOpcode::*; use ValueReduction::*; let mut lhe = Number(Meta::default(), 7u64.into()); - let mut rhe = Variable { meta: Meta::default(), name: VariableName::from_name("v") }; + let mut rhe = Variable { meta: Meta::default(), name: VariableName::from_string("v") }; let constants = UsefulConstants::new(&Curve::default()); let mut env = ValueEnvironment::new(&constants); - env.add_variable(&VariableName::from_name("v"), &FieldElement { value: 3u64.into() }); + env.add_variable(&VariableName::from_string("v"), &FieldElement(Some(3u64.into()))); lhe.propagate_values(&mut env); rhe.propagate_values(&mut env); @@ -1029,7 +1031,7 @@ mod tests { rhe: Box::new(rhe.clone()), }; expr.propagate_values(&mut env.clone()); - assert_eq!(expr.value(), Some(&FieldElement { value: 21u64.into() })); + assert_eq!(expr.value(), FieldElement(Some(21u64.into()))); // Infix addition. let mut expr = InfixOp { @@ -1039,7 +1041,7 @@ mod tests { rhe: Box::new(rhe.clone()), }; expr.propagate_values(&mut env.clone()); - assert_eq!(expr.value(), Some(&FieldElement { value: 10u64.into() })); + assert_eq!(expr.value(), FieldElement(Some(10u64.into()))); // Infix integer division. let mut expr = InfixOp { @@ -1049,6 +1051,6 @@ mod tests { rhe: Box::new(rhe.clone()), }; expr.propagate_values(&mut env.clone()); - assert_eq!(expr.value(), Some(&FieldElement { value: 2u64.into() })); + assert_eq!(expr.value(), FieldElement(Some(2u64.into()))); } } diff --git a/program_structure/src/intermediate_representation/ir.rs b/program_structure/src/intermediate_representation/ir.rs index c171679..5f7e1de 100644 --- a/program_structure/src/intermediate_representation/ir.rs +++ b/program_structure/src/intermediate_representation/ir.rs @@ -6,7 +6,7 @@ use crate::nonempty_vec::NonEmptyVec; use super::degree_meta::DegreeKnowledge; use super::type_meta::TypeKnowledge; -use super::value_meta::ValueKnowledge; +use super::value_meta::ValueReduction; use super::variable_meta::VariableKnowledge; type Index = usize; @@ -18,7 +18,7 @@ pub struct Meta { pub file_id: Option, degree_knowledge: DegreeKnowledge, type_knowledge: TypeKnowledge, - value_knowledge: ValueKnowledge, + value_knowledge: ValueReduction, variable_knowledge: VariableKnowledge, } @@ -30,7 +30,7 @@ impl Meta { file_id: *file_id, degree_knowledge: DegreeKnowledge::default(), type_knowledge: TypeKnowledge::default(), - value_knowledge: ValueKnowledge::default(), + value_knowledge: ValueReduction::default(), variable_knowledge: VariableKnowledge::default(), } } @@ -66,7 +66,7 @@ impl Meta { } #[must_use] - pub fn value_knowledge(&self) -> &ValueKnowledge { + pub fn value_knowledge(&self) -> &ValueReduction { &self.value_knowledge } @@ -86,7 +86,7 @@ impl Meta { } #[must_use] - pub fn value_knowledge_mut(&mut self) -> &mut ValueKnowledge { + pub fn value_knowledge_mut(&mut self) -> &mut ValueReduction { &mut self.value_knowledge } @@ -205,11 +205,14 @@ pub enum Expression { Phi { meta: Meta, args: Vec }, } +pub type TagList = Vec; + #[derive(Clone, PartialEq, Eq, Hash)] pub enum VariableType { Local, Component, - Signal(SignalType), + AnonymousComponent, + Signal(SignalType, TagList), } impl fmt::Display for VariableType { @@ -218,12 +221,17 @@ impl fmt::Display for VariableType { use VariableType::*; match self { Local => write!(f, "var"), - Component => write!(f, "component"), - Signal(signal_type) => { + AnonymousComponent | Component => write!(f, "component"), + Signal(signal_type, tag_list) => { if matches!(signal_type, Intermediate) { - write!(f, "signal") + write!(f, "signal")?; + } else { + write!(f, "signal {signal_type}")?; + } + if !tag_list.is_empty() { + write!(f, " {{{}}}", tag_list.join(", ")) } else { - write!(f, "signal {signal_type}") + Ok(()) } } } @@ -270,7 +278,7 @@ pub struct VariableName { impl VariableName { /// Returns a new variable name with the given name (without suffix or version). #[must_use] - pub fn from_name(name: N) -> VariableName { + pub fn from_string(name: N) -> VariableName { VariableName { name: name.to_string(), suffix: None, version: None } } @@ -359,7 +367,7 @@ pub enum AssignOp { AssignLocalOrComponent, } -#[derive(Copy, Clone, Hash, Eq, PartialEq)] +#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)] pub enum ExpressionInfixOpcode { Mul, Div, diff --git a/program_structure/src/intermediate_representation/lifting.rs b/program_structure/src/intermediate_representation/lifting.rs index 96d27af..029fd58 100644 --- a/program_structure/src/intermediate_representation/lifting.rs +++ b/program_structure/src/intermediate_representation/lifting.rs @@ -45,9 +45,9 @@ impl From for Declarations { } // Attempt to convert an AST statement into an IR statement. This will fail on -// statements that need to be handled manually (`While` and `IfThenElse`), as -// well as statements that have no direct IR counterparts (like `Declaration`, -// `Block` and `InitializationBlock`). +// statements that need to be handled manually (like `While`, `IfThenElse`, and +// `MultiSubstitution`), as well as statements that have no direct IR +// counterparts (like `Declaration`, `Block`, and `InitializationBlock`). impl TryLift<()> for ast::Statement { type IR = ir::Statement; type Error = IRError; @@ -113,6 +113,7 @@ impl TryLift<()> for ast::Statement { ast::Statement::Block { .. } | ast::Statement::While { .. } | ast::Statement::IfThenElse { .. } + | ast::Statement::MultiSubstitution { .. } | ast::Statement::InitializationBlock { .. } => { // These need to be handled by the caller. panic!("failed to convert AST statement to IR") @@ -121,7 +122,9 @@ impl TryLift<()> for ast::Statement { } } -// Attempt to convert an AST expression to an IR expression. +// Attempt to convert an AST expression to an IR expression. This will fail on +// expressions that need to be handled directly by the caller (like `Tuple` and +// `AnonymousComponent`). impl TryLift<()> for ast::Expression { type IR = ir::Expression; type Error = IRError; @@ -185,6 +188,10 @@ impl TryLift<()> for ast::Expression { // TODO: We currently treat `ParallelOp` as transparent and simply // lift the underlying expression. Should this be added to the IR? ast::Expression::ParallelOp { rhe, .. } => rhe.try_lift((), reports), + ast::Expression::Tuple { .. } | ast::Expression::AnonymousComponent { .. } => { + // These need to be handled by the caller. + panic!("failed to convert AST expression to IR") + } } } } @@ -206,10 +213,11 @@ impl TryLift<()> for ast::VariableType { fn try_lift(&self, _: (), reports: &mut ReportCollection) -> IRResult { match self { - ast::VariableType::Component => Ok(ir::VariableType::Component), ast::VariableType::Var => Ok(ir::VariableType::Local), - ast::VariableType::Signal(signal_type, _) => { - Ok(ir::VariableType::Signal(signal_type.try_lift((), reports)?)) + ast::VariableType::Component => Ok(ir::VariableType::Component), + ast::VariableType::AnonymousComponent => Ok(ir::VariableType::AnonymousComponent), + ast::VariableType::Signal(signal_type, tag_list) => { + Ok(ir::VariableType::Signal(signal_type.try_lift((), reports)?, tag_list.clone())) } } } @@ -238,8 +246,8 @@ impl TryLift<&ast::Meta> for String { // We assume that the input string uses '.' to separate the name from the suffix. let tokens: Vec<_> = self.split('.').collect(); match tokens.len() { - 1 => Ok(ir::VariableName::from_name(tokens[0])), - 2 => Ok(ir::VariableName::from_name(tokens[0]).with_suffix(tokens[1])), + 1 => Ok(ir::VariableName::from_string(tokens[0])), + 2 => Ok(ir::VariableName::from_string(tokens[0]).with_suffix(tokens[1])), // Either the original name from the AST contains `.`, or the suffix // added when ensuring uniqueness contains `.`. Neither case should // occur, so we return an error here instead of producing a report. diff --git a/program_structure/src/intermediate_representation/statement_impl.rs b/program_structure/src/intermediate_representation/statement_impl.rs index 8a14ac6..28fd862 100644 --- a/program_structure/src/intermediate_representation/statement_impl.rs +++ b/program_structure/src/intermediate_representation/statement_impl.rs @@ -47,7 +47,7 @@ impl Statement { Declaration { names, var_type, .. } => { for name in names.iter() { // Since we disregard accesses, components are treated as signals. - if matches!(var_type, Signal(_) | Component) { + if matches!(var_type, Signal(_, _) | Component | AnonymousComponent) { result = result || env.set_degree(name, &Linear.into()); } env.set_type(name, var_type); @@ -95,15 +95,13 @@ impl Statement { } result } - Substitution { meta, var, rhe, .. } => { + Substitution { var, rhe, .. } => { let mut result = rhe.propagate_values(env); // TODO: Handle array values. if !matches!(rhe, Update { .. }) { - if let Some(value) = rhe.value() { - env.add_variable(var, value); - result = result || meta.value_knowledge_mut().set_reduces_to(value.clone()); - } + let value = rhe.value(); + result = result || env.add_variable(var, &value) } trace!("Substitution returned {result}"); result @@ -304,7 +302,7 @@ impl VariableMeta for Statement { trace!("adding `{var:?}` to local variables written"); locals_written.insert(VariableUse::new(meta, var, &access)); } - Some(VariableType::Signal(_)) => { + Some(VariableType::Signal(_, _)) => { trace!("adding `{var:?}` to signals written"); signals_written.insert(VariableUse::new(meta, var, &access)); if matches!(op, AssignOp::AssignConstraintSignal) { @@ -314,7 +312,7 @@ impl VariableMeta for Statement { signals_read.insert(VariableUse::new(meta, var, &access)); } } - Some(VariableType::Component) => { + Some(VariableType::Component | VariableType::AnonymousComponent) => { trace!("adding `{var:?}` to components written"); components_written.insert(VariableUse::new(meta, var, &access)); } diff --git a/program_structure/src/intermediate_representation/type_meta.rs b/program_structure/src/intermediate_representation/type_meta.rs index 803b1fd..65ca254 100644 --- a/program_structure/src/intermediate_representation/type_meta.rs +++ b/program_structure/src/intermediate_representation/type_meta.rs @@ -34,10 +34,13 @@ impl TypeKnowledge { TypeKnowledge::default() } + // Sets the variable type of a node representing a variable. pub fn set_variable_type(&mut self, var_type: &VariableType) { self.var_type = Some(var_type.clone()); } + /// For declared variables, this returns the type. For undeclared variables + /// and other expression nodes this returns `None`. #[must_use] pub fn variable_type(&self) -> Option<&VariableType> { self.var_type.as_ref() @@ -52,12 +55,12 @@ impl TypeKnowledge { /// Returns true if the node is a signal. #[must_use] pub fn is_signal(&self) -> bool { - matches!(self.var_type, Some(VariableType::Signal(_))) + matches!(self.var_type, Some(VariableType::Signal(_, _))) } - /// Returns true if the node is a component. + /// Returns true if the node is a (possibly anonymous) component. #[must_use] pub fn is_component(&self) -> bool { - matches!(self.var_type, Some(VariableType::Component)) + matches!(self.var_type, Some(VariableType::Component | VariableType::AnonymousComponent)) } } diff --git a/program_structure/src/intermediate_representation/value_meta.rs b/program_structure/src/intermediate_representation/value_meta.rs index 07f87ed..8386401 100644 --- a/program_structure/src/intermediate_representation/value_meta.rs +++ b/program_structure/src/intermediate_representation/value_meta.rs @@ -1,4 +1,6 @@ -use num_bigint::BigInt; +#![deny(warnings)] +use num_bigint::{BigInt, Sign}; +use num_traits::Zero; use std::collections::HashMap; use std::fmt; @@ -17,24 +19,21 @@ impl ValueEnvironment { ValueEnvironment { constants: constants.clone(), reduces_to: HashMap::new() } } - /// Set the value of the given variable. Returns `true` on first update. - /// - /// # Panics - /// - /// This function panics if the caller attempts to set two different values - /// for the same variable. + /// Set the value of the given variable. Returns `true` on updates. pub fn add_variable(&mut self, name: &VariableName, value: &ValueReduction) -> bool { - if let Some(previous) = self.reduces_to.insert(name.clone(), value.clone()) { - assert_eq!(previous, *value); - false - } else { + let prev_value = self.reduces_to.get(name).cloned().unwrap_or_default(); + let new_value = prev_value.intersect(value); + if new_value != prev_value { + self.reduces_to.insert(name.clone(), new_value); true + } else { + false } } #[must_use] - pub fn get_variable(&self, name: &VariableName) -> Option<&ValueReduction> { - self.reduces_to.get(name) + pub fn get_variable(&self, name: &VariableName) -> ValueReduction { + self.reduces_to.get(name).cloned().unwrap_or_default() } /// Returns the prime used. @@ -62,68 +61,122 @@ pub trait ValueMeta { /// Returns the value if the node reduces to a constant, and None otherwise. #[must_use] - fn value(&self) -> Option<&ValueReduction>; + fn value(&self) -> ValueReduction; } #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum ValueReduction { - Boolean { value: bool }, - FieldElement { value: BigInt }, + Unknown, + Boolean(Option), + FieldElement(Option), + Impossible, +} + +impl Default for ValueReduction { + fn default() -> Self { + Self::Unknown + } } impl fmt::Display for ValueReduction { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { use ValueReduction::*; match self { - Boolean { value } => write!(f, "{value}"), - FieldElement { value } => write!(f, "{value}"), + Boolean(Some(value)) => write!(f, "{value}"), + Boolean(None) => write!(f, ""), + FieldElement(Some(value)) => write!(f, "{value}"), + FieldElement(None) => write!(f, ""), + Unknown => write!(f, ""), + Impossible => write!(f, ""), } } } -#[derive(Default, Clone)] -pub struct ValueKnowledge { - reduces_to: Option, -} - -impl ValueKnowledge { +impl ValueReduction { #[must_use] - pub fn new() -> ValueKnowledge { - ValueKnowledge::default() + pub fn new() -> Self { + Self::default() } - /// Sets the value of the node. Returns `true` on the first update. - #[must_use] - pub fn set_reduces_to(&mut self, reduces_to: ValueReduction) -> bool { - let result = self.reduces_to.is_none(); - self.reduces_to = Some(reduces_to); - result + // if we know a variable is either `a` OR `b`, then our overall + // knowledge is `a.union(b)` + pub fn union(&self, b: &Self) -> Self { + use ValueReduction::*; + match (self, b) { + (Unknown, _) => Unknown, + (_, Unknown) => Unknown, + + (l, Impossible) => l.clone(), + (Impossible, r) => r.clone(), + + (Boolean(_), FieldElement(_)) => Unknown, + (FieldElement(_), Boolean(_)) => Unknown, + + (FieldElement(av), FieldElement(bv)) if av == bv => FieldElement(av.clone()), + (FieldElement(_), FieldElement(_)) => FieldElement(None), + + (Boolean(av), Boolean(bv)) if av == bv => Boolean(*av), + (Boolean(_), Boolean(_)) => Boolean(None), + } } - /// Gets the value of the node. Returns `None` if the value is unknown. + // if we know a variable is both `a` AND `b`, then our overall + // knowledge is `a.intersect(b)` + pub fn intersect(&self, b: &Self) -> Self { + use ValueReduction::*; + + let bool_felt_merge = |b: &Option, fe: &Option| match (b, fe) { + // TODO: does this make sense? Should `true` be treated as + // incompatible with `1`? It seems like it shouldn't be. + (Some(false), Some(n)) if n.is_zero() => Unknown, + (Some(true), Some(n)) if n == &BigInt::from_bytes_le(Sign::Plus, &[1]) => Unknown, + (Some(_), Some(_)) => Impossible, + _ => Unknown, + }; + + match (self, b) { + (Unknown, r) => r.clone(), + (l, Unknown) => l.clone(), + + (_, Impossible) => Impossible, + (Impossible, _) => Impossible, + + (Boolean(b), FieldElement(fe)) => bool_felt_merge(b, fe), + (FieldElement(fe), Boolean(b)) => bool_felt_merge(b, fe), + + (FieldElement(av), FieldElement(bv)) if av == bv => FieldElement(av.clone()), + (FieldElement(_), FieldElement(_)) => Impossible, + + (Boolean(av), Boolean(bv)) if av == bv => Boolean(*av), + (Boolean(_), Boolean(_)) => Impossible, + } + } + + /// Restricts the value of the node. Returns `true` on the first update. #[must_use] - pub fn get_reduces_to(&self) -> Option<&ValueReduction> { - self.reduces_to.as_ref() + pub fn set_reduces_to(&mut self, reduces_to: ValueReduction) -> bool { + let new_val = self.intersect(&reduces_to); + let result = self != &new_val; + *self = new_val; + result } /// Returns `true` if the value of the node is known. #[must_use] pub fn is_constant(&self) -> bool { - self.reduces_to.is_some() + matches!(self, Self::FieldElement(Some(_)) | Self::Boolean(Some(_))) } /// Returns `true` if the value of the node is a boolean. #[must_use] pub fn is_boolean(&self) -> bool { - use ValueReduction::*; - matches!(self.reduces_to, Some(Boolean { .. })) + matches!(self, ValueReduction::Boolean(_)) } /// Returns `true` if the value of the node is a field element. #[must_use] pub fn is_field_element(&self) -> bool { - use ValueReduction::*; - matches!(self.reduces_to, Some(FieldElement { .. })) + matches!(self, ValueReduction::FieldElement(_)) } } @@ -133,23 +186,20 @@ mod tests { use crate::ir::value_meta::ValueReduction; - use super::ValueKnowledge; - #[test] fn test_value_knowledge() { - let mut value = ValueKnowledge::new(); - assert!(matches!(value.get_reduces_to(), None)); + use ValueReduction::*; + let mut value = ValueReduction::new(); + assert!(matches!(value, Unknown)); - let number = ValueReduction::FieldElement { value: BigInt::from(1) }; + let number = ValueReduction::FieldElement(Some(BigInt::from(1))); assert!(value.set_reduces_to(number)); - assert!(matches!(value.get_reduces_to(), Some(ValueReduction::FieldElement { .. }))); + assert!(matches!(value, ValueReduction::FieldElement(Some(_)))); assert!(value.is_field_element()); assert!(!value.is_boolean()); - let boolean = ValueReduction::Boolean { value: true }; - assert!(!value.set_reduces_to(boolean)); - assert!(matches!(value.get_reduces_to(), Some(ValueReduction::Boolean { .. }))); - assert!(!value.is_field_element()); - assert!(value.is_boolean()); + let boolean = ValueReduction::Boolean(Some(true)); + assert!(value.set_reduces_to(boolean)); + assert!(matches!(value, ValueReduction::Unknown)); } } diff --git a/program_structure/src/program_library/program_archive.rs b/program_structure/src/program_library/program_archive.rs index a330836..1636399 100644 --- a/program_structure/src/program_library/program_archive.rs +++ b/program_structure/src/program_library/program_archive.rs @@ -126,7 +126,7 @@ impl ProgramArchive { pub fn get_public_inputs_main_component(&self) -> &Vec { &self.public_inputs } - pub fn get_main_expression(&self) -> &Expression { + pub fn main_expression(&self) -> &Expression { &self.initial_template_call } // FileLibrary functions diff --git a/program_structure/src/program_library/report_code.rs b/program_structure/src/program_library/report_code.rs index b1c4185..dd20b5b 100644 --- a/program_structure/src/program_library/report_code.rs +++ b/program_structure/src/program_library/report_code.rs @@ -62,6 +62,8 @@ pub enum ReportCode { NonQuadratic, NonConstantArrayLength, NonComputableExpression, + AnonymousComponentError, + TupleError, // Constraint analysis codes UnconstrainedSignal, OneConstraintIntermediate, @@ -85,6 +87,7 @@ pub enum ReportCode { UnconstrainedDivision, BN128SpecificCircuit, UnderConstrainedSignal, + UnusedOutputSignal, } impl ReportCode { @@ -151,6 +154,8 @@ impl ReportCode { NonConstantArrayLength => "T20463", NonComputableExpression => "T20464", WrongNumberOfArguments(..) => "T20465", + AnonymousComponentError => "TAC01", + TupleError => "TAC02", // Constraint analysis codes UnconstrainedSignal => "CA01", OneConstraintIntermediate => "CA02", @@ -174,6 +179,7 @@ impl ReportCode { UnconstrainedDivision => "CS0015", BN128SpecificCircuit => "CS0016", UnderConstrainedSignal => "CS0017", + UnusedOutputSignal => "CS0018", } .to_string() } @@ -186,6 +192,8 @@ impl ReportCode { CompilerVersionError => "compiler-version-error", WrongTypesInAssignOperation => "wrong-types-in-assign-operation", WrongNumberOfArguments(..) => "wrong-number-of-arguments", + AnonymousComponentError => "anonymous-component-error", + TupleError => "tuple-error", UndefinedFunction => "undefined-function", UndefinedTemplate => "undefined-template", UninitializedSymbolInExpression => "uninitialized-symbol-in-expression", @@ -262,6 +270,7 @@ impl ReportCode { UnconstrainedDivision => "unconstrained-division", BN128SpecificCircuit => "bn128-specific-circuit", UnderConstrainedSignal => "under-constrained-signal", + UnusedOutputSignal => "unused-output-signal", } .to_string() } @@ -282,7 +291,7 @@ impl ReportCode { TooManyArguments => Some("overly-complex-function-or-template"), UnnecessarySignalAssignment => Some("unnecessary-signal-assignment"), UnconstrainedLessThan => Some("unconstrained-less-than"), - UnconstrainedDivision => Some("unconstrained-devision"), + UnconstrainedDivision => Some("unconstrained-division"), BN128SpecificCircuit => Some("bn128-specific-circuit"), UnderConstrainedSignal => Some("under-constrained-signal"), // We only provide a URL for Circomspect specific issues. diff --git a/program_structure/src/program_library/template_data.rs b/program_structure/src/program_library/template_data.rs index 6b4dcfb..60dc364 100644 --- a/program_structure/src/program_library/template_data.rs +++ b/program_structure/src/program_library/template_data.rs @@ -1,11 +1,13 @@ use super::ast; -use super::ast::{FillMeta, SignalElementType, Statement}; +use super::ast::{FillMeta, Statement}; use super::file_definition::FileID; use crate::file_definition::FileLocation; -use std::collections::hash_map::HashMap; +use std::collections::{HashMap, HashSet, BTreeMap}; +pub type TagInfo = HashSet; pub type TemplateInfo = HashMap; -type SignalInfo = HashMap; +type SignalInfo = BTreeMap; +type SignalDeclarationOrder = Vec<(String, usize)>; #[derive(Clone)] pub struct TemplateData { @@ -19,6 +21,9 @@ pub struct TemplateData { output_signals: SignalInfo, is_parallel: bool, is_custom_gate: bool, + // Only used to know the order in which signals are declared. + input_declarations: SignalDeclarationOrder, + output_declarations: SignalDeclarationOrder, } impl TemplateData { @@ -37,7 +42,15 @@ impl TemplateData { body.fill(file_id, elem_id); let mut input_signals = SignalInfo::new(); let mut output_signals = SignalInfo::new(); - fill_inputs_and_outputs(&body, &mut input_signals, &mut output_signals); + let mut input_declarations = SignalDeclarationOrder::new(); + let mut output_declarations = SignalDeclarationOrder::new(); + fill_inputs_and_outputs( + &body, + &mut input_signals, + &mut output_signals, + &mut input_declarations, + &mut output_declarations, + ); TemplateData { name, file_id, @@ -49,42 +62,54 @@ impl TemplateData { output_signals, is_parallel, is_custom_gate, + input_declarations, + output_declarations, } } + pub fn get_file_id(&self) -> FileID { self.file_id } + pub fn get_body(&self) -> &Statement { &self.body } + pub fn get_body_as_vec(&self) -> &Vec { match &self.body { Statement::Block { stmts, .. } => stmts, _ => panic!("Function body should be a block"), } } + pub fn get_mut_body(&mut self) -> &mut Statement { &mut self.body } + pub fn get_mut_body_as_vec(&mut self) -> &mut Vec { match &mut self.body { Statement::Block { stmts, .. } => stmts, _ => panic!("Function body should be a block"), } } + pub fn get_num_of_params(&self) -> usize { self.num_of_params } + pub fn get_param_location(&self) -> FileLocation { self.param_location.clone() } + pub fn get_name_of_params(&self) -> &Vec { &self.name_of_params } - pub fn get_input_info(&self, name: &str) -> Option<&(usize, SignalElementType)> { + + pub fn get_input_info(&self, name: &str) -> Option<&(usize, TagInfo)> { self.input_signals.get(name) } - pub fn get_output_info(&self, name: &str) -> Option<&(usize, SignalElementType)> { + + pub fn get_output_info(&self, name: &str) -> Option<&(usize, TagInfo)> { self.output_signals.get(name) } pub fn get_inputs(&self) -> &SignalInfo { @@ -93,6 +118,12 @@ impl TemplateData { pub fn get_outputs(&self) -> &SignalInfo { &self.output_signals } + pub fn get_declaration_inputs(&self) -> &SignalDeclarationOrder { + &self.input_declarations + } + pub fn get_declaration_outputs(&self) -> &SignalDeclarationOrder { + &self.output_declarations + } pub fn get_name(&self) -> &str { &self.name } @@ -108,41 +139,80 @@ fn fill_inputs_and_outputs( template_statement: &Statement, input_signals: &mut SignalInfo, output_signals: &mut SignalInfo, + input_declarations: &mut SignalDeclarationOrder, + output_declarations: &mut SignalDeclarationOrder, ) { match template_statement { Statement::IfThenElse { if_case, else_case, .. } => { - fill_inputs_and_outputs(if_case, input_signals, output_signals); + fill_inputs_and_outputs( + if_case, + input_signals, + output_signals, + input_declarations, + output_declarations, + ); if let Option::Some(else_value) = else_case { - fill_inputs_and_outputs(else_value, input_signals, output_signals); + fill_inputs_and_outputs( + else_value, + input_signals, + output_signals, + input_declarations, + output_declarations, + ); } } Statement::Block { stmts, .. } => { for stmt in stmts.iter() { - fill_inputs_and_outputs(stmt, input_signals, output_signals); + fill_inputs_and_outputs( + stmt, + input_signals, + output_signals, + input_declarations, + output_declarations, + ); } } Statement::While { stmt, .. } => { - fill_inputs_and_outputs(stmt, input_signals, output_signals); + fill_inputs_and_outputs( + stmt, + input_signals, + output_signals, + input_declarations, + output_declarations, + ); } Statement::InitializationBlock { initializations, .. } => { for initialization in initializations.iter() { - fill_inputs_and_outputs(initialization, input_signals, output_signals); + fill_inputs_and_outputs( + initialization, + input_signals, + output_signals, + input_declarations, + output_declarations, + ); } } Statement::Declaration { - xtype: ast::VariableType::Signal(stype, tag), + xtype: ast::VariableType::Signal(stype, tag_list), name, dimensions, .. } => { let signal_name = name.clone(); - let dim = dimensions.len(); + let dimensions = dimensions.len(); + let mut tag_info = HashSet::new(); + for tag in tag_list { + tag_info.insert(tag.clone()); + } + match stype { ast::SignalType::Input => { - input_signals.insert(signal_name, (dim, *tag)); + input_signals.insert(signal_name.clone(), (dimensions, tag_info)); + input_declarations.push((signal_name, dimensions)); } ast::SignalType::Output => { - output_signals.insert(signal_name, (dim, *tag)); + output_signals.insert(signal_name.clone(), (dimensions, tag_info)); + output_declarations.push((signal_name, dimensions)); } _ => {} //no need to deal with intermediate signals } diff --git a/program_structure_tests/Cargo.toml b/program_structure_tests/Cargo.toml index ee79e1e..c8a9f81 100644 --- a/program_structure_tests/Cargo.toml +++ b/program_structure_tests/Cargo.toml @@ -1,12 +1,13 @@ [package] name = "circomspect-program-structure-tests" -version = "0.6.1" +version = "0.8.0" edition = "2021" +rust-version = "1.65" [dependencies] -parser = { package = "circomspect-parser", version = "2.0.10", path = "../parser" } -program_structure = { package = "circomspect-program-structure", version = "2.0.10", path = "../program_structure"} +parser = { package = "circomspect-parser", version = "2.1.2", path = "../parser" } +program_structure = { package = "circomspect-program-structure", version = "2.1.2", path = "../program_structure" } [dev-dependencies] -parser = { package = "circomspect-parser", version = "2.0.10", path = "../parser" } -program_structure = { package = "circomspect-program-structure", version = "2.0.10", path = "../program_structure"} +parser = { package = "circomspect-parser", version = "2.1.2", path = "../parser" } +program_structure = { package = "circomspect-program-structure", version = "2.1.2", path = "../program_structure" } diff --git a/program_structure_tests/src/control_flow_graph.rs b/program_structure_tests/src/control_flow_graph.rs index 54c9f9d..44afb45 100644 --- a/program_structure_tests/src/control_flow_graph.rs +++ b/program_structure_tests/src/control_flow_graph.rs @@ -528,8 +528,8 @@ fn lift(name: &str) -> VariableName { // We assume that the input string uses '.' to separate the name from the suffix. let tokens: Vec<_> = name.split('.').collect(); match tokens.len() { - 1 => VariableName::from_name(tokens[0]), - 2 => VariableName::from_name(tokens[0]).with_suffix(tokens[1]), + 1 => VariableName::from_string(tokens[0]), + 2 => VariableName::from_string(tokens[0]).with_suffix(tokens[1]), _ => panic!("invalid variable name"), } }