diff --git a/.gitignore b/.gitignore index e2a8c8ed..edcd046b 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,12 @@ obj/table.c.o /regr_smlp/data/smlp_s2_tx.csv /regr_smlp/master/Test130_*.* /regr_smlp/master/Test131_*.* +/result +/data +.idea +src/logs.log +*.pb +*.h5 +*.onnx +variables/ +saved_model/ diff --git a/Extending_SMLP_with_solvers_for_neural_network_verification___Konstantinos_Konstantopoulos (3).pdf b/Extending_SMLP_with_solvers_for_neural_network_verification___Konstantinos_Konstantopoulos (3).pdf new file mode 100644 index 00000000..124c1e95 Binary files /dev/null and b/Extending_SMLP_with_solvers_for_neural_network_verification___Konstantinos_Konstantopoulos (3).pdf differ diff --git a/Marabou-Abstract_Solvers.md b/Marabou-Abstract_Solvers.md new file mode 100644 index 00000000..2807caae --- /dev/null +++ b/Marabou-Abstract_Solvers.md @@ -0,0 +1,36 @@ + +# INFO: File Structure + +- src/smlp_py/NN_verifiers: contains scripts to test the marabou models and verify the validity of the conversion of pb files into h5 files. +- src/smlp_py/marabou: helper files that contain examples of maraboupy commands. +- src/smlp_py/smtlib: + - parser.py & smt_to_pysmt.py contain helper functions. They are not used in SMLP's pipeline. + - text_to_sympy.py: Contains all the logic of converting, simplifying and reformatting expressions to a state that is easily translated into marabou expressions. +- src/smlp_py/solvers: This folder will contain all the logic of the external neural network verifiers that are going to be integrated into the SMLP pipeline. +- src/smlp_py/vnnlib: Contains the logic for for the need in the future to utilise the VNNLIB format (solver agnostic, several solvers support this format) to interact with the solvers. + + +# INFO: The abstract solver +abstract_solver.py defines an abstract solver class that is used to interface all the functionalities that all integrated solvers must support. This is because the main flow has been updated to reference the abstract solver functionalities, and thus all functions must be overridden by every new solver. + +Some methods are optionally overridden as their content usually cover most use cases. + +# HOW: Integrating multiple solvers +In the solvers/ folder, each solver must have it own subfolder. Currently, z3 and marabou are the only supported solvers. Each solver must have 2 files: +1) Operations.py : this file contains the functions that are utilised during the formula building and processing part of the workflow. For example, the method #smlp_and contains the conjunction logic utilised within the formula building phase, in marabou's case, the conjunction operations take place by using pysmt. Whereas is z3's case, the operators library is used to manage the formulas. +2) solver.py : This file contains the class that extends our abstract solver and operations class. Consequently, it will have to override all the functions mentioned in the abstract solver class in order to function properly. + + +# INFO: How the universal solver class is used in the main worklfow +The current PR contains changes in multiple files that contain core functionalities used in the main SMLP workflow. Re-usable parts of the flow or certain cases that are handled differently for each solver have been moved into the z3 solver and have been replaced with the universal solver class's functions. +The universal solver (universal_solver.py) acts as the intermediary and eventually point to the specified solver (marabou, z3), depending on the given "version" argument. Possible values are: + +``` +Solver.Version.PYSMT +Solver.Version.FORM2 + +``` + +# INFO: Pysmt processing +The file #text_to_sympy .py contains all the logic required to transform the formulas into marabou queries. +My dissertation can be used as a reference point to understand the underlying methodologies used inside each function. \ No newline at end of file diff --git a/src/smlp_py/NN_verifiers/compare_models.py b/src/smlp_py/NN_verifiers/compare_models.py new file mode 100755 index 00000000..b9187c82 --- /dev/null +++ b/src/smlp_py/NN_verifiers/compare_models.py @@ -0,0 +1,75 @@ +import tensorflow as tf +import numpy as np +import h5py + + +def read_h5_weights(h5_file_path): + with h5py.File(h5_file_path, 'r') as f: + weights = {} + for layer_name in f.keys(): + layer = f[layer_name] + for weight_name in layer.keys(): + weights[layer_name + '/' + weight_name] = np.array(layer[weight_name]) + return weights + + +def verify_pb_file(pb_file_path): + # Verify the protobuf file is valid + try: + with tf.io.gfile.GFile(pb_file_path, "rb") as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + return True + except tf.errors.InvalidArgumentError as e: + print(f"Error verifying the PB file: {e}") + return False + + +def read_pb_weights(pb_file_path): + # Verify the file first + if not verify_pb_file(pb_file_path): + raise ValueError(f"The file at {pb_file_path} is not a valid TensorFlow protobuf file.") + + # Load the protobuf graph + graph_def = tf.compat.v1.GraphDef() + with tf.io.gfile.GFile(pb_file_path, "rb") as f: + graph_def.ParseFromString(f.read()) + + # Import the graph and get the weights + with tf.compat.v1.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name="") + + weights = {} + with tf.compat.v1.Session(graph=graph) as sess: + for op in graph.get_operations(): + if op.type == "Const": + weights[op.name] = sess.run(op.outputs[0]) + return weights + + +def compare_weights(h5_weights, pb_weights): + for key in h5_weights: + if key in pb_weights: + if np.allclose(h5_weights[key], pb_weights[key]): + print(f"Weights for {key} match.") + else: + print(f"Weights for {key} do not match.") + else: + print(f"Weight {key} not found in PB model.") + + for key in pb_weights: + if key not in h5_weights: + print(f"Weight {key} not found in H5 model.") + + +def check_weights(h5_file_path, pb_file_path): + h5_weights = read_h5_weights(h5_file_path) + pb_weights = read_pb_weights(pb_file_path) + compare_weights(h5_weights, pb_weights) + + +# Example usage: +# check_weights('path_to_model.h5', 'path_to_model.pb') + +# Example usage: +check_weights("/home/ntinouldinho/Desktop/smlp/result/abc_smlp_toy_basic_nn_keras_model_complete.h5", "/home/ntinouldinho/Desktop/smlp/src/smlp_py/NN_verifiers/saved_model.pb") diff --git a/src/smlp_py/NN_verifiers/test_marabou.py b/src/smlp_py/NN_verifiers/test_marabou.py new file mode 100755 index 00000000..66902c95 --- /dev/null +++ b/src/smlp_py/NN_verifiers/test_marabou.py @@ -0,0 +1,190 @@ +from src.smlp_py.NN_verifiers.verifiers import MarabouVerifier +from src.smlp_py.smtlib.text_to_sympy import TextToPysmtParser + +from pysmt.shortcuts import Symbol, And, Not, Or, Implies, simplify, LT, Real, Times, Minus, Plus, Equals, GE, ToReal, LE +from pysmt.typing import * +import tf2onnx +import numpy as np +from pysmt.shortcuts import Symbol, Times, Minus, Div, Real +from pysmt.smtlib.parser import get_formula +# from pysmt.oracles import get_logic +from pysmt.typing import REAL +from z3 import simplify, parse_smt2_string +import z3 +from pysmt.smtlib.script import smtlibscript_from_formula +from io import StringIO + +from maraboupy.MarabouPythonic import * + + +if __name__ == "__main__": + import numpy as np + from tensorflow.keras.models import load_model + import os + + script_dir = os.path.dirname(os.path.abspath(__file__)) + relative_h5_path = os.path.join(script_dir, '../../../result/abc_smlp_toy_basic_nn_keras_model_complete.h5') + absolute_h5_path = os.path.normpath(relative_h5_path) + # Load the model from the .h5 file + model = load_model(absolute_h5_path) + + # Prepare your input data + input_data = np.array([[1.043789425, 0, 0.191919192, 0]]) # Example input data + + # Pass the inputs to the model and get the outputs + outputs = model.predict(input_data) + + # Print the outputs + print("Model outputs:", outputs) + + from keras.models import load_model + + # model = load_model("/home/kkon/Desktop/smlp/result/abc_smlp_toy_basic_nn_keras_model_complete.h5") + # model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model, opset=13, output_path="smlp_toy.onnx") + print("SAVING TO ONNX") + parser = TextToPysmtParser() + parser.init_variables(symbols=[("x1", "real"), ('x2', 'real'), ('p1', 'real'), ('p2', 'real'), + ('y1', 'real'), ('y2', 'real')]) + + mb = MarabouVerifier(parser=parser) + mb.init_variables(inputs=[("x1", "Real"), ('x2', 'Integer'), ('p1', 'Real'), ('p2', 'Integer')], + outputs=[('y1', 'Real'), ('y2', 'Real')]) + mb.initialize() + + smlp_formula = '(let ((|:0| (* (/ 281474976710656 2944425288877159) (- y1 (/ 1080863910568919 4503599627370496))))) (let ((|:1| (* (/ 281474976710656 2559564553220679) (- (* (/ 1 2) (+ y1 y2)) (/ 1170935903116329 1125899906842624))))) (>= (ite (< |:0| |:1|) |:0| |:1|) 1)))' + smlp_str = f""" + (declare-fun y1 () Real) + (declare-fun y2 () Real) + (assert {smlp_formula}) + """ + + smlp_parsed = z3.parse_smt2_string(smlp_str) + smlp_simplified = z3.simplify(smlp_parsed[0]) + ex = parser.parse(str(smlp_simplified)) + # ex = parser.replace_constants_with_floats_and_evaluate(ex) + marabou_formula = parser.convert_ite_to_conjunctions_disjunctions(ex) + print(marabou_formula.serialize()) + + + + + + y1 = parser.get_symbol("y1") + y2 = parser.get_symbol("y2") + p1 = parser.get_symbol("p1") + p2 = parser.get_symbol("p2") + x1 = parser.get_symbol("x1") + x2 = parser.get_symbol("x2") + + x2_int = parser.create_integer_disjunction("x2_unscaled", (-1, 1)) + p2_int = parser.create_integer_disjunction("p2_unscaled", (3, 7)) + # alpha = (((-1 <= x2) & (0.0 <= x1) & (x2 <= 1) & (x1 <= 10.0)) & (((p2 < 5) & (x1 == 10.0)) & (x2 < 12))) + # beta = ((4 <= y1) & (6 <= y2)) + + + # with x as input: y1==6.847101329531717 & y2==10.31207527363552 + # with x as knob: y1==4.120704402283359 & + solution = And( + Equals(x1, Real(10)), + Equals(x2, Real(-1)), + Equals(p1, Real(2)), + Equals(p2, Real(3)) + ) + + print(smtlibscript_from_formula(solution)) + theta = And( + GE(p1, Real(6.8)), + GE(p2, Real(3.8)), + LE(p1, Real(7.2)), + LE(p2, Real(4.2)) + ) + alpha = And( + GE(x2, Real(-1)), + LE(x2, Real(1)), + GE(x1, Real(0.0)), + LE(x1, Real(10.0)), + And( + LT(p2, Real(5)), + Equals(x1, Real(10.0)), + LT(x2, Real(12)) + ) + ) + + beta = And( + GE(y1, Real(4)), + GE(y2, Real(8)), + ) + + not_beta = Or( + LT(y1, Real(4)), + LT(y2, Real(8)) + ) + eta = And( + GE(p1, Real(0.0)), + LE(p1, Real(10.0)), + GE(p2, Real(3)), + LE(p2, Real(7)), + Or( + p1.Equals(Real(2.0)), + p1.Equals(Real(4.0)), + p1.Equals(Real(7.0)) + ) + ) + script = smtlibscript_from_formula(eta) + + outstream = StringIO() + script.serialize(outstream) + output = outstream.getvalue() + smlp_parsed = z3.parse_smt2_string(output) + smlp_simplified = z3.simplify(smlp_parsed[0]) + mb.apply_restrictions(x2_int) + mb.apply_restrictions(p2_int) + # mb.apply_restrictions(beta) + # mb.apply_restrictions(alpha) + # mb.apply_restrictions(eta) + # mb.apply_restrictions(marabou_formula) + mb.apply_restrictions(solution) + + # mb.apply_restrictions(theta) + + witness= mb.solve() + print(witness) + +################## TEST PARSER ########################### +# if __name__ == "__main__": +# parser = TextToPysmtParser() +# parser.init_variables(inputs=[("x1", "real"), ('x2', 'int'), ('p1', 'real'), ('p2', 'int'), +# ('y1', 'real'), ('y2', 'real')]) +# +# mb = MarabouVerifier(parser=parser) +# mb.init_variables(inputs=[("x1", "Real"), ('x2', 'Integer'), ('p1', 'Integer'), ('p2', 'Integer')], +# outputs=[('y1', 'Real'), ('y2', 'Real')]) +# +# +# # (1<=(ite)) and (y<=4) and (y>=8) +# # ite_without_ite = Or(And(c, t), And(Not(c), f)) +# +# y1 = parser.get_symbol("y1") +# y2 = parser.get_symbol("y2") +# +# ex = parser.parse('(y1+y2)/2') +# +# c = y1 > y2 +# t = y1 +# f = y2 +# # ite_without_ite = Or(And(c, t), And(Not(c), f)) +# +# condition_true = Times(ToReal(c), y1) # y1 if y1 > y2 +# condition_false = Times(ToReal(Not(c)), y2) # y2 if y1 <= y2 +# +# # Combine them +# ite_without_ite = Plus(condition_true, condition_false) +# +# # Final expression (ite_without_ite >= 1) +# inequality = GE(ite_without_ite, Real(1)) +# +# # Combine with the inequality 1 <= ITE(c, t, f) +# # inequality = Real(1) <= ite_without_ite +# print(inequality) + + diff --git a/src/smlp_py/NN_verifiers/verifiers.py b/src/smlp_py/NN_verifiers/verifiers.py new file mode 100755 index 00000000..d1c91dbd --- /dev/null +++ b/src/smlp_py/NN_verifiers/verifiers.py @@ -0,0 +1,635 @@ +import os +from abc import ABC, abstractmethod +from enum import Enum +from typing import List, Dict, Optional, Tuple + +from maraboupy import Marabou +from maraboupy import MarabouCore +from maraboupy import MarabouUtils +import tensorflow as tf +from pysmt.shortcuts import Symbol, And, Not, Or, Implies, simplify, LT, Real, Times, Minus, Plus, Equals, Int, ToReal +from pysmt.typing import BOOL, REAL, INT +import numpy as np +from maraboupy.MarabouPythonic import * +from pysmt.walkers import IdentityDagWalker +from fractions import Fraction +import smlp + +from src.smlp_py.smtlib.smt_to_pysmt import smtlib_to_pysmt +from src.smlp_py.smtlib.text_to_sympy import TextToPysmtParser +from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 +import json + +_operators_ = [">=", "<=", "<", ">"] + +convert_comparison_operators = { + "=": MarabouCore.Equation.EQ, + "<=": MarabouCore.Equation.LE, + ">=": MarabouCore.Equation.GE + } + +class Verifier(ABC): + @abstractmethod + def add_disjunction(self): + pass + + +class Variable: + _input_index = 0 + _output_index = 0 + + class Type(Enum): + Real = 0 + Int = 1 + + class Bounds: + def __init__(self, lower=-np.inf, upper=np.inf): + self.lower = lower + self.upper = upper + + def __init__(self, form: Type, index=None, name="", is_input=True): + self.index = index + self.form = form + self.name = name + self.is_input = is_input + self.bounds = Variable.Bounds() + + @staticmethod + def get_index(direction="output"): + return Variable._input_index if direction == "input" else Variable._output_index + + def set_lower_bound(self, lower): + self.bounds.lower = lower + + def set_upper_bound(self, upper): + self.bounds.upper = upper + +class MarabouVerifier(Verifier): + def __init__(self, parser=None, data_bounds_file=None, model_file_prefix=None, variable_ranges=None, is_temp=False): + # MarabouNetwork containing network instance + self.network = None + + # Dictionary containing variables + self.bounds = {} + + # List of MarabouCommon.Equation currently applied to network query + self.equations = [] + + # List of variables + self.variables = [] + + self.variable_ranges = variable_ranges + + self.unscaled_variables = [] + + self.model_file_path = "./" + self.model_file_prefix = model_file_prefix + # self.data_bounds_file = self.find_file_path("../../../result/abc_smlp_toy_basic_data_bounds.json") + self.data_bounds_file = self.find_file_path('../../'+data_bounds_file) + self.data_bounds = None + # Adds conjunction of equations between bounds in form: + # e.g. Int(var), var >= 0, var <= 3 -> Or(var == 0, var == 1, var == 2, var == 3) + + self.input_index = 0 + self.output_index = 0 + + self.parser = parser + self.network_num_vars = None + self.init_variables(is_temp=is_temp) + + self.applied_equations = [] + + if self.variable_ranges: + self.initialize() + + + def initialize(self, variable_ranges=None): + if variable_ranges: + self.variable_ranges = variable_ranges + + self.model_file_path = self.find_file_path('../../'+ self.model_file_prefix +'_model_complete.h5') + self.convert_to_pb() + self.load_json() + self.network_num_vars = self.network.numVars + self.add_unscaled_variables() + self.create_integer_range() + + def reset(self): + self.network.clear() + self.network = Marabou.read_tf('model.pb') + self.unscaled_variables = [] + self.add_unscaled_variables() + self.applied_equations = [] + # Default bounds for network + for equation in self.equations: + self.apply_restrictions(equation) + + + def load_json(self): + with open(self.data_bounds_file, 'r') as file: + self.data_bounds = json.load(file) + + def epsilon(self, e, direction): + if direction == 'down': + return np.nextafter(e, -np.inf) + elif direction == 'up': + return np.nextafter(e, np.inf) + else: + raise ValueError("Direction must be 'up' or 'down'") + + def find_file_path(self, relative_path): + script_dir = os.path.dirname(os.path.abspath(__file__)) + relative_h5_path = os.path.join(script_dir, relative_path) + absolute_h5_path = os.path.normpath(relative_h5_path) + return absolute_h5_path + + def convert_to_pb(self, output_model_file_path="."): + model = tf.keras.models.load_model(self.model_file_path) + tf.saved_model.save(model, output_model_file_path) + # Load the SavedModel + model = tf.saved_model.load(output_model_file_path) + concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + print("converted h5 to pb...") + + # Convert to ConcreteFunction + frozen_func = convert_variables_to_constants_v2(concrete_func) + graph_def = frozen_func.graph.as_graph_def() + + # Save the frozen graph + with tf.io.gfile.GFile('model.pb', 'wb') as f: + f.write(graph_def.SerializeToString()) + + self.network = Marabou.read_tf('model.pb') + + + + def init_variables(self, is_temp=False) -> None: + self.create_variables(is_input=True, is_temp=is_temp) + self.create_variables(is_input=False, is_temp=is_temp) + + def create_variables(self, is_input=True, is_temp=False): + store = self.parser.inputs if is_input else self.parser.outputs + for var in store: + name, type = var + var_type = Variable.Type.Real if type.lower() in ["real", "float"] else Variable.Type.Int + if name.startswith(('x', 'p', 'y')) and name.find("_scaled") == -1: + index = self.input_index if is_input else self.output_index + self.variables.append(Variable(var_type, name=name, index=index, is_input=is_input)) + + if is_input: + self.input_index += 1 + else: + self.output_index += 1 + + def create_integer_range(self): + integer_variables = [variable for variable in self.variables if variable.form == Variable.Type.Int] + for variable in integer_variables: + int_range = self.variable_ranges.get(variable.name) + if not int_range: + raise Exception(f"Need integer rangers for variable {variable.name}") + ranges = int_range['interval'] + lower, upper = ranges[0], ranges[-1] + variable.bounds = Variable.Bounds(lower=lower, upper=upper) + integer_formula = self.parser.create_integer_disjunction(f'{variable.name}_unscaled', (lower, upper)) + self.add_permanent_constraint(integer_formula) + + + def add_unscaled_variables(self): + for variable in self.variables: + unscaled_variable = self.network.getNewVariable() + self.unscaled_variables.append(Variable(Variable.Type.Real, index=unscaled_variable, name=f"{variable.name}_unscaled", is_input=True)) + + self.convert_scaled_unscaled() + + + def convert_scaled_unscaled(self): + for scaled_var, unscaled_var in zip(self.variables, self.unscaled_variables): + if scaled_var.name.find("_scaled") != -1: + continue + bounds = self.data_bounds[scaled_var.name] + min_value, max_value = bounds["min"], bounds["max"] + + scaling_factor = max_value - min_value + + _, scaled_var_index = self.get_variable_by_name(scaled_var.name) + _, unscaled_var_index = self.get_variable_by_name(unscaled_var.name) + + # Create an equation representing (x_max - x_min) * x_scaled - x_unscaled = - x_min + eq = MarabouUtils.Equation(MarabouCore.Equation.EQ) + eq.addAddend(scaling_factor, scaled_var_index) + eq.addAddend(-1, unscaled_var_index) + eq.setScalar(-min_value) + + # Add the equation to the network + # self.add_permanent_constraint(eq) + self.network.addEquation(eq) + + + def get_variable_by_name(self, name: str) -> Optional[Tuple[Variable, int]]: + is_output = name.startswith("y") + is_unscaled = name.find("_unscaled") != -1 + is_scaled = name.find("_scaled") != -1 + repository = self.unscaled_variables if is_unscaled else self.variables + + if is_scaled: + return None + + for index, variable in enumerate(repository): + if variable.name == name: + if is_unscaled: + return variable, variable.index + elif is_output: + index -= self.input_index + index = self.network.outputVars[0][0][index] if is_output else self.network.inputVars[0][0][index] + return variable, index + return None + + + def add_permanent_constraint(self, formula): + self.equations.append(formula) + self.apply_restrictions(formula) + + def add_bound(self, variable:str, value, direction="upper", strict=True): + var, var_index = self.get_variable_by_name(f"{variable}_unscaled") + if var is None: + return None + + epsilon_direction = "down" if direction == "upper" else "up" + value = self.epsilon(value, epsilon_direction) if strict else value + + if direction == "upper": + self.network.setUpperBound(var_index, value) + var.set_upper_bound(value) + elif direction == "lower": + self.network.setLowerBound(var_index, value) + var.set_lower_bound(value) + + def add_equality(self, variable, value): + var, var_index = self.get_variable_by_name(f"{variable}_unscaled") + + eq = MarabouUtils.Equation(MarabouCore.Equation.EQ) + eq.addAddend(1, var_index) + eq.setScalar(value) + self.network.addEquation(eq) + + + def apply_restrictions(self, formula, need_simplification=False): + self.applied_equations.append(formula) + formula = self.parser.simplify(formula) + conjunctions, disjunctions = self.process_formula(formula) + + for conjunction in conjunctions: + self.process_comparison(conjunction, need_simplification) + + self.process_disjunctions(disjunctions, need_simplification) + + def transform_pysmt_to_marabou_equation(self, formula): + symbols, comparator, scalar = formula + equation_type = None + + if comparator in convert_comparison_operators: + equation_type = convert_comparison_operators[comparator] + else: + if comparator == '<': + equation_type = MarabouCore.Equation.LE + scalar = self.epsilon(scalar, "down") + elif comparator == '>': + equation_type = MarabouCore.Equation.GE + scalar = self.epsilon(scalar, "up") + + equation = MarabouUtils.Equation(equation_type) + equation.setScalar(scalar) + + for parameter in symbols: + coefficient, symbol = parameter + # TODO: do not enforce the unscaled variables + name = str(symbol) + if name.find("_unscaled") == -1: + name += "_unscaled" + symbol, index = self.get_variable_by_name(name) + equation.addAddend(coefficient, index) + + return equation + + def is_negation_of_ite(self, formula): + if formula.is_and(): + if len(formula.args()) == 2: + # this is a custom logical block for handling the negation of objectives which yield a formula that looks like + # Or(A,B,C) where C = And(Or(D,E),Or(F,G)) (1) , which needs to be translated into: + # let K = Or(D,E), then C=And(K, Or(F,G)), which is equivalent to: Or(And(K,F),And(K,G)) (2). + # Then, using (2): And(K,F) = And(F, Or(D,E)), which is equivalent to: Or(And(F,D), And(F,E)) (3) + # Same applied to And(G, K) = Or(And(G,D), And(G,E)) (4) + # Finally: Or(And(K,F),And(K,G)) = Or(Or(And(F,D), And(F,E)), Or(And(G,D), And(G,E))), + # Which can be simplified to Or(And(F,D), And(F,E), And(G,D), And(G,E)) + left = formula.args()[0] + right = formula.args()[1] + if left.is_or() and len(left.args()) == 2 and right.is_or() and len(right.args()) == 2: + eq_1, eq_2 = left.args()[0], left.args()[1] + eq_3, eq_4 = right.args()[0], right.args()[1] + return True, [And(eq_1, eq_3), And(eq_1, eq_4), And(eq_2, eq_3), And(eq_2, eq_4)] + return False, [] + + def create_equation(self, formula, from_and=False, need_simplification=False): + equations = [] + formula = self.parser.simplify(formula) if not need_simplification else self.parser.z3_simplify(formula) + + if formula.is_and(): + equation = [self.create_equation(eq, from_and=True) for eq in formula.args()] + return equation + elif formula.is_le() or formula.is_lt() or formula.is_equals(): + res = self.parser.extract_components(formula, need_simplification) + equations.append(self.transform_pysmt_to_marabou_equation(res)) + elif formula.is_not(): + negation = self.parser.propagate_negation(formula) + res = self.parser.extract_components(negation, need_simplification) + equations.append(self.transform_pysmt_to_marabou_equation(res)) + + return equations[0] if from_and else equations + + def process_disjunctions(self, disjunctions, need_simplification=False): + marabou_disjunction = [] + for disjunction in disjunctions: + # split the disjunction into separate formulas + for formula in disjunction.args(): + res, formulas = self.is_negation_of_ite(formula) + formulas = formulas if res else [formula] + for formula in formulas: + # formula = self.parser.z3_simplify(formula) + equation = self.create_equation(formula, from_and=False, need_simplification=need_simplification) + if equation: + marabou_disjunction.append(equation) + + if len(marabou_disjunction) > 0: + self.network.addDisjunctionConstraint(marabou_disjunction) + + def process_formula(self, formula): + conjunctions = [] + disjunctions = [] + + def traverse(node, source=[]): + if node.is_and(): + # conjunctions.extend(node.args()) + for arg in node.args(): + traverse(arg, conjunctions) + elif node.is_or(): + disjunctions.append(node) + elif node.is_le() or node.is_lt() or node.is_equals(): + source.append(node) + else: + # Leaf nodes (symbols, literals, etc.) are not conjunctions or disjunctions + pass + + traverse(formula) + return conjunctions, disjunctions + + def process_comparison(self, formula, need_simplification=False): + formula = self.parser.z3_simplify(formula) + if formula.is_le() or formula.is_lt() or formula.is_equals(): + symbols, comparison, constant = self.parser.extract_components(formula, need_simplification) + + if len(symbols) > 1: + equation = self.transform_pysmt_to_marabou_equation((symbols, comparison, constant)) + self.network.addEquation(equation) + else: + _, symbol = symbols[0] + symbol = str(symbol) + + if comparison == "<=": + self.add_bound(symbol, constant, direction="upper", strict=False) + elif comparison == "<": + self.add_bound(symbol, constant, direction="upper", strict=True) + if comparison == ">=": + self.add_bound(symbol, constant, direction="lower", strict=False) + elif comparison == ">": + self.add_bound(symbol, constant, direction="lower", strict=True) + elif comparison == "=": + # TODO: add a marabou equation instead + self.add_equality(symbol, constant) + # self.add_bound(symbol, constant, direction="lower", strict=False) + # self.add_bound(symbol, constant, direction="upper", strict=False) + else: + return + + + def find_witness(self, witness): + answers = {"result":"SAT", "witness":{}, 'witness_var':{}} + for variable in self.unscaled_variables: + _, unscaled_index = self.get_variable_by_name(variable.name) + name = variable.name.replace("_unscaled", "") + scaled_var, _ = self.get_variable_by_name(name) + answers['witness_var'][scaled_var] = witness[unscaled_index] + answers['witness'][scaled_var.name] = witness[unscaled_index] + print(answers['witness']) + return answers + + def solve(self): + try: + # Options(verbose=0, cores=5) + results = self.network.solve() + if results and results[0] == 'unsat': + return "UNSAT", {"result":"UNSAT", "witness": {}} + else: # sat + return "SAT", self.find_witness(results[1]) + except Exception as e: + print(e) + return None + + def add_disjunction(self,): + return + + + + +if __name__ == "__main__": + parser = TextToPysmtParser() + # p2 is an int not a real + parser.init_variables(symbols=[("x1", "real"), ('x2', 'int'), ('p1', 'real'), ('p2', 'real'), + ('y1', 'real'), ('y2', 'real')]) + + mb = MarabouVerifier(parser=parser) + mb.init_variables(inputs=[("x1", "Real"),('x2', 'Integer'), ('p1', 'Integer'), ('p2', 'Integer')], outputs=[('y1', 'Real'), ('y2', 'Real')]) + + + def linearize(expr): + """ + Linearize the given expression, ensuring it is in a linear format. + """ + if expr.is_real_constant(): + return expr, 0 + elif expr.is_symbol(): + return expr, 0 + elif expr.is_plus(): + lhs, lhs_const = linearize(expr.arg(0)) + rhs, rhs_const = linearize(expr.arg(1)) + return Plus(lhs, rhs), lhs_const + rhs_const + elif expr.is_minus(): + lhs, lhs_const = linearize(expr.arg(0)) + rhs, rhs_const = linearize(expr.arg(1)) + return Minus(lhs, rhs), lhs_const - rhs_const + elif expr.is_times(): + const_part = 1 + var_part = None + for arg in expr.args(): + if arg.is_real_constant(): + const_part *= arg.constant_value() + else: + var_expr, var_const = linearize(arg) + if var_const != 0: + raise ValueError(f"Non-linear term detected: {expr}") + if var_part is None: + var_part = var_expr + else: + raise ValueError(f"Non-linear term detected: {expr}") + return Times(Real(const_part), var_part), 0 + else: + raise ValueError(f"Unsupported operation: {expr}") + + + def simplify_to_linear(formula): + """ + Simplify a given formula to a linear format if possible. + """ + if formula.is_lt() or formula.is_le(): + lhs, lhs_const = linearize(formula.arg(0)) + rhs, rhs_const = linearize(formula.arg(1)) + return LT(Plus(lhs, Real(lhs_const - rhs_const)), rhs) + elif formula.is_gt() or formula.is_ge(): + lhs, lhs_const = linearize(formula.arg(0)) + rhs, rhs_const = linearize(formula.arg(1)) + return LT(rhs, Plus(lhs, Real(lhs_const - rhs_const))) + elif formula.is_equals(): + lhs, lhs_const = linearize(formula.arg(0)) + rhs, rhs_const = linearize(formula.arg(1)) + return Equals(Plus(lhs, Real(lhs_const - rhs_const)), rhs) + else: + raise ValueError(f"Unsupported formula type: {formula}") + + y1 = parser.get_symbol("y1") + y2 = parser.get_symbol("y2") + p1 = parser.get_symbol("p1") + p2 = parser.get_symbol("p2") + + # formula = ( (-1 <= 5*x2) | ( (0.0 == x1) & (x2 > 1) ) ) + # Construct the left-hand side: 0.1 * (x1 - 0.2) + lhs = Times(Real(0.1), Minus(y1, Real(0.2))) + + # Construct the right-hand side: 0.3 * (0.4 * (x2 - x1) - 0.5) + inner_term = Minus(y2, y1) + scaled_inner_term = Times(Real(0.4), inner_term) + rhs_inner = Minus(scaled_inner_term, Real(0.5)) + rhs = Times(Real(0.3), rhs_inner) + + # Construct the inequality: lhs < rhs + inequality = LT(lhs, rhs) + # f = simplify_to_linear(inequality) + # formula = parser.parse("p1==4.0 or (p1==8.0 and p2 > 3)") + # formula = parser.parse("((3 <= p2) & (p2 <= 4) & (7656119366529843/1125899906842624 <= p1) & (p1 <= 8106479329266893/1125899906842624))") + # formula = "(let ((|:0| (- p1 7))) (let ((|:1| (- p2 4))) (and (and true (<= (ite (< |:0| 0) (- |:0|) |:0|) (/ 1 5))) (<= (ite (< |:1| 0) (- |:1|) |:1|) (/ 1 5)))))" + + # formula = ((3 <= p2) & (p2 <= 4) & (7656119366529843/1125899906842624 <= p1) & (p1 <= 8106479329266893/1125899906842624)) + formula = And(p1.Equals(Real(4)), Or(p1.Equals(Real(8)), And(LT(Real(3), p2), p1.Equals(Real(5))))) + var_types = { + 'y1': 'REAL', + 'y2': 'REAL', + 'p1': 'REAL', + 'p2': 'INT', + 'x1': 'REAL', + 'x2': 'INT' + } + # formula = smtlib_to_pysmt(formula, var_types) + # mb.apply_restrictions(formula) + + + # mb.add_bounds("x1", (0,10)) + # mb.add_bounds("x2", (-1, 1), num="int") + # mb.add_bounds("p1", (0, 10), num="grid", grid=[2, 4, 7]) + # mb.add_bounds("p2", (3, 7), num="int") + # mb.alpha() + # + # for var in mb.network.outputVars[0][0]: + # print(var) + # + exitCode1, vals1, stats1 = mb.solve() + print(exitCode1) + + +# TODO: CHECK IF MARABOU NATIVELY SUPPORTS INTEGERS: it does not + # def add_bounds(self, variable, bounds=None, num="real", grid=None): + # var, is_output = self.get_variable_by_name(variable) + # if var is None: + # return None + # + # # TODO: handle case when one of the two is None + # if bounds: + # lower, upper = bounds + # self.network.setLowerBound(var.index, lower) + # self.network.setUpperBound(var.index, upper) + # + # if num == "int": + # # add all distinct integer values + # grid = range(lower, upper+1) + # + # if num in ["int", "grid"] and grid is not None: + # disjunction = [] + # for i in grid: + # eq1 = MarabouUtils.Equation(MarabouCore.Equation.EQ) + # eq1.addAddend(1, var.index) + # eq1.setScalar(i) + # disjunction.append([eq1]) + # + # self.network.addDisjunctionConstraint(disjunction) + + +# def alpha(self): +# # (((-1 <= x2) & (0.0 <= x1) & (x2 <= 1) & (x1 <= 10.0)) & (((p2 < 5) & (x1 = 10.0)) & (x2 < 12))) +# # p2<5 and x1==10 and x2<12 +# # (p2≥5)∨(x1#10)∨(x2≥12) +# +# p1, is_output = self.get_variable_by_name("p1") +# p2, is_output = self.get_variable_by_name("p2") +# x1, is_output = self.get_variable_by_name("x1") +# x2, is_output = self.get_variable_by_name("x2") +# y1, is_output = self.get_variable_by_name("y1") +# y2, is_output = self.get_variable_by_name("y2") +# +# # +# # self.network.setUpperBound(p2.index, 5-epsilon) +# v = Var(p2.index) +# +# # self.network.addConstraint(v <= self.epsilon(5, "down")) +# # +# # self.network.setUpperBound(x1.index, self.epsilon(10,'up')) +# # self.network.setLowerBound(x1.index, self.epsilon(10, "down")) +# # +# # self.network.setUpperBound(x2.index, self.epsilon(12, "down")) +# # +# # self.network.setLowerBound(y1.index, 4) +# # self.network.setUpperBound(y2.index, 8) +# +# # p1==4.0 or (p1==8.0 and p2 > 3) +# eq1 = MarabouUtils.Equation(MarabouCore.Equation.EQ) +# eq1.addAddend(1, p1.index) +# eq1.setScalar(4) +# +# eq2 = MarabouUtils.Equation(MarabouCore.Equation.EQ) +# eq2.addAddend(1, p1.index) +# eq2.setScalar(8) +# +# eq3 = MarabouUtils.Equation(MarabouCore.Equation.GE) +# eq3.addAddend(1, p2.index) +# eq3.setScalar(self.epsilon(3, "up")) +# +# self.network.addDisjunctionConstraint([[eq1], [eq2, eq3]]) +# +# # b1 = self.network.getNewVariable() +# # +# # # Define the epsilon value +# # epsilon = 1e-5 +# # +# # # Constraint for (y1 + y2) / 2 > 1 when b1 = 1 +# # # This is equivalent to y1 + y2 > 2 +# # self.network.addInequality([y1, y2, b1], [1, 1, -2], -epsilon) # y1 + y2 - 2*b1 > 0 -> y1 + y2 > 2 when b1 = 1 +# # +# # # Ensure b1 is binary +# # self.network.setLowerBound(b1, 0) +# # self.network.setUpperBound(b1, 1) \ No newline at end of file diff --git a/src/smlp_py/marabou/marabou.py b/src/smlp_py/marabou/marabou.py new file mode 100755 index 00000000..695ad657 --- /dev/null +++ b/src/smlp_py/marabou/marabou.py @@ -0,0 +1,141 @@ +import sys +sys.path.append('/home/Desktop/Marabou/maraboupy') +from maraboupy import Marabou, MarabouCore +from maraboupy.MarabouCore import * +import numpy as np + +import os + + +class ONNXNetwork: + def __init__(self): + filename = "../../test.onnx" + # filename = "test.onnx" + self.network = Marabou.read_onnx(filename) + + def beta(self): + # self.network.setLowerBound(4, 4) + # self.network.setUpperBound(4, 10) + # + # self.network.setLowerBound(5, 8) + # self.network.setUpperBound(5, 20) + + # BEST SOLUTION + self.network.setLowerBound(4, 0.24) + self.network.setUpperBound(4, 10.7007) + + self.network.setLowerBound(5, 1.12) + self.network.setUpperBound(5, 12.02) + + def alpha(self): + # p2<5 and x1==10 and x2<12 + # (p2≥5)∨(x1#10)∨(x2≥12) + + epsilon = 1e-12 + + eq1 = MarabouCore.Equation(MarabouCore.Equation.GE) + eq1.addAddend(1, 3) + eq1.setScalar(5) + + eq2 = MarabouCore.Equation(MarabouCore.Equation.GE) + eq2.addAddend(1, 1) + eq2.setScalar(12) + + eq3 = MarabouCore.Equation(MarabouCore.Equation.GE) + eq3.addAddend(1, 0) + eq3.setScalar(10+epsilon) + + eq4 = MarabouCore.Equation(MarabouCore.Equation.LE) + eq4.addAddend(1, 0) + eq4.setScalar(10 - 1e-12) + + self.network.addDisjunctionConstraint([[eq1], [eq2], [eq3], [eq4]]) + + def add_bounds(self, var, bounds, num="real", grid=None): + lower, upper = bounds + self.network.setLowerBound(var, lower) + self.network.setUpperBound(var, upper) + + if num == "in": + disjunction = [] + + for i in range(lower, upper+1): + eq1 = MarabouCore.Equation(MarabouCore.Equation.EQ) + eq1.addAddend(1, var) + eq1.setScalar(i) + disjunction.append([eq1]) + + self.network.addDisjunctionConstraint(disjunction) + + if grid is not None: + disjunction = [] + + for num in grid: + eq1 = MarabouCore.Equation(MarabouCore.Equation.EQ) + eq1.addAddend(1, var) + eq1.setScalar(num) + disjunction.append([eq1]) + + self.network.addDisjunctionConstraint(disjunction) + + + def run_marabou(self): + options = Marabou.createOptions(verbosity = 10) + + grid = [2, 4, 7] + for var in self.network.inputVars[0][0]: + # if var == 0: + # self.add_bounds(var, (0, 10)) + # elif var == 1: + # self.add_bounds(var, (-1, 1), num="int") + # elif var == 2: + # self.add_bounds(var, (0, 10), num="int", grid=grid) + # elif var == 3: + # self.add_bounds(var, (3, 7), num="int") + + # BEST SOLUTION + if var == 0: + self.add_bounds(var, (-0.8218, 9.546)) + elif var == 1: + self.add_bounds(var, (-1, 1), num="int") + elif var == 2: + self.add_bounds(var, (0.1, 10), num="int", grid=grid) + elif var == 3: + self.add_bounds(var, (3, 7), num="int") + + # self.alpha() + self.beta() + + exitCode, vals, stats = self.network.solve(options = options) + + # Test Marabou equations against onnxruntime at an example input point + # inputPoint = np.ones(inputVars.shape) + # marabouEval = network.evaluateWithMarabou([inputPoint], options = options)[0] + # onnxEval = network.evaluateWithoutMarabou([inputPoint])[0] + print(exitCode, vals, stats) +# ONNXNetwork().run_marabou() + +# if __name__ == "__main__": +# onnx_file = "/home/ntinouldinho/Desktop/Marabou/data/test.onnx" +# # property_filename = "/home/ntinouldinho/Desktop/Marabou/data/model_constraints.vnnlib" +# # onnx_file = "/home/ntinouldinho/Desktop/smlp/src/test.onnx" +# property_filename = "/home/ntinouldinho/Desktop/smlp/src/query.vnnlib" +# +# network = Marabou.read_onnx(onnx_file) +# network.saveQuery("./query.txt") +# +# try: +# ipq = Marabou.load_query("./query.txt") +# # MarabouCore.loadProperty(ipq, property_filename) +# exitCode_ipq, vals_ipq, _ = Marabou.solve_query(ipq, propertyFilename=property_filename, filename="res.log") +# print(exitCode_ipq, vals_ipq) +# +# except Exception as e: +# print(e) + +if __name__ == "__main__": + network = Marabou.read_tf("/home/ntinouldinho/Desktop/smlp/result/abc_smlp_toy_basic_model_checkpoint.h5") + + + + diff --git a/src/smlp_py/marabou/query.txt b/src/smlp_py/marabou/query.txt new file mode 100755 index 00000000..d7d7c6f5 --- /dev/null +++ b/src/smlp_py/marabou/query.txt @@ -0,0 +1,51 @@ +30 +12 +0 +14 +12 +4 +0,0 +1,1 +2,2 +3,3 +2 +0,28 +1,29 +12,0.0000000000 +13,0.0000000000 +14,0.0000000000 +15,0.0000000000 +16,0.0000000000 +17,0.0000000000 +18,0.0000000000 +19,0.0000000000 +24,0.0000000000 +25,0.0000000000 +26,0.0000000000 +27,0.0000000000 +0,0,0.0302654002,0,-0.1437274814,1,0.1061730981,2,-0.0870333686,3,-0.4645415545,4,-1.0000000000 +1,0,-0.0128128808,0,-0.5780213475,1,0.5900486112,2,0.5305879116,3,0.3497457504,5,-1.0000000000 +2,0,-0.0007434468,0,-0.6913169622,1,-0.1507877558,2,0.6713727117,3,-0.1944409609,6,-1.0000000000 +3,0,0.0132133318,0,-0.0250194836,1,0.1154540256,2,-0.4708667994,3,-0.4788347185,7,-1.0000000000 +4,0,0.0000000000,0,-0.5775315166,1,-0.1357627511,2,-0.3575040698,3,-0.5444254875,8,-1.0000000000 +5,0,-0.0274824426,0,0.5064330697,1,-0.6230512857,2,-0.4707299173,3,-0.3514576852,9,-1.0000000000 +6,0,-0.0666948929,0,0.2382464856,1,-0.3667045534,2,0.1506942511,3,0.5816352963,10,-1.0000000000 +7,0,-0.0941732377,0,0.5513240099,1,-0.2898558974,2,-0.6173257232,3,0.7682082653,11,-1.0000000000 +8,0,-0.1418720335,12,-0.4838356972,13,-0.2879619300,14,0.7321704030,15,-0.1759769320,16,-0.6894207597,17,0.3923604190,18,0.1298776269,19,0.0196413193,20,-1.0000000000 +9,0,0.0638892725,12,0.6492301226,13,0.2203886509,14,-0.2999479175,15,0.5637680888,16,-0.0014058352,17,0.3970938921,18,-0.6249498725,19,0.3417502046,21,-1.0000000000 +10,0,0.0540842414,12,-0.5587263703,13,0.2551203668,14,0.5727752447,15,0.4214664698,16,-0.2815549374,17,-0.5591005087,18,-0.0682931393,19,-0.7146613598,22,-1.0000000000 +11,0,-0.0614044182,12,-0.4547419548,13,0.5524955988,14,0.1784126908,15,-0.2580326498,16,-0.0371657610,17,0.5687257051,18,0.2459676862,19,0.7129698992,23,-1.0000000000 +12,0,0.0996098146,24,0.0004564780,25,0.8553681374,26,-0.3106124997,27,0.8385750651,28,-1.0000000000 +13,0,-0.1071689427,24,0.4485777915,25,0.1472451091,26,-0.8218147159,27,0.9324899912,29,-1.0000000000 +0,relu,12,4 +1,relu,13,5 +2,relu,14,6 +3,relu,15,7 +4,relu,16,8 +5,relu,17,9 +6,relu,18,10 +7,relu,19,11 +8,relu,24,20 +9,relu,25,21 +10,relu,26,22 +11,relu,27,23 \ No newline at end of file diff --git a/src/smlp_py/marabou/res.log b/src/smlp_py/marabou/res.log new file mode 100755 index 00000000..99c4c78f --- /dev/null +++ b/src/smlp_py/marabou/res.log @@ -0,0 +1,167 @@ +Engine::processInputQuery: Input query (before preprocessing): 14 equations, 30 variables +Engine::processInputQuery: Input query (after preprocessing): 26 equations, 35 variables + +Input bounds: + x0: [ 10.0000, 10.0000] [FIXED] + x1: [ -1.0000, 12.0000] + x2: [ 4.0000, 4.0000] [FIXED] + x3: [ 3.0000, 7.0000] + +Branching heuristics set to LargestInterval + +Engine::solve: Initial statistics + +10:56:59 Statistics update: + --- Time Statistics --- + Total time elapsed: 2 milli (00:00:00) + Main loop: 0 milli (00:00:00) + Preprocessing time: 2 milli (00:00:00) + Unknown: 0 milli (00:00:00) + Breakdown for main loop: + [0.00%] Simplex steps: 0 milli + [0.00%] Explicit-basis bound tightening: 0 milli + [0.00%] Constraint-matrix bound tightening: 0 milli + [0.00%] Degradation checking: 0 milli + [0.00%] Precision restoration: 0 milli + [0.00%] Statistics handling: 0 milli + [0.00%] Constraint-fixing steps: 0 milli + [0.00%] Valid case splits: 0 milli. Average per split: 0.00 milli + [0.00%] Applying stored bound-tightening: 0 milli + [0.00%] SMT core: 0 milli + [0.00%] Symbolic Bound Tightening: 0 milli + [0.00%] SoI-based local search: 0 milli + [0.00%] SoI-based local search: 0 milli + [0.00%] Unaccounted for: 0 milli + --- Preprocessor Statistics --- + Number of preprocessor bound-tightening loop iterations: 5 + Number of eliminated variables: 9 + Number of constraints removed due to variable elimination: 7 + Number of equations removed due to variable elimination: 0 + --- Engine Statistics --- + Number of main loop iterations: 1 + 0 iterations were simplex steps. Total time: 0 milli. Average: 0.00 milli. + 0 iterations were constraint-fixing steps. Total time: 0 milli. Average: 0.00 milli + Number of active piecewise-linear constraints: 4 / 5 + Constraints disabled by valid splits: 1. By SMT-originated splits: 0 + Last reported degradation: 0.0000000000. Max degradation so far: 0.0000000000. Restorations so far: 0 + Number of simplex pivots we attempted to skip because of instability: 0. + Unstable pivots performed anyway: 0 + --- Tableau Statistics --- + Total number of pivots performed: 0 + Real pivots: 0. Degenerate: 0 (0.00%) + Degenerate pivots by request (e.g., to fix a PL constraint): 0 (0.00%) + Average time per pivot: 0.00 milli + Total number of fake pivots performed: 0 + Total number of rows added: 0. Number of merged columns: 0 + Current tableau dimensions: M = 26, N = 61 + --- SMT Core Statistics --- + Total depth is 0. Total visited states: 1. Number of splits: 0. Number of pops: 0 + Max stack depth: 0 + --- Bound Tightening Statistics --- + Number of tightened bounds: 0. + Number of rows examined by row tightener: 0. Consequent tightenings: 0 + Number of explicit basis matrices examined by row tightener: 0. Consequent tightenings: 0 + Number of bound tightening rounds on the entire constraint matrix: 0. Consequent tightenings: 0 + Number of bound notifications sent to PL constraints: 30. Tightenings proposed: 0 + --- Basis Factorization statistics --- + Number of basis refactorizations: 2 + --- Projected Steepest Edge Statistics --- + Number of iterations: 0. + Number of resets to reference space: 1. Avg. iterations per reset: 0 + --- SBT --- + Number of tightened bounds: 11 + --- SoI-based local search --- + Number of proposed phase pattern update: 0. Number of accepted update: 0 [0.00%] + Total time (% of local search time) updating SoI phase pattern : 0 milli [0.00%] + Total time obtaining current assignment: 0 milli [0.00%] + Total time getting SoI phase pattern : 0 milli [0.00%] + --- Context dependent statistics --- + Number of pushes / pops: 0 / 0 + [0.00%] Pre-Push hook: 0 milli + [0.00%] Push : 0 milli + [0.00%] Post-Pop hook: 0 milli + [0.00%] Pop : 0 milli + [0.00%] Total context-switching time: 0 milli + --- Proof Certificate --- + Number of certified leaves: 0 + Number of leaves to delegate: 0 + +--- +Before declaring sat, recomputing... + +Engine::solve: sat assignment found + +10:56:59 Statistics update: + --- Time Statistics --- + Total time elapsed: 2 milli (00:00:00) + Main loop: 0 milli (00:00:00) + Preprocessing time: 2 milli (00:00:00) + Unknown: 0 milli (00:00:00) + Breakdown for main loop: + [21.85%] Simplex steps: 0 milli + [9.66%] Explicit-basis bound tightening: 0 milli + [0.00%] Constraint-matrix bound tightening: 0 milli + [0.00%] Degradation checking: 0 milli + [0.00%] Precision restoration: 0 milli + [2.10%] Statistics handling: 0 milli + [0.00%] Constraint-fixing steps: 0 milli + [4.62%] Valid case splits: 0 milli. Average per split: 0.00 milli + [0.84%] Applying stored bound-tightening: 0 milli + [0.00%] SMT core: 0 milli + [387.39%] Symbolic Bound Tightening: 0 milli + [0.00%] SoI-based local search: 0 milli + [0.00%] SoI-based local search: 0 milli + [7750732804079643648.00%] Unaccounted for: 0 milli + --- Preprocessor Statistics --- + Number of preprocessor bound-tightening loop iterations: 5 + Number of eliminated variables: 9 + Number of constraints removed due to variable elimination: 7 + Number of equations removed due to variable elimination: 0 + --- Engine Statistics --- + Number of main loop iterations: 8 + 5 iterations were simplex steps. Total time: 0 milli. Average: 0.00 milli. + 0 iterations were constraint-fixing steps. Total time: 0 milli. Average: 0.00 milli + Number of active piecewise-linear constraints: 4 / 5 + Constraints disabled by valid splits: 1. By SMT-originated splits: 0 + Last reported degradation: 0.0000000000. Max degradation so far: 0.0000000000. Restorations so far: 0 + Number of simplex pivots we attempted to skip because of instability: 0. + Unstable pivots performed anyway: 0 + --- Tableau Statistics --- + Total number of pivots performed: 5 + Real pivots: 5. Degenerate: 0 (0.00%) + Degenerate pivots by request (e.g., to fix a PL constraint): 0 (0.00%) + Average time per pivot: 0.00 milli + Total number of fake pivots performed: 0 + Total number of rows added: 0. Number of merged columns: 0 + Current tableau dimensions: M = 26, N = 61 + --- SMT Core Statistics --- + Total depth is 0. Total visited states: 1. Number of splits: 0. Number of pops: 0 + Max stack depth: 0 + --- Bound Tightening Statistics --- + Number of tightened bounds: 2. + Number of rows examined by row tightener: 5. Consequent tightenings: 1 + Number of explicit basis matrices examined by row tightener: 1. Consequent tightenings: 2 + Number of bound tightening rounds on the entire constraint matrix: 0. Consequent tightenings: 0 + Number of bound notifications sent to PL constraints: 60. Tightenings proposed: 0 + --- Basis Factorization statistics --- + Number of basis refactorizations: 2 + --- Projected Steepest Edge Statistics --- + Number of iterations: 5. + Number of resets to reference space: 1. Avg. iterations per reset: 5 + --- SBT --- + Number of tightened bounds: 13 + --- SoI-based local search --- + Number of proposed phase pattern update: 0. Number of accepted update: 0 [0.00%] + Total time (% of local search time) updating SoI phase pattern : 0 milli [0.00%] + Total time obtaining current assignment: 0 milli [0.00%] + Total time getting SoI phase pattern : 0 milli [0.00%] + --- Context dependent statistics --- + Number of pushes / pops: 0 / 0 + [0.00%] Pre-Push hook: 0 milli + [0.00%] Push : 0 milli + [0.00%] Post-Pop hook: 0 milli + [0.00%] Pop : 0 milli + [0.00%] Total context-switching time: 0 milli + --- Proof Certificate --- + Number of certified leaves: 0 + Number of leaves to delegate: 0 diff --git a/src/smlp_py/smlp_flows.py b/src/smlp_py/smlp_flows.py index da965373..53765810 100644 --- a/src/smlp_py/smlp_flows.py +++ b/src/smlp_py/smlp_flows.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # This file is part of smlp. +import time # imports from SMLP modules from smlp_py.smlp_logs import SmlpLogger, SmlpTracer @@ -19,6 +20,9 @@ from smlp_py.smlp_optimize import SmlpOptimize from smlp_py.smlp_refine import SmlpRefine +from src.smlp_py.solvers.universal_solver import Solver + + # Combining simulation results, optimization, uncertainty analysis, sequential experiments # https://foqus.readthedocs.io/en/3.1.0/chapt_intro/index.html @@ -128,6 +132,8 @@ def __init__(self, argv): self.optInst.set_tracer(self.tracer, self.args.trace_runtime, self.args.trace_precision, self.args.trace_anonymize) self.queryInst.set_lemma_precision(self.args.lemma_precision) + + self.use_pysmt = self.args.use_pysmt # TODO !!!: is this the right place to define data_fname and new_data_fname and error_file ??? @@ -317,7 +323,12 @@ def smlp_flow(self): # sanity check that the order of features in model_features_dict, feat_names, X_train, X_test, X is # the same; this is mostly important for model exploration modes self.modelInst.model_features_sanity_check(model_features_dict, feat_names, X_train, X_test, X) - + + Solver(specs=(feat_names, resp_names, self.modelTernaInst._specInst.get_spec_domain_dict), + data_bounds_file= self.dataInst.data_bounds_file, + model_file_prefix= self.dataInst.model_file_prefix, + version=Solver.Version.PYSMT if args.use_pysmt else Solver.Version.FORM2) + if args.analytics_mode == 'verify': if True or len(self.specInst.get_spec_knobs)> 0: if config_dict is None: @@ -370,6 +381,8 @@ def smlp_flow(self): args.approximate_fractions, args.fraction_precision, self.dataInst.data_bounds_file, bounds_factor=None, T_resp_bounds_csv_path=None) elif args.analytics_mode == 'optimize': + start = time.time() + use_pysmt = args.use_pysmt self.optInst.smlp_optimize(syst_expr_dict, args.model, model, self.dataInst.unscaled_training_features, self.dataInst.unscaled_training_responses, model_features_dict, feat_names, resp_names, objv_names, objv_exprs, args.optimize_pareto, @@ -378,8 +391,9 @@ def smlp_flow(self): args.solver_logic, args.vacuity_check, args.data_scaler, args.scale_features, args.scale_responses, args.scale_objectives, args.approximate_fractions, args.fraction_precision, - self.dataInst.data_bounds_file, bounds_factor=None, T_resp_bounds_csv_path=None) - + self.dataInst.data_bounds_file, bounds_factor=None, T_resp_bounds_csv_path=None, use_pysmt=use_pysmt) + end = time.time() + print(f"TOTAL TIME IS {end-start}") #self.logger.info('self.optInst.best_config_dict {}'.format(str(self.optInst.best_config_dict))) if syst_expr_dict is not None: if 'final' in self.optInst.best_config_dict: diff --git a/src/smlp_py/smlp_optimize.py b/src/smlp_py/smlp_optimize.py old mode 100644 new mode 100755 index 907846ec..c9844b57 --- a/src/smlp_py/smlp_optimize.py +++ b/src/smlp_py/smlp_optimize.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # This file is part of smlp. - +import pysmt import smlp +from oauthlib.uri_validate import query from smlp_py.smlp_terms import SmlpTerms, ModelTerms, ScalerTerms from smlp_py.smlp_query import SmlpQuery from smlp_py.smlp_utils import (str_to_bool, np_JSONEncoder) @@ -13,6 +14,9 @@ import pandas as pd #import keras import numpy as np +from src.smlp_py.smtlib.text_to_sympy import TextToPysmtParser +from pysmt.shortcuts import Real +from src.smlp_py.solvers.universal_solver import Solver # single or multi-objective optimization, with stability constraints and any user # given constraints on free input, control (knob) and output variables satisfied. @@ -54,6 +58,7 @@ def __init__(self): self._DEF_APPROXIMATE_FRACTIONS:bool = True self._DEF_FRACTION_PRECISION:int = 64 self._DEF_OPTIMIZATION_STRATEGY:str = 'lazy' # TODO !!! define enum type for lazy/eager and API functions to get the strategy value, so neither strings no enum types will be ued outside this file + self._ENABLE_PYSMT = False # Formulae alpha, beta, eta are used in single and pareto optimization tasks. # They are used to constrain control variables x and response variables y as follows: @@ -448,7 +453,7 @@ def optimize_single_objective_eager(self, model_full_term_dict:dict, objv_name:s def optimize_single_objective(self, model_full_term_dict:dict, objv_name:str, objv_expr:str, objv_term:smlp.term2, epsilon:float, smlp_domain:smlp.domain, eta:smlp.form2, theta_radii_dict:dict, alpha:smlp.form2, beta:smlp.form2, delta:float, solver_logic:str, scale_objectives:bool, orig_objv_name:str, objv_bounds:dict, call_info=None, sat_approx=False, sat_precision=64, save_trace=False, - l0=None, u0=None, l=(-np.inf), u=np.inf): + l0=None, u0=None, l=(-np.inf), u=np.inf, pysmt_min_objs=None): self._opt_logger.info('Optimize single objective ' + str(objv_name) + ': Start') # initial lower bound l0 and initial upper bound u0 @@ -479,12 +484,18 @@ def optimize_single_objective(self, model_full_term_dict:dict, objv_name:str, ob else: T = (l + u) / 2 #quer_form = objv_term > smlp.Cnst(T) - #quer_form = objv_term >= smlp.Cnst(T) - # TODO !!!! use the following, avoid usage of >= - quer_form = self._smlpTermsInst.smlp_ge(objv_term, smlp.Cnst(T)); + quer_form = objv_term >= Solver.smlp_cnst(T) + quer_form = Solver.handle_ite_formula(formula=quer_form) + # quer_form = solver.create_query() + quer_expr = '{} >= {}'.format(objv_expr, str(T)) if objv_expr is not None else None quer_name = objv_name + '_' + str(T) - quer_and_beta = self._smlpTermsInst.smlp_and(quer_form, beta) if not beta == smlp.true else quer_form + if not beta == Solver._instance.smlp_true: + quer_and_beta = Solver.smlp_and(quer_form, beta) + # self._modelTermsInst.verifier.parser.and_(quer_form, beta) if self._ENABLE_PYSMT else self._smlpTermsInst.smlp_and(quer_form, beta) + # quer_and_beta = solver.create_query_and_beta(quer_form, beta) + else: + quer_and_beta = quer_form #print('quer_and_beta', quer_and_beta) 'u0_l0_u_l_T' self._opt_tracer.info('objective_thresholds_u0_l0_u_l_T, {} : {} : {} : {} : {}'.format(str(u0),str(l0),str(u),str(l),str(T))) quer_res = self._queryInst.query_condition( @@ -542,8 +553,20 @@ def optimize_single_objective(self, model_full_term_dict:dict, objv_name:str, ob #print('objv_term', objv_term, flush=True); print('stable_witness_terms', stable_witness_terms, flush=True) l_prev = l # save the value of l, it is for reporting only. #if objv_expr is not None: # the objective is not a symbolic max_min term, we may need its value, at least to see search progress - objv_witn_val_term = smlp.subst(objv_term, stable_witness_terms); #print('objv_witn_val_term', objv_witn_val_term) - #using objective values as lower bounds is not sound since objective value in sat model is the ceneter-point value + + objv_witn_val_term = Solver.substitute_objective_with_witness(stable_witness_terms=stable_witness_terms, objv_term=objv_term) + # if self._ENABLE_PYSMT: + # substitution = {} + # for symbol, value in stable_witness_terms.items(): + # symbol = self._modelTermsInst.verifier.parser.get_symbol(symbol) + # substitution[symbol] = Real(value) + # # Apply the substitution + # objv_witn_val_term = self._modelTermsInst.verifier.parser.simplify(pysmt_min_objs.substitute(substitution)) + # else: + # objv_witn_val_term = smlp.subst(objv_term, stable_witness_terms); #print('objv_witn_val_term', objv_witn_val_term) + + + #using objective values as lower bounds is not sound since objective value in sat model is the ceneter-point value # and the objective's value is not guaranteed to be a lower bound in entire stability region #objv_witn_val = self._smlpTermsInst.ground_smlp_expr_to_value(objv_witn_val_term, sat_approx, sat_precision) #assert objv_witn_val >= T @@ -601,7 +624,7 @@ def optimize_single_objectives(self, feat_names:list, resp_names:list, #X:pd.Dat #assert scale_objectives objv_terms_dict, orig_objv_terms_dict, scaled_objv_terms_dict = \ self._modelTermsInst.compute_objectives_terms(objv_names, objv_exprs, objv_bounds_dict, scale_objectives) - + # TODO: set sat_approx to False once dump and load with Fractions will work opt_conf = {} for i, (objv_name, objv_term) in enumerate(list(objv_terms_dict.items())): @@ -639,22 +662,34 @@ def optimize_single_objectives(self, feat_names:list, resp_names:list, #X:pd.Dat def active_objectives_max_min_bounds(self, model_full_term_dict:dict, objv_terms_dict:dict, t:list[float], smlp_domain:smlp.domain, alpha:smlp.form2, beta:smlp.form2, eta:smlp.form2, theta_radii_dict, epsilon:float, delta:float, solver_logic:str, strategy:str, direction, scale_objectives, objv_bounds, update_thresholds_dict, - sat_approx:bool, sat_precision:int, save_trace:bool): + sat_approx:bool, sat_precision:int, save_trace:bool, pysmt_objv_terms_dict=None): assert direction == 'up' eta_F_t = eta min_objs = None + pysmt_min_objs = None min_name = '' #print('thresholds t', t, 'objv_terms_dict', objv_terms_dict) for j, (objv_name, objv_term) in enumerate(objv_terms_dict.items()): if t[j] is not None: - eta_F_t = self._smlpTermsInst.smlp_and(eta_F_t, objv_term > smlp.Cnst(t[j])) + eta_F_t = Solver.calculate_eta_F_t(eta=eta_F_t, term=objv_term, val=t[j]) + # if self._ENABLE_PYSMT: + # eta_F_t = TextToPysmtParser.and_(eta_F_t, pysmt_objv_term > Real(t[j])) + # else: + # eta_F_t = self._smlpTermsInst.smlp_and(eta_F_t, objv_term > smlp.Cnst(t[j])) else: min_name = min_name + '_' + objv_name if min_name != '' else objv_name + # if self._ENABLE_PYSMT: + # pysmt_objv_term = pysmt_objv_terms_dict[objv_name] + # if pysmt_min_objs is not None: + # pysmt_min_objs = TextToPysmtParser.ite_(pysmt_objv_term < pysmt_min_objs, pysmt_objv_term, pysmt_min_objs) + # else: + # pysmt_min_objs = pysmt_objv_term + # else: if min_objs is not None: - min_objs = smlp.Ite(objv_term < min_objs, objv_term, min_objs) + min_objs = Solver.smlp_ite(objv_term < min_objs, objv_term, min_objs) else: min_objs = objv_term - + # When active_objectives_max_min_bounds() is called for the first time from # optimize_pareto_objectives(), the list t which represents the proven lower # bounds of objectives, is composed of None's, and the proven lower @@ -701,7 +736,7 @@ def active_objectives_max_min_bounds(self, model_full_term_dict:dict, objv_terms r = self.optimize_single_objective(model_full_term_dict, min_name, None, min_objs, epsilon, smlp_domain, eta_F_t, theta_radii_dict, alpha, beta, delta, solver_logic, scale_objectives, min_name, objv_bounds, update_thresholds_dict, - sat_approx, sat_precision, save_trace, l0, u0, l, u) + sat_approx, sat_precision, save_trace, l0, u0, l, u, pysmt_min_objs) elif strategy == 'eager': if self._tmp_return_eager_tuples: r = {} @@ -918,6 +953,10 @@ def optimize_pareto_objectives(self, feat_names:list[str], resp_names:list[str], objv_terms_dict, orig_objv_terms_dict, scaled_objv_terms_dict = \ self._modelTermsInst.compute_objectives_terms(objv_names, objv_exprs, objv_bounds_dict, scale_objectives) + pysmt_objv_terms_dict = None + # pysmt_objv_terms_dict, pysmt_orig_objv_terms_dict, pysmt_scaled_objv_terms_dict = \ + # self._modelTermsInst.pysmt_compute_objectives_terms(objv_names, objv_exprs, objv_bounds_dict, + # scale_objectives) objv_count = len(objv_names) objv_enum = range(objv_count) @@ -951,7 +990,7 @@ def sanity_check_fixed_objv_thresholds(t:list[float], fixed_onjv_dict): self._opt_tracer.info('pareto_iteration,{},{},{}'.format(str(call_n), '__'.join(objv_names), '__'.join([str(e) for e in s]))) c_lo, c_up, witness = self.active_objectives_max_min_bounds(model_full_term_dict, objv_terms_dict, s, smlp_domain, alpha, beta, eta, theta_radii_dict, epsilon, delta, solver_logic, strategy, direction, - scale_objectives, objv_bounds_dict, call_info_dict, sat_approx, sat_precision, save_trace) + scale_objectives, objv_bounds_dict, call_info_dict, sat_approx, sat_precision, save_trace, pysmt_objv_terms_dict) #print('c_lo', c_lo, 'c_up', c_up); print('witness', witness); assert c_lo != np.inf @@ -1016,12 +1055,23 @@ def sanity_check_fixed_objv_thresholds(t:list[float], fixed_onjv_dict): self._opt_logger.info('Checking whether to fix objective {} at threshold {}...\n'.format(str(j), str(s[j]))) self._opt_tracer.info('activity check, objective {} threshold {}'.format(str(objv_names[j]), str(s[j]))) #print('objv_terms_dict', objv_terms_dict) - quer_form = smlp.true + # quer_form = pysmt.shortcuts.TRUE() if self._ENABLE_PYSMT else smlp.true + quer_form = Solver._instance.smlp_true for i in objv_enum: #print('obv i', list(objv_terms_dict.keys())[i]) - quer_form = self._smlpTermsInst.smlp_and(quer_form, list(objv_terms_dict.values())[i] > smlp.Cnst(t[i])) + # if self._ENABLE_PYSMT: + # quer_form = self._modelTermsInst.parser.and_(quer_form, + # list(pysmt_objv_terms_dict.values())[i] > pysmt.shortcuts.Real(t[i])) + # else: + quer_form = Solver.smlp_and(quer_form, list(objv_terms_dict.values())[i] > Solver.smlp_cnst(t[i])) #print('queryform', quer_form) - quer_and_beta = self._smlpTermsInst.smlp_and(quer_form, beta) if not beta == smlp.true else quer_form + if not beta == Solver._instance.smlp_true: + quer_and_beta = Solver.smlp_and(quer_form, beta) + else: + quer_and_beta = quer_form + + quer_and_beta = Solver.z3_simplify(quer_and_beta) + opt_quer_name = 'thresholds_' + '_'.join(str(x) for x in t) + '_check' quer_res = self._queryInst.query_condition(True, model_full_term_dict, opt_quer_name, 'True', quer_and_beta, smlp_domain, eta, alpha, theta_radii_dict, delta, solver_logic, True, sat_approx, sat_precision) @@ -1108,7 +1158,7 @@ def smlp_optimize(self, syst_expr_dict:dict, algo:str, model:dict, X:pd.DataFram quer_names:list[str], quer_exprs, delta:float, epsilon:float, alph_expr:str, beta_expr:str, eta_expr:str, theta_radii_dict:dict, solver_logic:str, vacuity:bool, data_scaler:str, scale_feat:bool, scale_resp:bool, scale_objv:bool, - float_approx=True, float_precision=64, data_bounds_json_path=None, bounds_factor=None, T_resp_bounds_csv_path=None): + float_approx=True, float_precision=64, data_bounds_json_path=None, bounds_factor=None, T_resp_bounds_csv_path=None, use_pysmt=False): self.objv_names = objv_names self.objv_exprs = objv_exprs self.feat_names = feat_names @@ -1118,7 +1168,12 @@ def smlp_optimize(self, syst_expr_dict:dict, algo:str, model:dict, X:pd.DataFram # output to user initial values of mode status with open(self.optimization_results_file+'.json', 'w') as f: json.dump(self.mode_status_dict, f, indent='\t', cls=np_JSONEncoder) - + + # initialise Solver + Solver(specs=(feat_names, resp_names, self._modelTermsInst._specInst.get_spec_domain_dict), + version=Solver.Version.PYSMT if use_pysmt else Solver.Version.FORM2) + + domain, syst_term_dict, model_full_term_dict, eta, alpha, beta, interface_consistent, model_consistent = \ self._modelTermsInst.create_model_exploration_base_components( syst_expr_dict, algo, model, model_features_dict, feat_names, resp_names, diff --git a/src/smlp_py/smlp_query.py b/src/smlp_py/smlp_query.py index 4061f38d..1fbed730 100644 --- a/src/smlp_py/smlp_query.py +++ b/src/smlp_py/smlp_query.py @@ -8,6 +8,11 @@ from smlp_py.smlp_terms import ModelTerms, SmlpTerms from smlp_py.smlp_utils import np_JSONEncoder #, str_to_bool +from src.smlp_py.NN_verifiers.verifiers import MarabouVerifier +from src.smlp_py.solvers.universal_solver import Solver + +from src.smlp_py.solvers.marabou.solver import Pysmt_Solver + class SmlpQuery: def __init__(self): @@ -42,6 +47,7 @@ def __init__(self): self._trace_runtime = None self._trace_precision = None self._trace_anonymize = None + self._ENABLE_PYSMT = False def set_logger(self, logger): self._query_logger = logger @@ -97,12 +103,11 @@ def synthesis_results_file(self): return self.report_file_prefix + '_synthesize_results.json' def find_candidate(self, solver): - #res = solver.check() - res = self._modelTermsInst.smlp_solver_check(solver, 'ca', self._lemma_precision) - if self._modelTermsInst.solver_status_unknown(res): # isinstance(res, smlp.unknown): + res, witness = self._modelTermsInst.smlp_solver_check(solver, 'ca', self._lemma_precision) + if res == "unknown": return None else: - return res + return res, witness def update_consistecy_results(self, mode_status_dict, interface_consistent, model_consistent, mode_status, mode_results_file): @@ -145,14 +150,29 @@ def get_model_exploration_base_components(self, mode_status_dict, results_file, # just for small potential speedup. def check_concrete_witness_consistency(self, domain:smlp.domain, model_full_term_dict:dict, alpha:smlp.form2, eta:smlp.form2, query:smlp.form2, witn_form:smlp.form2, solver_logic:str): - solver = self._modelTermsInst.create_model_exploration_instance_from_smlp_components( - domain, model_full_term_dict, True, solver_logic) - solver.add(alpha); #print('alpha', alpha) - solver.add(eta); #print('eta', eta) - solver.add(witn_form); #print('witn_form', witn_form); print('query', query) + # solver = self._modelTermsInst.create_model_exploration_instance_from_smlp_components( + # domain, model_full_term_dict, True, solver_logic) + # solver.add(alpha); #print('alpha', alpha) + # solver.add(eta); #print('eta', eta) + # solver.add(witn_form); #print('witn_form', witn_form); print('query', query) + + candidate_solver = Solver.create_solver( + create_solver=self._modelTermsInst.create_model_exploration_instance_from_smlp_components, + domain=domain, + model_full_term_dict=model_full_term_dict, + incremental=True, + solver_logic=solver_logic + ) + candidate_solver.add_formula(eta, need_simplification=True) + candidate_solver.add_formula(alpha) + candidate_solver.add_formula(witn_form, need_simplification=True) + + printer = {'alpha':alpha, 'eta':eta,'witn_form':witn_form} if query is not None: - solver.add(query) - res = self._modelTermsInst.smlp_solver_check(solver, 'witness_consistency') + printer['query'] = query + # solver.add(query) + candidate_solver.add_formula(query, need_simplification=True) + res = self._modelTermsInst.smlp_solver_check(candidate_solver, 'witness_consistency', equations=printer ) #res = solver.check(); #print('res', res) return res @@ -167,14 +187,24 @@ def check_concrete_witness_consistency(self, domain:smlp.domain, model_full_term # theta x y /\ alpha y /\ ! ( beta y /\ obj y >= T) def find_candidate_counter_example(self, universal, domain:smlp.domain, cand:dict, query:smlp.form2, model_full_term_dict:dict, alpha:smlp.form2, theta_radii_dict:dict, solver_logic:str): #, beta:smlp.form2 - solver = self._modelTermsInst.create_model_exploration_instance_from_smlp_components( - domain, model_full_term_dict, False, solver_logic) - theta = self._modelTermsInst.compute_stability_formula_theta(cand, None, theta_radii_dict, universal) - solver.add(theta); #print('adding theta', theta) - solver.add(alpha); #print('adding alpha', alpha) - solver.add(self._smlpTermsInst.smlp_not(query)); #print('adding negated quert', query) - return self._modelTermsInst.smlp_solver_check(solver, 'ce', self._lemma_precision) - #return solver.check() + + theta = self._modelTermsInst.compute_stability_formula_theta(cand, None, theta_radii_dict, universal) + solver = Solver.create_counter_example( + create_solver=self._modelTermsInst.create_model_exploration_instance_from_smlp_components, + domain=domain, + model_full_term_dict=model_full_term_dict, + incremental=False, + solver_logic=solver_logic, + formulas=[alpha, theta], + query=query + ) + + + + return self._modelTermsInst.smlp_solver_check(solver, 'ce', self._lemma_precision, + {'alpha': alpha, 'theta': theta, + 'not_query': self._smlpTermsInst.smlp_not(query)}, temp=True) + # Enhancement !!!: at least add here the delta condition def generalize_counter_example(self, coex): @@ -190,28 +220,31 @@ def validate_witness_smt(self, universal:bool, model_full_term_dict:dict, quer_n self._query_logger.info('Verifying assertion {} <-> {}'.format(str(quer_name), str(quer_expr))) else: self._query_logger.info('Certifying stability of witness for query ' + str(quer_name) + ':\n ' + str(witn_dict)) - candidate_solver = self._modelTermsInst.create_model_exploration_instance_from_smlp_components( - domain, model_full_term_dict, True, solver_logic) - - cond_feasible = None - # add the remaining user constraints and the query - candidate_solver.add(eta); #print('adding eta', eta) - candidate_solver.add(alpha); #print('adding alpha', alpha) - #candidate_solver.add(beta) - candidate_solver.add(quer); #print('adding quer', quer) - #print('adding witn_dict', witn_dict) + + candidate_solver = Solver.create_solver( + create_solver=self._modelTermsInst.create_model_exploration_instance_from_smlp_components, + domain=domain, + model_full_term_dict=model_full_term_dict, + incremental=True, + solver_logic=solver_logic + ) + candidate_solver.add_formula(eta, need_simplification=True) + candidate_solver.add_formula(alpha) + candidate_solver.add_formula(quer, need_simplification=True) + + for var,val in witn_dict.items(): #candidate_solver.add(smlp.Var(var) == smlp.Cnst(val)) - candidate_solver.add(self._smlpTermsInst.smlp_eq(smlp.Var(var), smlp.Cnst(val))) + candidate_solver.add_formula(Solver.smlp_eq(Solver.smlp_var(var), Solver.smlp_cnst(val))) - candidate_check_res = self._modelTermsInst.smlp_solver_check(candidate_solver, 'ca') - if self._modelTermsInst.solver_status_sat(candidate_check_res): #isinstance(candidate_check_res, smlp.sat): + candidate_check_res, _ = self._modelTermsInst.smlp_solver_check(candidate_solver, 'ca', equations={'alpha':alpha, 'eta':eta,'quer':quer}) + if candidate_check_res == "sat": #isinstance(candidate_check_res, smlp.sat): cond_feasible = True if universal: self._query_logger.info('The configuration is consistent with assertion ' + str(quer_name)) else: self._query_logger.info('Witness to query ' + str(quer_name) + ' is a valid witness; checking its stability') - elif self._modelTermsInst.solver_status_unsat(candidate_check_res): #isinstance(candidate_check_res, smlp.unsat): + elif candidate_check_res == "unsat": #isinstance(candidate_check_res, smlp.unsat): cond_feasible = False if universal: # Assertion cannot be satisfied (is constant False) given the knob configuration and the constraints. @@ -237,20 +270,20 @@ def validate_witness_smt(self, universal:bool, model_full_term_dict:dict, quer_n # checking stability of a valid witness to the query witn_term_dict = self._smlpTermsInst.witness_const_to_term(witn_dict) - ce = self.find_candidate_counter_example(universal, domain, witn_term_dict, quer, model_full_term_dict, alpha, + ce, ce_witness = self.find_candidate_counter_example(universal, domain, witn_term_dict, quer, model_full_term_dict, alpha, theta_radii_dict, solver_logic) - if self._modelTermsInst.solver_status_sat(ce): #isinstance(ce, smlp.sat): + if ce == "sat": #isinstance(ce, smlp.sat): if universal: self._query_logger.info('Completed with result: FAIL') #self._query_logger.info('Assertion ' + str(quer_name) + ' fails (for stability radii ' + str(theta_radii_dict)) #status = 'FAIL' if cond_feasible else 'FAIL VACUOUSLY' - ce_model = self._modelTermsInst.get_solver_model(ce) + ce_model = self._modelTermsInst.get_solver_model(ce, ce_witness) return {'assertion_status':'FAIL', 'asrt': False, 'assertion_feasible': cond_feasible, 'counter_example':self._smlpTermsInst.witness_term_to_const(ce_model, approximate=sat_approx, precision=sat_precision)} else: self._query_logger.info('Witness to query ' + str(quer_name) + ' is not stable for radii ' + str(theta_radii_dict)) return 'witness, not stable' - elif self._modelTermsInst.solver_status_unsat(ce): #isinstance(ce, smlp.unsat): + elif ce == "unsat": #isinstance(ce, smlp.unsat): if universal: self._query_logger.info('Completed with result: PASS') #self._query_logger.info('Assertion ' + str(quer_name) + ' passes (for stability radii ' + str(theta_radii_dict)) @@ -404,15 +437,15 @@ def validate_witness(self, universal:bool, syst_expr_dict:dict, algo:str, model: self._query_logger.info('Verifying consistency of configuration for assertion ' + str(quer_name) + ':\n ' + str(witn_form)) else: self._query_logger.info('Certifying consistency of witness for query ' + str(quer_name) + ':\n ' + str(witn_form)) - witn_status = self.check_concrete_witness_consistency(domain, model_full_term_dict, + witn_status, _ = self.check_concrete_witness_consistency(domain, model_full_term_dict, alpha, eta, None, witn_form, solver_logic) - if self._modelTermsInst.solver_status_sat(witn_status): #isinstance(witn_status, smlp.sat): + if witn_status == "sat": #isinstance(witn_status, smlp.sat): if universal: self._query_logger.info('Input, knob and configuration constraints are consistent') else: self._query_logger.info('Input, knob and concrete witness constraints are consistent') mode_status_dict[quer_name][CONSISTENCY] = 'true' - elif self._modelTermsInst.solver_status_unsat(witn_status): #isinstance(witn_status, smlp.unsat): + elif witn_status == "unsat": #isinstance(witn_status, smlp.unsat): if universal: self._query_logger.info('Input, knob and configuration constraints are inconsistent') else: @@ -521,15 +554,38 @@ def query_condition(self, universal, model_full_term_dict:dict, quer_name:str, q else: self._query_logger.info('Querying condition {} <-> {}'.format(str(quer_name), str(quer))) #print('query', quer, 'eta', eta, 'delta', delta) - candidate_solver = self._modelTermsInst.create_model_exploration_instance_from_smlp_components( - domain, model_full_term_dict, True, solver_logic) - - # add the remaining user constraints and the query - candidate_solver.add(eta) - candidate_solver.add(alpha) - #candidate_solver.add(beta) - candidate_solver.add(quer) - #print('eta', eta); print('alpha', alpha); print('quer', quer); + + # if not self._ENABLE_PYSMT: + # candidate_solver = self._modelTermsInst.create_model_exploration_instance_from_smlp_components( + # domain, model_full_term_dict, True, solver_logic) + # + # # add the remaining user constraints and the query + # candidate_solver.add(eta) + # candidate_solver.add(alpha) + # #candidate_solver.add(beta) + # candidate_solver.add(quer) + # else: + # self._modelTermsInst.verifier.reset() + # self._modelTermsInst.verifier.apply_restrictions(eta, need_simplification=True) + # self._modelTermsInst.verifier.apply_restrictions(alpha) + # self._modelTermsInst.verifier.apply_restrictions(quer, need_simplification=True) + + candidate_solver = Solver.create_solver( + create_solver=self._modelTermsInst.create_model_exploration_instance_from_smlp_components, + domain=domain, + model_full_term_dict=model_full_term_dict, + incremental=True, + solver_logic=solver_logic + ) + candidate_solver.add_formula(eta, need_simplification=True) + candidate_solver.add_formula(alpha) + candidate_solver.add_formula(quer, need_simplification=True) + print("IN QUERY CONDITION") + # res = self.smlp_solver_check(solver, + # 'interface_consistency' if model_full_term_dict is None else 'model_consistency', + # equations={'alpha': alpha, 'eta': eta}) + + #print('eta', eta); print('alpha', alpha); print('quer', quer); #print('solving query', quer) self._query_tracer.info('{},{}'.format('synthesis' if universal else 'query', str(quer_name))) #, str(quer_expr) ,{} use_approxiamted_fractions = self._lemma_precision != 0 @@ -539,13 +595,22 @@ def query_condition(self, universal, model_full_term_dict:dict, quer_name:str, q while True: # solve Ex. eta x /\ Ay. theta x y -> alpha y -> (beta y /\ query) print('searching for a candidate', flush=True) - - ca = self.find_candidate(candidate_solver) - - if self._modelTermsInst.solver_status_sat(ca): # isinstance(ca, smlp.sat): + if isinstance(candidate_solver, Pysmt_Solver): + print('PYSMT FORMULA', {'alpha': alpha, 'eta': eta, 'quer': quer.serialize()}) + else: + print('FORM2 FORMULA', {'alpha': alpha, 'eta': eta, 'quer': quer}) + + result, ca = self.find_candidate(candidate_solver) + + # condition_sat = self._modelTermsInst.solver_status_sat(ca) + # condition_unsat = self._modelTermsInst.solver_status_unsat(ca["result"]) if self._ENABLE_PYSMT else self._modelTermsInst.solver_status_unsat(ca) + # condition_unknown = self._modelTermsInst.solver_status_unsat(ca["result"]) if self._ENABLE_PYSMT else self._modelTermsInst.solver_status_unknown(ca) + + if result == "sat": # isinstance(ca, smlp.sat): print('candidate found -- checking stability', flush=True) #print('ca', ca_model) - ca_model = self._modelTermsInst.get_solver_model(ca) #ca.model + # ca_model = self._modelTermsInst.get_solver_model(ca) #ca.model + ca_model = Solver.get_witness(result=result, witness=ca, interface=self._modelTermsInst._specInst.get_spec_interface) #ca.model if use_approxiamted_fractions: ca_model_approx = self._smlpTermsInst.approximate_witness_term(ca_model, self._lemma_precision) #print('ca_model_approx -------------', ca_model_approx) @@ -560,14 +625,21 @@ def query_condition(self, universal, model_full_term_dict:dict, quer_name:str, q #print('ca_model_approx', ca_model_approx) feasible = True if use_approxiamted_fractions: - ce = self.find_candidate_counter_example(universal, domain, ca_model_approx, quer, model_full_term_dict, alpha, + c_result, ce = self.find_candidate_counter_example(universal, domain, ca_model_approx, quer, model_full_term_dict, alpha, theta_radii_dict, solver_logic) else: - ce = self.find_candidate_counter_example(universal, domain, ca_model, quer, model_full_term_dict, alpha, + c_result, ce = self.find_candidate_counter_example(universal, domain, ca_model, quer, model_full_term_dict, alpha, theta_radii_dict, solver_logic) - if self._modelTermsInst.solver_status_sat(ce): #isinstance(ce, smlp.sat): + + # is_sat = self._modelTermsInst.solver_status_sat(ce["result"]) if self._ENABLE_PYSMT else self._modelTermsInst.solver_status_sat(ce) + # is_unsat = self._modelTermsInst.solver_status_unsat(ce["result"]) if self._ENABLE_PYSMT else self._modelTermsInst.solver_status_unsat(ce) + + if c_result == "sat": #isinstance(ce, smlp.sat): print('candidate not stable -- continue search', flush=True) - ce_model = self._modelTermsInst.get_solver_model(ce) #ce.model + ce_model = Solver.get_witness(result=result, witness=ca, + interface=self._modelTermsInst._specInst.get_spec_interface) # ca.model + + cem = ce_model.copy(); #print('ce model', cem) # drop Assignements to responses from ce for var in ce_model.keys(): @@ -589,13 +661,23 @@ def query_condition(self, universal, model_full_term_dict:dict, quer_name:str, q else: lemma = self.generalize_counter_example(cem); #print('lemma', lemma) theta = self._modelTermsInst.compute_stability_formula_theta(lemma, delta, theta_radii_dict, universal) - candidate_solver.add(self._smlpTermsInst.smlp_not(theta)) + Solver.apply_theta(solver=candidate_solver, formula=theta) + # if self._ENABLE_PYSMT: + # theta_negation = self._modelTermsInst.parser.propagate_negation(theta) + # # self._modelTermsInst.verifier.add_permanent_constraint(theta_negation) + # self._modelTermsInst.verifier.apply_restrictions(theta_negation) + # print("PYSMT THETA ADDED ", theta_negation) + # else: + # candidate_solver.add(self._smlpTermsInst.smlp_not(theta)) continue - elif self._modelTermsInst.solver_status_unsat(ce): #isinstance(ce, smlp.unsat): + elif c_result == "unsat": #isinstance(ce, smlp.unsat): #print('candidate stable -- return candidate') self._query_logger.info('Query completed with result: STABLE_SAT (satisfiable)') if witn: # export witness (use numbers as values, not terms) - ca_model = self._modelTermsInst.get_solver_model(ca) # ca.model + ca_model = Solver.get_witness(result=result, witness=ca, + interface=self._modelTermsInst._specInst.get_spec_interface) # ca.model + + witness_vals_dict = self._smlpTermsInst.witness_term_to_const(ca_model, sat_approx, sat_precision) #print('domain witness_vals_dict', witness_vals_dict) # sanity check: the value of query in the sat assignment should be true @@ -605,14 +687,18 @@ def query_condition(self, universal, model_full_term_dict:dict, quer_name:str, q return {'query_status':'STABLE_SAT', 'witness':witness_vals_dict, 'feasible':feasible} else: return {'query_status':'STABLE_SAT', 'witness':ca_model, 'feasible':feasible} - elif self._modelTermsInst.solver_status_unsat(ca): #isinstance(ca, smlp.unsat): + # if self._ENABLE_PYSMT: + # return {'query_status':'STABLE_SAT', 'witness':ca['witness'], 'feasible':feasible} + # else: + # return {'query_status':'STABLE_SAT', 'witness':ca_model, 'feasible':feasible} + elif result == "unsat": self._query_logger.info('Query completed with result: UNSAT (unsatisfiable)') if feasible is None: feasible = False #print('candidate does not exist -- query unsuccessful') #print('query unsuccessful: witness does not exist (query is unsat)') return {'query_status':'UNSAT', 'witness':None, 'feasible':feasible} - elif self._modelTermsInst.solver_status_unknown(ca): #isinstance(ca, smlp.unknown): + elif result == "unknown": #isinstance(ca, smlp.unknown): self._opt_logger.info('Completed with result: {}'.format('UNKNOWN')) return {'query_status':'UNKNOWN', 'witness':None, 'feasible':feasible} #raise Exception('UNKNOWN return value in candidate search is currently not supported for queries') diff --git a/src/smlp_py/smlp_solver.py b/src/smlp_py/smlp_solver.py index 4528ce49..472874b8 100644 --- a/src/smlp_py/smlp_solver.py +++ b/src/smlp_py/smlp_solver.py @@ -12,6 +12,7 @@ def __init__(self): self._DEF_SOLVER = 'z3' self._DEF_SOLVER_PATH = None self._DEF_SOLVER_LOGIC = 'ALL' + self.use_pysmt = False #self._DEF_SOLVER_INCREMENTAL = True ''' @@ -35,6 +36,9 @@ def __init__(self): 'help':'SMT2-lib theory with respect to which to solve model exploration task at hand, ' + 'in modes "verify," "query", "optimize" and "optsyn". ' + '[default: {}]'.format(str(self._DEF_SOLVER_LOGIC))}, + 'use_pysmt': {'abbr': 'use_pysmt', 'default': self.use_pysmt, 'type': str_to_bool, + 'help': 'Solver to use in model exploration modes "verify," "query", "optimize" and "optsyn". ' + + '[default: {}]'.format(str(self.use_pysmt))}, #'solver_incr': {'abbr':'solver_incr', 'default': self._DEF_SOLVER_INCREMENTAL, 'type':str_to_bool, # 'help':'Should sover be used in incremental mode? ' + # '[default: {}]'.format(str(self._DEF_SOLVER_INCREMENTAL))} diff --git a/src/smlp_py/smlp_terms.py b/src/smlp_py/smlp_terms.py index c5252040..459fb247 100644 --- a/src/smlp_py/smlp_terms.py +++ b/src/smlp_py/smlp_terms.py @@ -5,6 +5,8 @@ import numpy as np import pandas as pd import keras +from Cython.Compiler.TreePath import operations +from pysmt.fnode import FNode from sklearn.tree import _tree import json import ast @@ -18,10 +20,19 @@ from enum import Enum import smlp -from smlp_py.smlp_utils import (np_JSONEncoder, lists_union_order_preserving_without_duplicates, +from src.smlp_py.smlp_utils import (np_JSONEncoder, lists_union_order_preserving_without_duplicates, list_subtraction_set, get_expression_variables, str_to_bool) #from smlp_py.smlp_spec import SmlpSpec +from pysmt.shortcuts import Real as pysmtReal +from src.smlp_py.NN_verifiers.verifiers import MarabouVerifier +import pysmt +from src.smlp_py.smtlib.text_to_sympy import TextToPysmtParser +from src.smlp_py.solvers.universal_solver import Solver + +from src.smlp_py.solvers.z3.operations import SMLPOperations + +from src.smlp_py.solvers.z3.solver import Form2_Solver # TODO !!! create a parent class for TreeTerms, PolyTerms, NNKerasTerms. # setting logger, report_file_prefix, model_file_prefix can go to that class to work for all above three classes @@ -51,16 +62,7 @@ # to solver instance separately (as many as required, depending on whether all responses are analysed together). -USE_CACHE = False -def conditional_cache(func): - """Custom decorator to conditionally apply @functools.cache.""" - if USE_CACHE: - # Apply caching - return functools.cache(func) - else: - # Return the original function without caching - return func ''' def conditional_cache(func): @@ -78,7 +80,7 @@ def wrapper(self, *args, **kwargs): # Class SmlpTerms has methods for generating terms, and classes TreeTerms, PolyTerms and NNKerasTerms are inherited # from it but this inheritance is probably not implemented in the best way: TODO !!!: see if that can be improved. -class SmlpTerms: +class SmlpTerms(SMLPOperations): def __init__(self): self._smlp_terms_logger = None self.report_file_prefix = None @@ -127,8 +129,9 @@ def __init__(self): ast.LtE: self.smlp_le, ast.Gt: self.smlp_gt, ast.GtE: self.smlp_ge, ast.And: self.smlp_and, ast.Or: self.smlp_or, ast.Not: self.smlp_not, ast.IfExp: self.smlp_ite - } - + } + self._ENABLE_PYSMT = False + # set logger from a caller script def set_logger(self, logger): self._smlp_terms_logger = logger @@ -585,8 +588,17 @@ def ground_smlp_expr_to_value(self, ground_term:smlp.term2, approximate=False, p # Can also be applied to a dictionary where values are terms. def witness_term_to_const(self, witness, approximate=False, precision=64): witness_vals_dict = {} + # if self._ENABLE_PYSMT: + # return witness for k,t in witness.items(): - witness_vals_dict[k] = self.ground_smlp_expr_to_value(t, approximate, precision) + if isinstance(t, smlp.term2): + new_value = self.ground_smlp_expr_to_value(t, approximate, precision) + elif isinstance(t, pysmt.fnode.FNode) and t.is_constant(): + new_value = float(t.constant_value()) + else: + new_value = float(t) + + witness_vals_dict[k] = new_value return witness_vals_dict # computes and returns sat assignment witness_approx which approximates input witness/sat assignment @@ -1693,22 +1705,23 @@ def _unscaled_name(self, name): # x_scaled obtained from x using min_max scaler to range [0, 1] (which is the same as normalizin x), # orig_min stands for min(x) and orig_max stands for max(x). Note that 1 / (max(x) - min(x)) is a # rational constant, it is defined to smlp instance as a fraction (thus there is no loss of precision). - def feature_scaler_to_term(self, orig_feat_name, scaled_feat_name, orig_min, orig_max): + def feature_scaler_to_term(self, orig_feat_name, scaled_feat_name, orig_min, orig_max, allow_solver=False): #print('feature_scaler_to_term', 'orig_min', orig_min, type(orig_min), 'orig_max', orig_max, type(orig_max), flush=True) + operations = Solver if allow_solver else self if orig_min == orig_max: - return self.smlp_cnst(0) #smlp.Cnst(0) # same as returning smlp.Cnst(smlp.Q(0)) + return operations.smlp_cnst(0) #smlp.Cnst(0) # same as returning smlp.Cnst(smlp.Q(0)) else: - return self.smlp_mult( - self.smlp_cnst(self.smlp_q(1) / self.smlp_q(orig_max - orig_min)), - (self.smlp_var(orig_feat_name) - self.smlp_cnst(orig_min))) + return operations.smlp_mult( + operations.smlp_cnst(operations.smlp_q(1) / operations.smlp_q(orig_max - orig_min)), + (operations.smlp_var(orig_feat_name) - operations.smlp_cnst(orig_min))) ####return self.smlp_div(self.smlp_var(orig_feat_name) - self.smlp_cnst(orig_min), self.smlp_cnst(orig_max) - self.smlp_cnst(orig_min)) ####return smlp.Cnst(smlp.Q(1) / smlp.Q(orig_max - orig_min)) * (smlp.Var(orig_feat_name) - smlp.Cnst(orig_min)) - + # Computes dictionary with features as keys and scaler terms as values - def feature_scaler_terms(self, data_bounds, feat_names): - return dict([(self._scaled_name(feat), self.feature_scaler_to_term(feat, self._scaled_name(feat), - data_bounds[feat]['min'], data_bounds[feat]['max'])) for feat in feat_names]) - + def feature_scaler_terms(self, data_bounds, feat_names, allow_solver=False): + return dict([(self._scaled_name(feat), self.feature_scaler_to_term(feat, self._scaled_name(feat), + data_bounds[feat]['min'], data_bounds[feat]['max'], allow_solver=allow_solver)) for feat in feat_names]) + # Computes term x from column x_scaled using expression x = x_scaled * (max_x - min_x) + x_min. # Argument orig_feat_name is name for column x, argument scaled_feat_name is the name of scaled column # x_scaled obtained earlier from x using min_max scaler to range [0, 1] (same as normalization of x), @@ -1794,6 +1807,16 @@ def __init__(self): # 'help':'Should terms be cached along building terms and formulas in model exploration modes? ' + # '[default {}]'.format(str(self._DEF_CACHE_TERMS))} } + + # self.parser = TextToPysmtParser() + # self.parser.init_variables(symbols=[("x1", "real", True), ('x2', 'int', True), ('p1', 'real', True), ('p2', 'int', True), + # ('y1', 'real', False), ('y2', 'real', False)]) + # + # self.verifier = MarabouVerifier(parser=self.parser) + # + # self._ENABLE_PYSMT = False + # self._RETURN_PYSMT = False + # set logger from a caller script def set_logger(self, logger): @@ -1965,7 +1988,8 @@ def _compute_model_terms_dict(self, algo, model, feat_names, resp_names, data_bo if tree_flat_encoding or nn_keras_flat_encoding: model_term = [self.smlp_cnst_fold(form, {feat_name: feat_term}) for form in model_term] else: - model_term = self.smlp_cnst_fold(model_term, {feat_name: feat_term}) #self.smlp_subst + model_term = self.smlp_cnst_fold(model_term, {feat_name: feat_term}) + # model_term = Solver.substitute(var=model_term, substitutions={feat_name: feat_term}) #self.smlp_subst #print('model term after', model_term, flush=True) model_term_dict[resp_name] = model_term #print('model_term_dict with unscaled features', model_term_dict, flush=True) @@ -2084,18 +2108,24 @@ def compute_objectives_terms(self, objv_names, objv_exprs, objv_bounds, scale_ob #print('objv_exprs', objv_exprs) if objv_exprs is None: return None, None, None, None - orig_objv_terms_dict = dict([(objv_name, self.ast_expr_to_term(objv_expr)) \ + orig_objv_terms_dict = dict([(objv_name, Solver.parse_ast(parser=self.ast_expr_to_term, expression=objv_expr)) \ for objv_name, objv_expr in zip(objv_names, objv_exprs)]) #self._smlpTermsInst. #print('orig_objv_terms_dict', orig_objv_terms_dict) + if scale_objv: - scaled_objv_terms_dict = self.feature_scaler_terms(objv_bounds, objv_names) #._scalerTermsInst + # allow_pysmt = isinstance(next(iter(orig_objv_terms_dict), FNode) + + scaled_objv_terms_dict = self.feature_scaler_terms(objv_bounds, objv_names, allow_solver=True) # ._scalerTermsInst + #print('scaled_objv_terms_dict', scaled_objv_terms_dict) objv_terms_dict = {} for i, (k, v) in enumerate(scaled_objv_terms_dict.items()): #print('k', k, 'v', v, type(v)); x = list(orig_objv_terms_dict.keys())[i]; #print('x', x); print('arg', orig_objv_terms_dict[x]) - objv_terms_dict[k] = self.smlp_cnst_fold(v, {x: orig_objv_terms_dict[x]}) #self.smlp_subst + + objv_terms_dict[k] = Solver.substitute(var=v, substitutions={x: orig_objv_terms_dict[x]}) + # objv_terms_dict[k] = self.smlp_cnst_fold(v, {x: orig_objv_terms_dict[x]}) #objv_terms_dict = scaled_objv_terms_dict else: objv_terms_dict = orig_objv_terms_dict @@ -2106,7 +2136,41 @@ def compute_objectives_terms(self, objv_names, objv_exprs, objv_bounds, scale_ob for objv_name in objv_names] #print('objv_terms_dict', objv_terms_dict) return objv_terms_dict, orig_objv_terms_dict, scaled_objv_terms_dict - + + def pysmt_compute_objectives_terms(self, objv_names, objv_exprs, objv_bounds, scale_objv): + # print('objv_exprs', objv_exprs) + if objv_exprs is None: + return None, None, None, None + + pysmt_objv_terms_dict = dict([(objv_name, self.parser.parse(objv_expr)) \ + for objv_name, objv_expr in zip(objv_names, objv_exprs)]) + + if scale_objv: + scaled_objv_terms_dict = self.feature_scaler_terms(objv_bounds, objv_names, + self.parser) # ._scalerTermsInst + + # print('scaled_objv_terms_dict', scaled_objv_terms_dict) + objv_terms_dict = {} + for i, (k, v) in enumerate(scaled_objv_terms_dict.items()): + # print('k', k, 'v', v, type(v)); + x = list(pysmt_objv_terms_dict.keys())[i]; + # print('x', x); print('arg', orig_objv_terms_dict[x]) + # if self._ENABLE_PYSMT: + substitution = {self.parser.get_symbol(x): pysmt_objv_terms_dict[x]} + # Apply the substitution + objv_terms_dict[k] = self.parser.simplify(v.substitute(substitution)) + # else: + # objv_terms_dict = scaled_objv_terms_dict + else: + objv_terms_dict = pysmt_objv_terms_dict + scaled_objv_terms_dict = None + + if scaled_objv_terms_dict is not None: + assert list(scaled_objv_terms_dict.keys()) == [self._scaled_name(objv_name) # ._scalerTermsInst + for objv_name in objv_names] + # print('objv_terms_dict', objv_terms_dict) + return objv_terms_dict, pysmt_objv_terms_dict, scaled_objv_terms_dict + # Compute stability region theta; used also in generating lemmas during search for a stable solution. # cex is assignement of values to knobs. Even if cex contains assignements to inputs, such assignements @@ -2122,8 +2186,10 @@ def compute_stability_formula_theta(self, cex, delta_dict:dict, radii_dict, univ assert delta_rel >= 0 else: delta_rel = delta_abs = None - - theta_form = self.smlp_true + + # theta_form = pysmt.shortcuts.TRUE() if self._ENABLE_PYSMT else self.smlp_true + theta_form = Solver._instance.smlp_true + #print('radii_dict', radii_dict) radii_dict_local = radii_dict.copy() knobs = radii_dict_local.keys(); #print('knobs', knobs); print('cex', cex); print('delta', delta_dict) @@ -2135,19 +2201,24 @@ def compute_stability_formula_theta(self, cex, delta_dict:dict, radii_dict, univ radii_dict_local[cex_var] = {'rad-abs':0, 'rad-rel': None} # delta for var,radii in radii_dict_local.items(): - var_term = self.smlp_var(var) + var_term = Solver.smlp_var(var) # either rad-abs or rad-rel must be None -- for each var wr declare only one of these if radii['rad-abs'] is not None: rad = radii['rad-abs']; #print('rad', rad); if delta_rel is not None: # we are generating a lemma rad = rad * (1 + delta_rel) + delta_abs - rad_term = self.smlp_cnst(rad) + + # if self._ENABLE_PYSMT: + # rad_term = float(rad) + # else: + # rad_term = self.smlp_cnst(rad) + rad_term = Solver.get_rad_term(rad=rad) + elif radii['rad-rel'] is not None: rad = radii['rad-rel']; #print('rad', rad) if delta_rel is not None: # we are generating a lemma rad = rad * (1 + delta_rel) + delta_abs - rad_term = self.smlp_cnst(rad) - + # TODO !!! issue a warning when candidates become closer and closer # TODO !!!!!!! warning when distance between previous and current candidate # TODO !!!!!! warning when FINAL rad + delta is 0, as part of sanity checking options @@ -2166,37 +2237,83 @@ def compute_stability_formula_theta(self, cex, delta_dict:dict, radii_dict, univ # candidate; this is a matter of definition of relative radius, and seems cleaner than computing actual radius # from relative radius based on variable values in the counter-exaples to candidate rather than variable values # in the candidate itself. - if delta_rel is not None: # radius for a lemma -- cex holds values of candidate counter-example - rad_term = rad_term * abs(var_term) - else: # radius for excluding a candidate -- cex holds values of the candidate - rad_term = rad_term * abs(cex[var]) - elif delta_dict is not None: - raise exception('When delta dictionary is provided, either absolute or relative or delta must be specified') - theta_form = self.smlp_and(theta_form, ((abs(var_term - cex[var])) <= rad_term)) + + # if self._ENABLE_PYSMT: + # rad_term = float(rad) + # else: + # rad_term = self.smlp_cnst(rad) + # if delta_rel is not None: # radius for a lemma -- cex holds values of candidate counter-example + # rad_term = rad_term * abs(var_term) + # else: # radius for excluding a candidate -- cex holds values of the candidate + # rad_term = rad_term * abs(cex[var]) + + rad_term = Solver.generate_rad_term(rad=rad, delta_rel=delta_rel, var_term=var_term,candidate=cex[var]) + + + elif delta_dict is not None: + raise exception('When delta dictionary is provided, either absolute or relative or delta must be specified') + # if self._ENABLE_PYSMT: + # value = float(cex[var]) + # PYSMT_var = self.parser.get_symbol(var) + # type = pysmt.shortcuts.Int if str(PYSMT_var.get_type()) == "Int" else pysmtReal + # calc_type = int if str(PYSMT_var.get_type()) == "Int" else float + # lower = calc_type(value - rad_term) + # lower = type(lower) + # upper = calc_type(value + rad_term) + # upper = type(upper) + # theta_form = self.parser.and_(theta_form, PYSMT_var >= lower, PYSMT_var <= upper) + # + # else: + # theta_form = self.smlp_and(theta_form, ((abs(var_term - cex[var])) <= rad_term)) + + theta_form = Solver.create_theta_form(theta_form=theta_form, witness=cex[var], var=var, var_term=var_term, rad_term=rad_term) + #print('theta_form', theta_form) return theta_form - + # Creates eta constraints on control parameters (knobs) from the spec. # Covers grid as well as range/interval constraints. def compute_grid_range_formulae_eta(self): #print('generate eta constraint') - eta_grid_form = self.smlp_true + eta_grid_form = Solver._instance.smlp_true eta_grids_dict = self._specInst.get_spec_eta_grids_dict; #print('eta_grids_dict', eta_grids_dict) for var,grid in eta_grids_dict.items(): - eta_grid_disj = self.smlp_false - var_term = self.smlp_var(var) + eta_grid_disj = Solver._instance.smlp_false + var_term = Solver.smlp_var(var) for gv in grid: # iterate over grid values - if eta_grid_disj == self.smlp_false: - eta_grid_disj = var_term == self.smlp_cnst(gv) + if eta_grid_disj == Solver._instance.smlp_false: + eta_grid_disj = Solver.smlp_eq(var_term, Solver.smlp_cnst(gv)) else: - eta_grid_disj = self.smlp_or(eta_grid_disj, var_term == self.smlp_cnst(gv)) - if eta_grid_form == self.smlp_true: + eta_grid_disj = Solver.smlp_or(eta_grid_disj, Solver.smlp_eq(var_term, Solver.smlp_cnst(gv))) + if eta_grid_form == Solver._instance.smlp_true: eta_grid_form = eta_grid_disj else: - eta_grid_form = self.smlp_and(eta_grid_form, eta_grid_disj) + eta_grid_form = Solver.smlp_and(eta_grid_form, eta_grid_disj) #print('eta_grid_form', eta_grid_form); return eta_grid_form - + + def pysmt_compute_grid_range_formulae_eta(self): + # print('generate eta constraint') + eta_grid_form = pysmt.shortcuts.TRUE() + eta_grids_dict = self._specInst.get_spec_eta_grids_dict; # print('eta_grids_dict', eta_grids_dict) + for var, grid in eta_grids_dict.items(): + # self.verifier.add_bounds(var, grid=grid) + eta_grid_disj = pysmt.shortcuts.FALSE() + var_term = self.parser.get_symbol(var) + symbol_type = var_term.get_type() + for gv in grid: # iterate over grid values + if eta_grid_disj == pysmt.shortcuts.FALSE(): + eta_grid_disj = self.parser.eq_(var_term, self.parser.cast_number(symbol_type, gv)) + else: + eta_grid_disj = self.parser.or_(eta_grid_disj, + self.parser.eq_(var_term, self.parser.cast_number(symbol_type, gv))) + if eta_grid_form == pysmt.shortcuts.TRUE(): + eta_grid_form = eta_grid_disj + else: + eta_grid_form = self.parser.and_(eta_grid_form, eta_grid_disj) + # print('eta_grid_form', eta_grid_form); + return eta_grid_form + # Compute formulae alpha, beta, eta from respective expression string. def compute_input_ranges_formula_alpha(self, model_inputs): @@ -2225,8 +2342,13 @@ def compute_input_ranges_formula_alpha(self, model_inputs): assert False return alpha_form - def compute_input_ranges_formula_alpha_eta(self, alpha_vs_eta, model_inputs): - alpha_or_eta_form = self.smlp_true + def compute_input_ranges_formula_alpha_eta(self, alpha_vs_eta, model_inputs, specs=None): + # self.ENABLE_PYSMT + # alpha_or_eta_form = self.smlp_true + # smt_form = pysmt.shortcuts.TRUE() + + alpha_or_eta_form = Solver._instance.smlp_true + if alpha_vs_eta == 'alpha': alpha_or_eta_ranges_dict = self._specInst.get_spec_alpha_bounds_dict elif alpha_vs_eta == 'eta': @@ -2243,19 +2365,36 @@ def compute_input_ranges_formula_alpha_eta(self, alpha_vs_eta, model_inputs): #print('mn', mn, 'mx', mx) if mn is not None and mx is not None: if self._declare_domain_interface_only: - if self._encode_input_range_as_disjunction and alpha_vs_eta == 'alpha' and v in self._specInst.get_spec_inputs: - rng = self.smlp_or_multi([self.smlp_eq(self.smlp_var(v), self.smlp_cnst(i)) for i in range(mn, mx+1)]) - else: - rng = self.smlp_and(self.smlp_var(v) >= self.smlp_cnst(mn), self.smlp_var(v) <= self.smlp_cnst(mx)) - alpha_or_eta_form = self.smlp_and(alpha_or_eta_form, rng) + # if self._ENABLE_PYSMT: + # symbol_v = self.parser.get_symbol(v) + # form = self.parser.and_(symbol_v >= mn, symbol_v <= mx) + # smt_form = self.parser.and_(smt_form, form) + # + # + # if self._encode_input_range_as_disjunction and alpha_vs_eta == 'alpha' and v in self._specInst.get_spec_inputs: + # rng = self.smlp_or_multi([self.smlp_eq(self.smlp_var(v), self.smlp_cnst(i)) for i in range(mn, mx+1)]) + # else: + # rng = self.smlp_and(self.smlp_var(v) >= self.smlp_cnst(mn), self.smlp_var(v) <= self.smlp_cnst(mx)) + # alpha_or_eta_form = self.smlp_and(alpha_or_eta_form, rng) + + alpha_or_eta_form = Solver.create_alpha_or_eta_form( + alpha_or_eta_form=alpha_or_eta_form, + v=v, + mn=mn, + mx=mx, + is_alpha=alpha_vs_eta == 'alpha', + is_in_spec=v in self._specInst.get_spec_inputs, + is_disjunction=self._encode_input_range_as_disjunction + ) elif mn is not None: - rng = self.smlp_var(v) >= self.smlp_cnst(mn) - alpha_or_eta_form = self.smlp_and(alpha_or_eta_form, rng) + rng = Solver.smlp_var(v) >= Solver.smlp_cnst(mn) + alpha_or_eta_form = Solver.smlp_and(alpha_or_eta_form, rng) elif mx is not None: - rng = self.smlp_var(v) <= self.smlp_cnst(mx) - alpha_or_eta_form = self.smlp_and(alpha_or_eta_form, rng) + rng = Solver.smlp_var(v) <= Solver.smlp_cnst(mx) + alpha_or_eta_form = Solver.smlp_and(alpha_or_eta_form, rng) else: assert False + return alpha_or_eta_form # alph_expr is alpha constraint specified in command line. If it is not None @@ -2269,16 +2408,23 @@ def compute_global_alpha_formula(self, alph_expr, model_inputs): #alph_form = self.compute_input_ranges_formula_alpha(model_inputs) #alph_form = self.smlp_true if alph_expr is None: - alpha_expr = self._specInst.get_spec_alpha_global_expr + alph_expr = self._specInst.get_spec_alpha_global_expr if alph_expr is None: - return self.smlp_true + return Solver._instance.smlp_true else: alph_expr_vars = get_expression_variables(alph_expr) dont_care_vars = list_subtraction_set(alph_expr_vars, model_inputs) if len(dont_care_vars) > 0: raise Exception('Variables ' + str(dont_care_vars) + ' in input constraints (alpha) are not part of the model') - alph_glob = self.ast_expr_to_term(alph_expr) + + alph_glob = Solver.parse_ast(parser=self.ast_expr_to_term, expression=alph_expr) + # if self._ENABLE_PYSMT: + # if self._RETURN_PYSMT: + # return self.parser.parse(alph_expr) + # else: + # print(self.parser.parse(alph_expr)) + return alph_glob #self._smlpTermsInst.smlp_and(alph_form, alph_glob) # The argument model_inps_outps is the union of model input and output varaiables. @@ -2290,18 +2436,18 @@ def compute_global_alpha_formula(self, alph_expr, model_inputs): # and are not dropped during data processing (see function SmlpData._prepare_data_for_modeling). def compute_beta_formula(self, beta_expr, model_inps_outps): if beta_expr is None: - return self.smlp_true + return Solver._instance.smlp_true else: beta_expr_vars = get_expression_variables(beta_expr) dont_care_vars = list_subtraction_set(beta_expr_vars, model_inps_outps) if len(dont_care_vars) > 0: raise Exception('Variables ' + str(dont_care_vars) + ' in optimization constraints (beta) are not part of the model') - return self.ast_expr_to_term(beta_expr) + return Solver.parse_ast(parser=self.ast_expr_to_term, expression=beta_expr) def compute_eta_formula(self, eta_expr, model_inputs): if eta_expr is None: - return self.smlp_true + return Solver._instance.smlp_true else: # eta_expr can only contain knobs (control inputs), not free inputs or outputs (responses) eta_expr_vars = get_expression_variables(eta_expr) @@ -2309,7 +2455,13 @@ def compute_eta_formula(self, eta_expr, model_inputs): if len(dont_care_vars) > 0: raise Exception('Variables ' + str(dont_care_vars) + ' in knob constraints (eta) are not part of the model') - return self.ast_expr_to_term(eta_expr) + # if self._ENABLE_PYSMT: + # if self._RETURN_PYSMT: + # return self.parser.parse(eta_expr) + # else: + # print(self.parser.parse(eta_expr)) + + return Solver.parse_ast(parser=self.ast_expr_to_term, expression=eta_expr) def var_domain(self, var, spec_domain_dict): interval = spec_domain_dict[var][self._SPEC_DOMAIN_INTERVAL_TAG]; #self._specInst.get_spec_interval_tag @@ -2349,19 +2501,27 @@ def create_model_exploration_base_components(self, syst_expr_dict:dict, algo, mo else: raise Exception('Data bounds file cannot be loaded') self._smlp_terms_logger.info('Parsing the SPEC: End') - + # get variable domains dictionary; certain sanity checks are performrd within this function. spec_domain_dict = self._specInst.get_spec_domain_dict; #print('spec_domain_dict', spec_domain_dict) - + + # self.verifier.initialize(variable_ranges=spec_domain_dict) + + + # contraints on features used as control variables and on the responses - alph_ranges = self.compute_input_ranges_formula_alpha_eta('alpha', feat_names); #print('alph_ranges') - alph_global = self.compute_global_alpha_formula(alph_expr, feat_names); #print('alph_global') - alpha = self.smlp_and(alph_ranges, alph_global); #print('alpha') - beta = self.compute_beta_formula(beta_expr, feat_names+resp_names); #print('beta') - eta_ranges = self.compute_input_ranges_formula_alpha_eta('eta', feat_names); #print('eta_ranges') - eta_grids = self.compute_grid_range_formulae_eta(); #print('eta_grids') - eta_global = self.compute_eta_formula(eta_expr, feat_names); #print('eta_global', eta_global) - eta = self.smlp_and_multi([eta_ranges, eta_grids, eta_global]); #print('eta', eta) + alph_ranges = self.compute_input_ranges_formula_alpha_eta('alpha', feat_names, + spec_domain_dict); # print('alph_ranges') + alph_global = self.compute_global_alpha_formula(alph_expr, feat_names); # print('alph_global') + alpha = Solver.smlp_and(alph_ranges, alph_global) + + beta = self.compute_beta_formula(beta_expr, feat_names + resp_names); # print('beta') + eta_ranges = self.compute_input_ranges_formula_alpha_eta('eta', feat_names, + spec_domain_dict); # print('eta_ranges') + eta_grids = self.compute_grid_range_formulae_eta() + eta_global = self.compute_eta_formula(eta_expr, feat_names); # print('eta_global', eta_global) + + eta = Solver.smlp_and_multi([eta_ranges, eta_grids, eta_global]) self._smlp_terms_logger.info('Alpha global constraints: ' + str(alph_global)) self._smlp_terms_logger.info('Alpha ranges constraints: ' + str(alph_ranges)) @@ -2548,30 +2708,34 @@ def create_model_exploration_instance_from_smlp_components(self, domain, model_f eq_form = self.smlp_eq(self.smlp_var(resp_name), resp_term); #print('eq_form', eq_form, type(eq_form)) base_solver.add(eq_form) return base_solver - + # wrapper function on solver.check to measure runtime and return status in a convenient way - def smlp_solver_check(self, solver, call_name:str, lemma_precision:int=0): + def smlp_solver_check(self, solver, call_name:str, lemma_precision:int=0, equations=None, temp=False): + if equations: + print('FORM2 FORMULA', equations) approx_lemmas = lemma_precision > 0 start = time.time() #print('solver chack start', flush=True) - res = solver.check() + res, witness = solver.check(temp=temp) #print('solver chack end', flush=True) end = time.time() - if isinstance(res, smlp.unknown): + status = Solver.convert_results_to_string(res) + + if status == "unknown": #print('smlp_unknown', smlp.unknown) - status = 'unknown' sat_model = {} - elif isinstance(res, smlp.sat): + elif status == "sat": #print('smlp_sat', smlp.sat) - status = 'sat' - sat_model = self.witness_term_to_const(res.model, approximate=False, precision=None) + witness = res.model if isinstance(solver, Form2_Solver) else witness["witness"] + sat_model = self.witness_term_to_const(witness, approximate=False, precision=None) if approx_lemmas: - sat_model_approx = self.approximate_witness_term(res.model, lemma_precision) + sat_model_approx = self.approximate_witness_term(witness, lemma_precision) + # return TextToPysmtParser.SAT #print('res.model', res.model, 'sat_model', sat_model) - elif isinstance(res, smlp.unsat): + elif status == "unsat": #print('smlp_unsat', smlp.unsat) - status = 'unsat' sat_model = {} + # return TextToPysmtParser.UNSAT else: raise Exception('Unexpected solver result ' + str(res)) @@ -2643,16 +2807,16 @@ def smlp_solver_check(self, solver, call_name:str, lemma_precision:int=0): #print('res.mode;', res.model, 'assignment', assignment, 'assignment_approx', assignment_approx); #return res, assignment_approx #print('exit smlp_solver_check', flush=True) - return res + return status, witness def solver_status_sat(self, res): - return isinstance(res, smlp.sat) + return res == "SAT" def solver_status_unsat(self, res): - return isinstance(res, smlp.unsat) + return res == "UNSAT" def solver_status_unknown(self, res): - return isinstance(res, smlp.unknown) + return res == "UNKNOWN" # we return value assignmenets to interface (input, knob, output) variables defined in the Spec file # (and not values assigned to any other variables that might be defined additionally as part of solver domain, @@ -2692,22 +2856,50 @@ def get_solver_resps_model(self, res): def check_alpha_eta_consistency(self, domain:smlp.domain, model_full_term_dict:dict, alpha:smlp.form2, eta:smlp.form2, solver_logic:str): #print('create solver: model', model_full_term_dict, flush=True) - solver = self.create_model_exploration_instance_from_smlp_components( - domain, model_full_term_dict, False, solver_logic) - #print('add alpha', alpha, flush=True) - solver.add(alpha); #print('alpha', alpha, flush=True) - solver.add(eta); #print('eta', eta) - #print('create check', flush=True) - #res = solver.check(); print('res', res, flush=True) - res = self.smlp_solver_check(solver, 'interface_consistency' if model_full_term_dict is None else 'model_consistency') + # if not self._RETURN_PYSMT: + # solver = self.create_model_exploration_instance_from_smlp_components( + # domain, model_full_term_dict, False, solver_logic) + # #print('add alpha', alpha, flush=True) + # solver.add(alpha); #print('alpha', alpha, flush=True) + # solver.add(eta); #print('eta', eta) + # #print('create check', flush=True) + # #res = solver.check(); print('res', res, flush=True) + # res = self.smlp_solver_check(solver, 'interface_consistency' if model_full_term_dict is None else 'model_consistency', equations={'alpha':alpha, 'eta':eta}) + # else: + # self.verifier.reset() + # self.verifier.apply_restrictions(alpha) + # self.verifier.apply_restrictions(eta) + # print('PYSMT FORMULA',{'alpha':alpha, 'eta':eta}) + # res, witness = self.verifier.solve() + + solver = Solver.create_solver( + create_solver=self.create_model_exploration_instance_from_smlp_components, + domain=domain, + model_full_term_dict=model_full_term_dict, + incremental=False, + solver_logic=solver_logic + ) + solver.add_formula(alpha) + solver.add_formula(eta) + + res,_ = self.smlp_solver_check(solver, 'interface_consistency' if model_full_term_dict is None else 'model_consistency', equations={'alpha':alpha, 'eta':eta}) + consistency_type = 'Input and knob' if model_full_term_dict is None else 'Model' - if isinstance(res, smlp.sat): + if res == "sat": self._smlp_terms_logger.info(consistency_type + ' interface constraints are consistent') interface_consistent = True - elif isinstance(res, smlp.unsat): + elif res == "unsat": self._smlp_terms_logger.info(consistency_type + ' interface constraints are inconsistent') interface_consistent = False else: raise Exception('alpha and eta cosnsistency check failed to complete') return interface_consistent + def add_integer_constraints(self): + for symbol, values in self._specInst.get_spec_domain_dict.items(): + if values['range'] == 'int': + ranges = values['interval'] + integer_formula = self.parser.create_integer_disjunction(f'{symbol}_unscaled', (ranges[0], ranges[-1])) + self.verifier.add_permanent_constraint(integer_formula) + + diff --git a/src/smlp_py/smtlib/__init__.py b/src/smlp_py/smtlib/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/src/smlp_py/smtlib/parser.py b/src/smlp_py/smtlib/parser.py new file mode 100755 index 00000000..3f407ae3 --- /dev/null +++ b/src/smlp_py/smtlib/parser.py @@ -0,0 +1,96 @@ +import ast +from pysmt.shortcuts import Symbol, And, Or, Not, Implies, Iff, Ite, Equals, Plus, Minus, Times, Div, Pow, Bool, TRUE, FALSE, Int, Real +from pysmt.typing import BOOL, REAL, INT + +pysmt_types = { + "int": INT, + "real": REAL, + "bool": BOOL +} + +class TextToPysmtParser(object): + def __init__(self): + self.symbols = {} + self._ast_operators_map = { + ast.Add: Plus, # Addition + ast.Sub: Minus, # Subtraction + ast.Mult: Times, # Multiplication + ast.Div: Div, # Division + ast.Pow: Pow, # Exponentiation + ast.BitXor: Iff, # Bitwise XOR (interpreted as logical Iff) + + ast.USub: Minus, # Unary subtraction (negation) + + ast.Eq: Equals, # Equality + ast.NotEq: Not, # Not equal + ast.Lt: lambda l, r: l < r, # Less than + ast.LtE: lambda l, r: l <= r, # Less than or equal to + ast.Gt: lambda l, r: l > r, # Greater than + ast.GtE: lambda l, r: l >= r, # Greater than or equal to + + ast.And: And, # Logical AND + ast.Or: Or, # Logical OR + ast.Not: Not, # Logical NOT + + ast.IfExp: Ite # If expression + } + + def add_symbol(self, name, symbol_type): + assert symbol_type in pysmt_types.keys() + self.symbols[name] = Symbol(name, pysmt_types[symbol_type]) + + def parse(self, expr): + assert isinstance(expr, str) + symbol_list = self.symbols + + def eval_(node): + if isinstance(node, ast.Num): + return Real(node.n) if isinstance(node.n, float) else Int(node.n) + elif isinstance(node, ast.BinOp): + return self._ast_operators_map[type(node.op)](eval_(node.left), eval_(node.right)) + elif isinstance(node, ast.UnaryOp): + return self._ast_operators_map[type(node.op)](eval_(node.operand)) + elif isinstance(node, ast.Name): + return symbol_list[node.id] + elif isinstance(node, ast.BoolOp): + res_boolop = self._ast_operators_map[type(node.op)](eval_(node.values[0]), eval_(node.values[1])) + for value in node.values[2:]: + res_boolop = self._ast_operators_map[type(node.op)](res_boolop, eval_(value)) + return res_boolop + elif isinstance(node, ast.Compare): + left = eval_(node.left) + first_comparator = eval_(node.comparators[0]) + result = self._ast_operators_map[type(node.ops[0])](left, first_comparator) + for op, comparator in zip(node.ops[1:], node.comparators[1:]): + left = eval_(comparator) + result = And(result, self._ast_operators_map[type(op)](left, eval_(comparator))) + return result + elif isinstance(node, ast.IfExp): + return self._ast_operators_map[ast.IfExp](eval_(node.test), eval_(node.body), eval_(node.orelse)) + elif isinstance(node, ast.Constant): + if node.value is True: + return TRUE() + elif node.value is False: + return FALSE() + elif isinstance(node.value, int): + return Int(node.value) + elif isinstance(node.value, float): + return Real(node.value) + else: + return node.value + else: + raise TypeError(f"Unexpected node type {type(node)}") + + return eval_(ast.parse(expr, mode='eval').body) + + +if __name__ == "__main__": + + parser = TextToPysmtParser() + parser.add_symbol('x1', 'int') + parser.add_symbol('x2', 'real') + parser.add_symbol('p2', 'real') + + formula = parser.parse('p2<5.0 and x1==10 and x2<12.0') + print(formula) + diff --git a/src/smlp_py/smtlib/smt_to_pysmt.py b/src/smlp_py/smtlib/smt_to_pysmt.py new file mode 100755 index 00000000..bc412854 --- /dev/null +++ b/src/smlp_py/smtlib/smt_to_pysmt.py @@ -0,0 +1,190 @@ +import re + +from pysmt.smtlib.parser import SmtLibParser +from pysmt.shortcuts import Symbol, simplify, get_env +from pysmt.typing import REAL, INT, BOOL +from io import StringIO +from pysmt.rewritings import CNFizer + +from pysmt.shortcuts import Symbol, And, LE, Real, Ite, Or, Not, LT, Equals, Plus, Minus, Times, Div +from pysmt.typing import REAL +from sympy import sympify + + +def smtlib_to_pysmt(smt_query, var_types): + """ + Converts an SMT-LIB query string to a PySMT formula. + + Parameters: + smt_query (str): The SMT-LIB query string. + var_types (dict): A dictionary mapping variable names to their types (REAL, INT, BOOL). + + Returns: + pysmt.shortcuts.FNode: The PySMT formula. + """ + # Initialize the SMT-LIB parser + parser = SmtLibParser() + + # Build the declarations for the variables + declarations = [] + for var, vtype in var_types.items(): + if vtype == 'REAL': + declarations.append(f"(declare-fun {var} () Real)") + Symbol(var, REAL) + elif vtype == 'INT': + declarations.append(f"(declare-fun {var} () Int)") + Symbol(var, INT) + elif vtype == 'BOOL': + declarations.append(f"(declare-fun {var} () Bool)") + Symbol(var, BOOL) + else: + raise ValueError(f"Unsupported variable type: {vtype}") + + # Join the declarations with the original SMT-LIB query + smt_query_with_declarations = "\n".join(declarations) + f"\n(assert {smt_query})" + + # Parse the SMT-LIB query + script = parser.get_script(StringIO(smt_query_with_declarations)) + + # Extract the formula from the script + formula = script.get_last_formula() + + # Simplify the parsed formula + simplified_formula = simplify(formula) + + return simplified_formula + + +def convert_fractions_to_floats(formula: str) -> str: + # Regular expression to find fractions in the format 'numerator/denominator' + fraction_pattern = re.compile(r'(\d+/\d+)') + + def fraction_to_float(match): + fraction = match.group() + numerator, denominator = map(int, fraction.split('/')) + return str(numerator / denominator) + + # Substitute all fractions in the formula with their float equivalents + formula_with_floats = fraction_pattern.sub(fraction_to_float, formula) + + return formula_with_floats + + +def convert_ternary_to_logic(formula: str) -> str: + # Regular expression to find ternary statements in the format 'condition ? true_expr : false_expr' + ternary_pattern = re.compile(r'\(([^()]+)\?\(([^()]+)\):\(([^()]+)\)\)') + + while ternary_pattern.search(formula): + formula = ternary_pattern.sub( + lambda match: f'(({match.group(1)}) & ({match.group(2)}) | (~({match.group(1)}) & ({match.group(3)})))', + formula) + + return formula + +def pysmt_convert_fractions_to_floats(term): + if term.is_constant() and term.is_real_constant(): + value = term.constant_value() + return Real(float(value)) + elif term.is_symbol(): + return term + elif term.is_plus() or term.is_minus() or term.is_times() or term.is_div(): + if term.node_type() == 12: + return Plus(*[pysmt_convert_fractions_to_floats(arg) for arg in term.args()]) + elif term.node_type() == 13: + return Minus(*[pysmt_convert_fractions_to_floats(arg) for arg in term.args()]) + elif term.node_type() == 14: + return Times(*[pysmt_convert_fractions_to_floats(arg) for arg in term.args()]) + elif term.node_type() == 15: + return Div(*[pysmt_convert_fractions_to_floats(arg) for arg in term.args()]) + return term + +def recursively_convert_ite(term): + if term.is_ite(): + condition = term.arg(0) + true_branch = recursively_convert_ite(term.arg(1)) + false_branch = recursively_convert_ite(term.arg(2)) + return Or(And(condition, true_branch), And(Not(condition), false_branch)) + elif term.is_and(): + return And(*[recursively_convert_ite(arg) for arg in term.args()]) + elif term.is_or(): + return Or(*[recursively_convert_ite(arg) for arg in term.args()]) + elif term.is_not(): + return Not(recursively_convert_ite(term.arg(0))) + elif term.is_le() or term.is_lt() or term.is_equals(): + left = recursively_convert_ite(term.arg(0)) + right = recursively_convert_ite(term.arg(1)) + if term.node_type() == 16: + return LE(left, right) + elif term.node_type() == 17: + return LT(left, right) + elif term.node_type() == 18: + return Equals(left, right) + else: + return pysmt_convert_fractions_to_floats(term) + +# node types: +# 12: + +# 13: - +# 14: / +# 15: * +# 16: <= +# 17: < +# 18: == +# 19: ITE + + +# Example usage +if __name__ == "__main__": + + + + # Define the SMT-LIB query as a string + # smt_query = "(let ((|:0| (* (/ 281474976710656 2944425288877159) (- y1 (/ 1080863910568919 4503599627370496))))) (let ((|:1| (* (/ 281474976710656 2559564553220679) (- (* (/ 1 2) (+ y1 y2)) (/ 1170935903116329 1125899906842624))))) (and (>= (ite (< |:0| |:1|) |:0| |:1|) 1) (and (>= y1 4) (>= y2 8)))))" + smt_query = "(and (and true (and (>= x1 0) (<= x1 10))) (and (>= x2 (- 1)) (<= x2 1)))" + # Define variable types + var_types = { + 'y1': 'REAL', + 'y2': 'REAL', + 'p1': 'REAL', + 'p2': 'REAL', + 'x1': 'REAL', + 'x2': 'REAL' + } + + # Convert SMT-LIB to PySMT + pysmt_formula = smtlib_to_pysmt(smt_query, var_types) + # + # pysmt_formula = recursively_convert_ite(pysmt_formula) + # + # print(pysmt_formula.serialize()) + + # pysmt_formula = pysmt_formula.serialize() + # # Print the PySMT formula + # print("Converted PySMT Formula:") + # print(pysmt_formula) + # + # pysmt_formula = convert_fractions_to_floats(pysmt_formula) + # print("Removed fractions:") + # print(pysmt_formula) + # + # pysmt_formula = recursively_convert_ite(sympify(pysmt_formula)) + # # pysmt_formula = convert_ternary_to_logic(pysmt_formula) + # print("Removed ITE:") + # print(pysmt_formula) + + + cnfizer = CNFizer() + cnf_formula = cnfizer.convert(pysmt_formula) + print("CNF PySMT Formula:") + print(cnf_formula) + + # Example: Add an additional constraint to the formula + # p1 = Symbol('p1', ) + # additional_constraint = p1 != 3 + # combined_formula = simplify(pysmt_formula & additional_constraint) + + # print("Combined Formula with Additional Constraint:") + # print + + +##################################################################################### diff --git a/src/smlp_py/smtlib/text_to_sympy.py b/src/smlp_py/smtlib/text_to_sympy.py new file mode 100755 index 00000000..fdfa4ac0 --- /dev/null +++ b/src/smlp_py/smtlib/text_to_sympy.py @@ -0,0 +1,765 @@ +import re + +import gmpy2 +import z3 +from pysmt import * +from sympy.logic.boolalg import And, Or, Not +from pysmt.shortcuts import Symbol, And, Or, Not, Implies, Iff, Ite, Equals, Plus, Minus, Times, Div, Pow, Bool, TRUE, \ + FALSE, Int, Real, simplify, LT, LE, GT, GE, ToReal +from pysmt.shortcuts import Or, Equals +from pysmt.fnode import FNode +from pysmt.typing import BOOL, REAL, INT +from pysmt.rewritings import CNFizer +from pysmt.walkers import IdentityDagWalker, DagWalker +import ast +import smlp +from pysmt.smtlib.script import smtlibscript_from_formula +from io import StringIO + +from typing import List, Dict, Optional, Tuple + +from z3 import Tactic, Goal + +pysmt_types = { + "int": INT, + "real": REAL, + "bool": BOOL +} + + +class Equation: + + def __init__(self, variable, operator: str, scalar: float): + self.variable = variable + self.operator = operator + self.scalar = scalar + self._eq = [str(variable),operator,str(scalar)] + + def __str__(self): + return "{0} {1} {2}".format(self.variable,self.operator,self.scalar) + + def __eq__(self, o: object) -> bool: + return str(self) == str(o) + + def lhs(self): + return self.variable + + def rhs(self): + return self.scalar + + def op(self): + return self.operator + + def __hash__(self) -> int: + return hash(self.variable) * hash(self.operator) * int(self.scalar) + +class InequalityChecker(DagWalker): + def __init__(self, env=None): + DagWalker.__init__(self, env=env) + self.is_inequality = False + self.contains_and_or = False + + def walk_and(self, formula, args, **kwargs): + self.contains_and_or = True + return formula + + def walk_or(self, formula, args, **kwargs): + self.contains_and_or = True + return formula + + def walk_le(self, formula, args, **kwargs): + self.is_inequality = True + return formula + + def walk_lt(self, formula, args, **kwargs): + self.is_inequality = True + return formula + + def walk_ge(self, formula, args, **kwargs): + self.is_inequality = True + return formula + + def walk_gt(self, formula, args, **kwargs): + self.is_inequality = True + return formula + +def check_inequality(formula): + checker = InequalityChecker() + checker.walk(formula) + return checker.is_inequality and not checker.contains_and_or + +class TextToPysmtParser(object): + SAT = "SAT" + UNSAT = "UNSAT" + types = pysmt_types + real = Real + true = TRUE + false = FALSE + _instance = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(TextToPysmtParser, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + self.symbols = {} + self.inputs = [] + self.outputs = [] + self._ast_operators_map = { + ast.Add: Plus, # Addition + ast.Sub: Minus, # Subtraction + ast.Mult: Times, # Multiplication + ast.Div: self._div_op, # Division + ast.Pow: Pow, # Exponentiation + ast.BitXor: Iff, # Bitwise XOR (interpreted as logical Iff) + + ast.USub: lambda l: -l, # Unary subtraction (negation) + + ast.Eq: Equals, # Equality + ast.NotEq: Not, # Not equal + ast.Lt: lambda l, r: l < r, # Less than + ast.LtE: lambda l, r: l <= r, # Less than or equal to + ast.Gt: lambda l, r: l > r, # Greater than + ast.GtE: lambda l, r: l >= r, # Greater than or equal to + + ast.And: And, # Logical AND + ast.Or: Or, # Logical OR + ast.Not: Not, # Logical NOT + + ast.IfExp: Ite, # If expression + ast.Call: And, + 'If': Ite, + 'And': And, + 'Not': Not, + 'Or': Or + } + + def _div_op(self, left, right): + # Ensure both operands are real numbers for division + left = ToReal(left) if not left.is_real_constant() else left + right = ToReal(right) if not right.is_real_constant() else right + return Div(left, right) + + + @staticmethod + def and_(*expressions): + return And(*expressions) + + @staticmethod + def or_(*expressions): + return Or(*expressions) + + @staticmethod + def eq_(*expressions): + return Equals(*expressions) + + @staticmethod + def ite_(*expressions): + return Ite(*expressions) + + @staticmethod + def to_cnf(formula): + cnfizer = CNFizer() + cnf_formula = cnfizer.convert(formula) + return cnf_formula + + @staticmethod + def conjunction_to_disjunction(formula): + if formula.is_and(): + negated_terms = [Not(arg) for arg in formula.args()] + disjunction = Or(negated_terms) + return simplify(Not(disjunction)) + else: + raise ValueError("Input formula is not a conjunction") + + def is_comparison(self, node: FNode) -> bool: + return node.is_le() or node.is_lt() or node.is_equals() + + def create_integer_disjunction(self, variable, values): + variable = self.get_symbol(variable) + if not variable: + return None + + lower, upper = values + value_range = range(lower, upper + 1) + + return Or(*(Equals(variable, Real(val)) for val in value_range)) + + + def split_disjunctions(self, formula: FNode) -> list: + if formula.is_or(): + comparisons = [arg for arg in formula.args() if self.is_comparison(arg)] + if len(comparisons) == len(formula.args()): + return comparisons + elif self.is_comparison(formula): + return [formula] + else: + raise ValueError("Input formula is not a valid disjunction of comparisons") + return [] + + def opposite_comparator(self, comparator): + # sympy only uses LE and LT + # GE and GT are described using LE and LT and reversing the order of the symbol and number + if comparator == "<=": + return ">=" + elif comparator == "<": + return ">" + else: + return comparator + + def decide_comparator(self, formula): + node_type = formula.node_type() + if node_type == 16: + return "<=" + elif node_type == 17: + return "<" + elif node_type == 18: + return "=" + else: + return None + + def apply_comparator(self, comparator, a, b): + match comparator: + case '<=': + return a <= b + case '<': + return a < b + case "=": + return a == b + case _: + return None + + + def extract_coefficient(self, symbol): + coeff = [] + # possible formats + # 1) x-5 + # 2) a*x - 5 + for arg in symbol.args(): + if arg.is_constant(): + coeff.insert(0, arg) + elif arg.is_symbol(): + coeff.append(arg) + else: + pass + + return coeff + + def extract_components(self, comparison: FNode, need_simplification=False): + # if need_simplification: + # smtlib = self.extract_smtlib(comparison) + # comparison = self.handle_ite_formula(smtlib, handle_ite=False) + + left = comparison.arg(0) + right = comparison.arg(1) + + if not right.is_constant() and not left.is_constant(): + raise ValueError("The right-hand side of the formula must be a constant") + + comparator = self.decide_comparator(comparison) + + terms_subformula = left if right.is_constant() else right + + terms = [] + + def traverse(node): + if node.is_times(): + coeff, var = node.args() + if coeff.is_constant() and var.is_symbol(): + terms.append((float(coeff.constant_value()), var)) + elif var.is_constant() and coeff.is_symbol(): + terms.append((float(var.constant_value()), coeff)) + else: + raise ValueError("Invalid term structure in linear inequality") + elif node.is_plus(): + for arg in node.args(): + traverse(arg) + elif node.is_minus(): + left, right = node.args() + traverse(left) + if right.is_times(): + coeff, var = right.args() + if coeff.is_constant() and var.is_symbol(): + terms.append((-float(coeff.constant_value()), var)) + elif var.is_constant() and coeff.is_symbol(): + terms.append((-float(var.constant_value()), coeff)) + else: + raise ValueError("Invalid term structure in linear inequality") + else: + raise ValueError("Invalid term structure in linear inequality") + elif node.is_symbol(): + terms.append((1.0, node)) + elif node.is_constant(): + terms.append((node.constant_value(), Real(0))) + else: + raise ValueError("Unsupported node type in linear inequality") + + traverse(terms_subformula) + + if right.is_constant(): + scalar = float(right.constant_value()) + return terms, comparator, scalar + else: + scalar = float(left.constant_value()) + return terms, self.opposite_comparator(comparator), scalar + + def process_formula(self, formula: FNode): + components = [] + if formula.is_and(): + for arg in formula.args(): + components.extend(self.process_formula(arg)) + elif formula.is_or(): + print("Disjunction found, storing components:") + for arg in formula.args(): + if arg.is_and(): + components.extend(self.process_formula(arg)) + else: + components.append(arg) + elif self.is_comparison(formula): + components.append(formula) + else: + print("Other formula type encountered.") + + return components + + + def propagate_negation(self, formula): + """ + Apply negation to a formula and propagate the negation inside without leaving any negations in the formula. + """ + formula = self.simplify(formula) + if formula.is_not(): + return self.propagate_negation(formula.arg(0)) # Remove double negation if exists + + elif formula.is_and(): + # Apply De Morgan's law: not (A and B) -> (not A) or (not B) + return Or([self.propagate_negation(Not(arg)) for arg in formula.args()]) + + elif formula.is_or(): + # Apply De Morgan's law: not (A or B) -> (not A) and (not B) + return And([self.propagate_negation(Not(arg)) for arg in formula.args()]) + + elif formula.is_equals(): + # not (A = B) -> A != B + A, B = formula.args() + return And(LT(A, B), LT(B, A)) + + elif formula.is_lt(): + # not (A < B) -> A >= B + A, B = formula.args() + return LE(B,A) + + elif formula.is_le(): + # not (A <= B) -> A > B + A, B = formula.args() + return LT(B, A) + + elif formula.is_plus() or formula.is_times(): + # Propagate negation inside arithmetic operations + return formula + + elif formula.is_symbol() or formula.is_constant(): + # Apply negation directly to literals + return Not(formula) + + else: + raise NotImplementedError(f"Negation propagation not implemented for formula type: {formula}") + + def simplify(self, expression): + return simplify(expression) + + def cast_number(self, symbol_type, number): + if symbol_type == REAL: + return Real(number) + elif symbol_type == INT: + return Int(number) + + def init_variables(self, symbols: List[Tuple[str, str, bool]]) -> None: + for input_var in symbols: + name, type, is_input = input_var + unscaled_name = f"{name}_unscaled" + scaled_name = f"{name}_scaled" + # TODO: i replaced the type variable with real, make sure that's ok + self.add_symbol(name, 'real', is_input=is_input, nn_type=type) + self.add_symbol(unscaled_name, 'real', is_input=is_input, nn_type=type) + self.add_symbol(scaled_name, 'real', is_input=is_input, nn_type=type) + + def add_symbol(self, name, symbol_type, is_input=True, nn_type='real'): + assert symbol_type.lower() in pysmt_types.keys() + self.symbols[name] = Symbol(name, pysmt_types[symbol_type]) + + if name.find("_unscaled") == -1: + store = self.inputs if is_input else self.outputs + store.append((name, nn_type)) + + def get_symbol(self, name): + # assert name in self.symbols.keys() + if name not in self.symbols.keys(): + self.symbols[name] = Symbol(name, pysmt_types['real']) + + return self.symbols[name] + + def remove_first_and_last_line(self, text): + # Split the text into a list of lines + lines = text.split('\n') + + # Remove the first and last lines + if len(lines) > 1: + lines = lines[1:-2] + else: + # If there's only one line or no line, return an empty string + lines = [] + + # Join the remaining lines back into a single string + return '\n'.join(lines) + + def extract_smtlib(self, formula): + script = smtlibscript_from_formula(formula) + outstream = StringIO() + script.serialize(outstream) + output = outstream.getvalue() + return self.remove_first_and_last_line(output) + + def z3_simplify(self, formula): + if isinstance(formula, str): + smlp_str = formula + else: + smlp_str = self.extract_smtlib(formula) + smlp_parsed = z3.parse_smt2_string(smlp_str) + smlp_simplified = z3.simplify(smlp_parsed[0]) + ex = self.parse(str(smlp_simplified).replace('\n', '')) + if ex.is_not(): + ex = self.propagate_negation(ex) + + return ex + + def handle_ite_formula(self, formula, is_form2=False, handle_ite=True): + # smlp_str = self.extract_smtlib(formula) if not isinstance(formula, str) else formula + # smlp_str = f""" + # (declare-fun y1 () Real) + # (declare-fun y2 () Real) + # (assert {formula}) + # """ if not isinstance(formula, str) else formula + + if is_form2: + smlp_str = f""" + (declare-fun y1 () Real) + (declare-fun y2 () Real) + (assert {formula}) + """ + elif isinstance(formula, str): + smlp_str = formula + else: + smlp_str = self.extract_smtlib(formula) + + smlp_parsed = z3.parse_smt2_string(smlp_str) + # if flag: + # goal = Goal() + # goal.add(smlp_parsed) + # t = Tactic('tseitin-cnf') + # smlp_parsed = t(goal)[0] + + smlp_simplified = z3.simplify(smlp_parsed[0]) + ex = self.parse(str(smlp_simplified).replace('\n','')) + if ex.is_not(): + ex = self.propagate_negation(ex) + # ex = parser.replace_constants_with_floats_and_evaluate(ex) + if handle_ite: + marabou_formula = self.convert_ite_to_conjunctions_disjunctions(ex) + else: + marabou_formula = ex + return marabou_formula + + def replace_constants_with_floats_and_evaluate(self, formula: FNode) -> FNode: + def traverse(node: FNode) -> FNode: + if node.is_plus(): + left, right = node.args() + new_left = traverse(left) + new_right = traverse(right) + if new_left.is_constant() and new_right.is_constant(): + return Real(new_left.constant_value() + new_right.constant_value()) + return Plus(new_left, new_right) + elif node.is_minus(): + left, right = node.args() + new_left = traverse(left) + new_right = traverse(right) + if new_left.is_constant() and new_right.is_constant(): + return Real(new_left.constant_value() - new_right.constant_value()) + return Minus(new_left, new_right) + elif node.is_times(): + left, right = node.args() + new_left = traverse(left) + new_right = traverse(right) + if new_left.is_constant() and new_right.is_constant(): + return Real(new_left.constant_value() * new_right.constant_value()) + return Times(new_left, new_right) + elif node.is_div(): + left, right = node.args() + new_left = traverse(left) + new_right = traverse(right) + if new_left.is_constant() and new_right.is_constant(): + return Real(new_left.constant_value() / new_right.constant_value()) + return Div(new_left, new_right) + elif node.is_ite(): + condition, true_branch, false_branch = node.args() + new_condition = traverse(condition) + new_true_branch = traverse(true_branch) + new_false_branch = traverse(false_branch) + return Ite(new_condition, new_true_branch, new_false_branch) + elif node.is_le(): + left, right = node.args() + new_left = traverse(left) + new_right = traverse(right) + return LE(new_left, new_right) + elif node.is_lt(): + left, right = node.args() + new_left = traverse(left) + new_right = traverse(right) + return LT(new_left, new_right) + elif node.is_and(): + new_args = [traverse(arg) for arg in node.args()] + return And(new_args) + elif node.is_or(): + new_args = [traverse(arg) for arg in node.args()] + return Or(new_args) + elif node.is_constant(): + if isinstance(node.constant_value(), gmpy2.mpq): + return Real(float(node.constant_value())) + elif node.is_int_constant(): + return Real(float(node.constant_value())) + elif node.is_real_constant(): + return Real(node.constant_value()) + else: + return node + + return traverse(formula) + + def parse(self, expr): + assert isinstance(expr, str) + symbol_list = self.symbols + + def eval_(node): + if isinstance(node, ast.Num): + # return Real(node.n) if isinstance(node.n, float) else Int(node.n) + return Real(float(node.n)) + elif isinstance(node, ast.BinOp): + left = eval_(node.left) + right = eval_(node.right) + if left.is_constant() and right.is_constant(): + if isinstance(node.op, ast.Mult): + return Real(float(left.constant_value() * right.constant_value())) + elif isinstance(node.op, ast.Div): + return Real(float(left.constant_value() / right.constant_value())) + return self._ast_operators_map[type(node.op)](left, right) + elif isinstance(node, ast.UnaryOp): + operand = eval_(node.operand) + if operand.is_constant() and isinstance(node.op, ast.USub): + return Real(-operand.constant_value()) + return self._ast_operators_map[type(node.op)](operand) + elif isinstance(node, ast.Name): + return symbol_list[node.id] + elif isinstance(node, ast.BoolOp): + res_boolop = self._ast_operators_map[type(node.op)](eval_(node.values[0]), eval_(node.values[1])) + for value in node.values[2:]: + res_boolop = self._ast_operators_map[type(node.op)](res_boolop, eval_(value)) + return res_boolop + elif isinstance(node, ast.Compare): + left = eval_(node.left) + first_comparator = eval_(node.comparators[0]) + result = self._ast_operators_map[type(node.ops[0])](left, first_comparator) + for op, comparator in zip(node.ops[1:], node.comparators[1:]): + left = eval_(comparator) + result = And(result, self._ast_operators_map[type(op)](left, eval_(comparator))) + return result + elif isinstance(node, ast.Call): + func = node.func.id + args = [eval_(arg) for arg in node.args] + if func in self._ast_operators_map: + return self._ast_operators_map[func](*args) + else: + raise ValueError(f"Unsupported function call: {func}") + + elif isinstance(node, ast.IfExp): + return self._ast_operators_map[ast.IfExp](eval_(node.test), eval_(node.body), eval_(node.orelse)) + elif isinstance(node, ast.Constant): + if node.value is True: + return TRUE() + elif node.value is False: + return FALSE() + elif isinstance(node.value, int): + return Int(node.value) + elif isinstance(node.value, float): + return Real(node.value) + else: + return node.value + else: + raise TypeError(f"Unexpected node type {type(node)}") + + return eval_(ast.parse(expr, mode='eval').body) + + def convert_ite_to_conjunctions_disjunctions(self, formula): + def traverse(node, from_ite=False, value=0, position=None, comparator=None): + if from_ite: + condition, true_branch, false_branch = node.args() + true_branch = traverse(true_branch) + false_branch = traverse(false_branch) + condition = traverse(condition) + not_condition = traverse(Not(condition)) + + + true_branch = self.apply_comparator(comparator, true_branch, value) if position == "right" else self.apply_comparator(comparator, value, true_branch) + false_branch = self.apply_comparator(comparator, false_branch, value) if position == "right" else self.apply_comparator(comparator, value, false_branch) + + true_branch = self.z3_simplify(true_branch) + false_branch = self.z3_simplify(false_branch) + return Or( + And(condition, true_branch), + And(not_condition, false_branch) + ) + elif node.is_and(): + new_args = [traverse(arg) for arg in node.args()] + return And(new_args) + elif node.is_or(): + new_args = [traverse(arg) for arg in node.args()] + return Or(new_args) + elif node.is_not(): + return self.propagate_negation(node) + elif self.is_comparison(node): + left, right = node.args() + comparator = self.decide_comparator(node) + if left.is_constant() and right.is_ite(): + return traverse(right, from_ite=True, value=left, position='left', comparator=comparator) + elif right.is_constant() and left.is_ite(): + return traverse(left, from_ite=True, value=right, position='right', comparator=comparator) + else: + return node + else: + return node + + return traverse(formula) + +if __name__ == "__main__": + + parser = TextToPysmtParser() + parser.add_symbol('x1', 'int') + parser.add_symbol('x2', 'real') + parser.add_symbol('p2', 'real') + parser.add_symbol('p1', 'real') + parser.add_symbol('y1', 'real') + parser.add_symbol('y2', 'real') + + # Example usage + y1 = parser.get_symbol('y1') + y2 = parser.get_symbol('y1') + # Original formula: not(y1 >= 4.0 and y2 >= 8.0) + original_formula = Not(And(LE(Real(4.0), y1), LE(Real(8.0), y2))) + + negated_formula = parser.propagate_negation(original_formula) + + print(f"Original formula: {original_formula}") + print(f"Negated formula with propagated negation: {negated_formula}") + + def separate_conjunctions_and_disjunctions(formula): + conjunctions = [] + disjunctions = [] + + def traverse(node, source=None): + if node.is_and(): + # conjunctions.extend(node.args()) + for arg in node.args(): + traverse(arg, conjunctions) + elif node.is_or(): + disjunctions.extend(node.args()) + # Or(*[recursively_convert_ite(arg) for arg in term.args()]) + # for arg in node.args(): + # traverse(arg, disjunctions) + elif node.is_le() or node.is_lt() or node.is_equals(): + source.append(node) + source.append(node) + else: + # Leaf nodes (symbols, literals, etc.) are not conjunctions or disjunctions + pass + + traverse(formula) + return conjunctions, disjunctions + + + x = Symbol('x', REAL) + y = Symbol('y', REAL) + + # formula = And(LT(x, Real(10.0)), Or(LT(y, Real(10.0)), LT(y, Real(10.0)))) + # formula = And(LT(x, Real(10.0)), LT(y, Real(10.0))) + formula = ((-1 <= y) & (0.0 <= x) & (y <= 1) & (x <= 10.0)) + conjunctions, disjunctions = separate_conjunctions_and_disjunctions(formula) + + # Define symbols + + # Create a formula + formula = And(LT(x, Real(10.0)), Or(LT(y, Real(10.0)), LT(y, Real(10.0)))) + + # Apply external negation + negated_formula = Not(formula) + + # Simplify the formula + simplified_formula = simplify(negated_formula) + + print(simplified_formula) + + + formula = parser.parse("p1==4 or (p1==8 and p2 > 3)") + # formula = parser.parse('p2<5.0 and x1==10 and x2<12.0') + x = Symbol('x', REAL) + y = Symbol('y', REAL) + a = Symbol('a', INT) + b = Symbol('b', INT) + + form=parser.and_((x <= Real(5.0)), # x < 5.0 + Equals(y, Real(10.0)), + Equals(a, Int(3)), # a < 3 + Equals(b, Int(4))) # b = 4) + print(formula) + print(form) + ts = parser.conjunction_to_disjunction(form) + print(ts) + + x2 = Symbol('x2', REAL) + x3 = Symbol('x3', REAL) + + disjunction = Or(LT(x2, Real(12.0)), GE(x3, Real(5.0))) + + comparisons_array = parser.split_disjunctions(disjunction) + + for comparison in comparisons_array: + print(comparison) + + x = Symbol('x', REAL) + y = Symbol('y', REAL) + z = Symbol('z', REAL) + + # Example formula with conjunctions and disjunctions + # formula = And(Or(LT(x, Real(10.0)), GT(y, Real(5.0))), LE(z, Real(3.0))) + formula = And(LT(x, Real(10.0)), GE(y, Real(4.0))) + t = parser.to_cnf(formula) + + + def frozenset_to_formula(cnf): + clauses = [] + for clause in cnf: + literals = [] + for literal in clause: + if literal.is_not(): + literals.append(Not(literal.arg(0))) + else: + literals.append(literal) + clauses.append(Or(literals)) + return And(clauses) + + + # Reconstruct the formula + t = frozenset_to_formula(t) + t = parser.simplify(t) + # Process the formula + res = parser.process_formula(formula) + print(res) \ No newline at end of file diff --git a/src/smlp_py/solver.py b/src/smlp_py/solver.py new file mode 100644 index 00000000..70100f20 --- /dev/null +++ b/src/smlp_py/solver.py @@ -0,0 +1,17 @@ +import functools +import types +from abc import ABC, abstractmethod +from enum import Enum + +import pysmt +import smlp +from pysmt.shortcuts import Real + +from src.smlp_py.NN_verifiers.verifiers import MarabouVerifier +from src.smlp_py.smtlib.text_to_sympy import TextToPysmtParser +import operator as op +from pysmt.shortcuts import Symbol, And + + + + diff --git a/src/smlp_py/solvers/abstract_solver.py b/src/smlp_py/solvers/abstract_solver.py new file mode 100644 index 00000000..363ccfaa --- /dev/null +++ b/src/smlp_py/solvers/abstract_solver.py @@ -0,0 +1,115 @@ +import functools +from abc import ABC, abstractmethod +import smlp + +USE_CACHE = False + + +def conditional_cache(func): + """Custom decorator to conditionally apply @functools.cache.""" + if USE_CACHE: + # Apply caching + return functools.cache(func) + else: + # Return the original function without caching + return func + + +class ClassProperty: + def __init__(self, fget): + self.fget = fget + + def __get__(self, instance, owner): + return self.fget(owner) + + +class AbstractSolver(ABC): + + @abstractmethod + def create_query(self, *args, **kwargs): + pass + + @abstractmethod + def create_query_and_beta(self, *args, **kwargs): + pass + + @abstractmethod + def substitute_objective_with_witness(self, *args, **kwargs): + pass + + @abstractmethod + def generate_rad_term(self, *args, **kwargs): + pass + + @abstractmethod + def create_theta_form(self, *args, **kwargs): + pass + + @abstractmethod + def get_rad_term(self, *args, **kwargs): + pass + + @abstractmethod + def create_alpha_or_eta_form(self, *args, **kwargs): + pass + + @abstractmethod + def parse_ast(self, *args, **kwargs): + pass + + @abstractmethod + def create_solver(self, *args, **kwargs): + pass + + @abstractmethod + def add_formula(self, *args, **kwargs): + pass + + @abstractmethod + def check(self, *args, **kwargs): + pass + + @abstractmethod + def create_counter_example(self, *args, **kwargs): + pass + + @abstractmethod + def substitute(self, *args, **kwargs): + pass + + @abstractmethod + def handle_ite_formula(self, *args, **kwargs): + pass + + @abstractmethod + def calculate_eta_F_t(self, *args, **kwargs): + pass + + @abstractmethod + def apply_theta(self, *args, **kwargs): + pass + + def get_witness(self, *args, **kwargs): + result = kwargs["result"] + witness = kwargs["witness"] + interface = kwargs["interface"] + + condition = result == "sat" + + if condition: + reduced_model = dict((k, v) for k, v in witness.items() if k in interface) + return reduced_model + else: + return None + + def convert_results_to_string(self, res): + if isinstance(res, smlp.sat): + return "sat" + elif isinstance(res, smlp.unsat): + return "unsat" + elif isinstance(res, smlp.unknown): + return "unknown" + elif type(res) == str: + return res.lower() + else: + raise Exception("Unsupported result format") diff --git a/src/smlp_py/solvers/marabou/operations.py b/src/smlp_py/solvers/marabou/operations.py new file mode 100644 index 00000000..d00e646c --- /dev/null +++ b/src/smlp_py/solvers/marabou/operations.py @@ -0,0 +1,64 @@ +import pysmt.shortcuts +import smlp +from pysmt.fnode import FNode +from src.smlp_py.solvers.abstract_solver import ClassProperty, conditional_cache + + +class PYSMTOperations: + + @ClassProperty + def smlp_true(cls): + return pysmt.shortcuts.TRUE() + + @ClassProperty + def smlp_false(cls): + return pysmt.shortcuts.FALSE() + + @ClassProperty + def smlp_real(cls): + return pysmt.shortcuts.Real + + @ClassProperty + def smlp_integer(cls): + return pysmt.shortcuts.Int + + @conditional_cache # @functools.cache + def smlp_cnst(cls, const): + if isinstance(const, FNode): + return const + elif isinstance(const, str): + const = float(const) + return pysmt.shortcuts.Real(const) + + # logical not (logic negation) + @conditional_cache # @functools.cache + def smlp_not(cls, form: FNode): + return pysmt.shortcuts.Not(form) + + # logical and (conjunction) + @conditional_cache # @functools.cache + def smlp_and(cls, form1: FNode, form2: FNode): + return pysmt.shortcuts.And(form1, form2) # form1 & form2 + + def smlp_and_multi(cls, form_list: list[FNode]): + return pysmt.shortcuts.And(*form_list) + + # logical or (disjunction) + @conditional_cache # @functools.cache + def smlp_or(cls, form1: FNode, form2: FNode): + return pysmt.shortcuts.Or(form1, form2) + + def smlp_or_multi(cls, form_list: list[FNode]): + return pysmt.shortcuts.Or(*form_list) + + def smlp_eq(self, term1: smlp.term2, term2: smlp.term2): + return pysmt.shortcuts.Equals(term1, term2) + + def smlp_q(self, const): + return pysmt.shortcuts.Real(const) + + def smlp_mult(self, *args): + return pysmt.shortcuts.Times(*args) + + def smlp_ite(self, *args): + return pysmt.shortcuts.Ite(*args) diff --git a/src/smlp_py/solvers/marabou/solver.py b/src/smlp_py/solvers/marabou/solver.py new file mode 100644 index 00000000..aec03889 --- /dev/null +++ b/src/smlp_py/solvers/marabou/solver.py @@ -0,0 +1,190 @@ +import pysmt + +from src.smlp_py.NN_verifiers.verifiers import MarabouVerifier +from src.smlp_py.smtlib.text_to_sympy import TextToPysmtParser +from src.smlp_py.solvers.abstract_solver import AbstractSolver, ClassProperty +from src.smlp_py.solvers.marabou.operations import PYSMTOperations +from pysmt.shortcuts import Real +from memory_profiler import profile + +class Pysmt_Solver(AbstractSolver, PYSMTOperations): + verifier = None + temp_solver = None + + def __init__(self, specs,data_bounds_file, model_file_prefix): + super().__init__() + self.specs = specs + self.data_bounds_file = data_bounds_file + self.model_file_prefix = model_file_prefix + self.create_verifier() + + def create_verifier(self): + symbols = [] + feat_names, resp_names, spec_domain_dict = self.specs + + for feature in feat_names: + type = spec_domain_dict[feature]['range'] + symbols.append((feature, type, True)) + + for response in resp_names: + type = spec_domain_dict[response]['range'] + symbols.append((response, type, False)) + + parser = TextToPysmtParser() + parser.init_variables(symbols=symbols) + + self.verifier = MarabouVerifier(parser=parser, data_bounds_file=self.data_bounds_file, model_file_prefix=self.model_file_prefix) + self.verifier.initialize(spec_domain_dict) + + @ClassProperty + def smlp_true(self): + return pysmt.shortcuts.TRUE() + + def smlp_var(self, var): + return self.verifier.parser.get_symbol(var) + + def create_query(self, query_form=None): + self.verifier.parser.handle_ite_formula(query_form, is_form2=True) + + def create_query_and_beta(self, query, beta): + return self.verifier.parser.and_(query, beta) + + def substitute_objective_with_witness(self, *args, **kwargs): + stable_witness_terms = kwargs["stable_witness_terms"] + objv_term = kwargs["objv_term"] + + substitution = {} + for symbol, value in stable_witness_terms.items(): + symbol = self.verifier.parser.get_symbol(symbol) + substitution[symbol] = Real(float(value)) + # Apply the substitution + return self.verifier.parser.simplify(objv_term.substitute(substitution)) + + def generate_rad_term(self, **kwargs): + rad = kwargs["rad"] + return float(rad) + + def get_rad_term(self, **kwargs): + rad = kwargs["rad"] + return float(rad) + + def create_theta_form(self, **kwargs): + witness = kwargs["witness"] + var = kwargs["var"] + rad_term = kwargs["rad_term"] + theta_form = kwargs["theta_form"] + + rad_term = float(rad_term) + value = float(witness) + PYSMT_var = self.verifier.parser.get_symbol(var) + type = pysmt.shortcuts.Int if str(PYSMT_var.get_type()) == "Int" else Real + calc_type = int if str(PYSMT_var.get_type()) == "Int" else float + lower = calc_type(value - rad_term) + lower = type(lower) + upper = calc_type(value + rad_term) + upper = type(upper) + theta_form = self.verifier.parser.and_(theta_form, PYSMT_var >= lower, PYSMT_var <= upper) + return theta_form + + def create_alpha_or_eta_form(self, **kwargs): + alpha_or_eta_form = kwargs["alpha_or_eta_form"] + mx = kwargs["mx"] + mn = kwargs["mn"] + v = kwargs["v"] + + symbol_v = self.smlp_var(v) + form = self.smlp_and(symbol_v >= mn, symbol_v <= mx) + return self.simplify(self.smlp_and(alpha_or_eta_form, form)) + + def simplify(self, expression): + return self.verifier.parser.simplify(expression) + + def z3_simplify(self, expression): + return self.verifier.parser.simplify(expression) + + def parse(self, expression): + return self.verifier.parser.parse(expression) + + def GE(self, *args): + return args[0] >= args[1] + + def parse_ast(self, *args, **kwargs): + expression = kwargs['expression'] + return self.parse(expression) + + def create_solver(self, *args, **kwargs): + temp = kwargs.get('temp', False) + if temp: + self.temp_solver = MarabouVerifier(parser=self.verifier.parser, + variable_ranges=self.verifier.variable_ranges, + is_temp=True) + + else: + self.verifier.reset() + return self + + def add_formula(self, formula, **kwargs): + need_simplification = kwargs.get("need_simplification", False) + self.verifier.apply_restrictions(formula, need_simplification=need_simplification) + + def check(self, *args, **kwargs): + temp = kwargs.get("temp", False) + if temp: + result = self.temp_solver.solve() + self.temp_solver = None + return result + else: + return self.verifier.solve() + + def generate_theta(self, *args, **kwargs): + pass + + @profile + def create_counter_example(self, *args, **kwargs): + formulas = kwargs["formulas"] + query = kwargs["query"] + + self.temp_solver = MarabouVerifier( parser=self.verifier.parser, + variable_ranges=self.verifier.variable_ranges, + is_temp=True) + for formula in formulas: + self.temp_solver.apply_restrictions(formula) + + negation = self.temp_solver.parser.propagate_negation(query) + # z3_equiv = self.temp_solver.parser.handle_ite_formula(negation, handle_ite=False) + self.temp_solver.apply_restrictions(negation, need_simplification=True) + return self + + def substitute(self, *args, **kwargs): + var = kwargs["var"] + substitutions = kwargs["substitutions"] + + for x in list(substitutions.keys()): + temp = substitutions[x] + del substitutions[x] + substitutions[self.smlp_var(x)] = temp + + return self.simplify(var.substitute(substitutions)) + + def calculate_eta_F_t(self, *args, **kwargs): + eta = kwargs["eta"] + term = kwargs["term"] + val = kwargs["val"] + + return self.smlp_and(eta, term > self.smlp_cnst(val)) + + + def handle_ite_formula(self, *args, **kwargs): + formula = kwargs["formula"] + + return self.verifier.parser.handle_ite_formula(formula, is_form2=False) + + def apply_theta(self, *args, **kwargs): + formula = kwargs["formula"] + solver = kwargs["solver"] + + theta_negation = self.verifier.parser.propagate_negation(formula) + # self._modelTermsInst.verifier.add_permanent_constraint(theta_negation) + solver.verifier.apply_restrictions(theta_negation) + print("PYSMT THETA ADDED ", theta_negation) + diff --git a/src/smlp_py/solvers/universal_solver.py b/src/smlp_py/solvers/universal_solver.py new file mode 100644 index 00000000..4887f65e --- /dev/null +++ b/src/smlp_py/solvers/universal_solver.py @@ -0,0 +1,75 @@ +from enum import Enum +import types +from src.smlp_py.solvers.z3.solver import Form2_Solver +from src.smlp_py.solvers.marabou.solver import Pysmt_Solver + + +class Solver: + class Version(Enum): + FORM2 = 0 + PYSMT = 1 + + _instance = None + version = None + + def __new__(cls, *args, **kwargs): + version = kwargs["version"] + if isinstance(version, cls.Version): + cls.version = version + else: + raise ValueError("Must be a valid version") + + if cls._instance is None and isinstance(cls.version, cls.Version): + if cls.version == cls.Version.PYSMT: + specs = kwargs["specs"] + data_bounds_file = kwargs["data_bounds_file"] + model_file_prefix = kwargs["model_file_prefix"] + cls._instance = Pysmt_Solver(specs, data_bounds_file, model_file_prefix) + else: + cls._instance = Form2_Solver() + cls._map_instance_methods() + return cls._instance + + @classmethod + def _map_instance_methods(cls): + """Automatically maps all methods from the instance to the SingletonFactory class.""" + for base_class in cls._instance.__class__.__mro__: + for name, method in base_class.__dict__.items(): + if isinstance(method, types.FunctionType): + if not hasattr(cls, name): + setattr(cls, name, cls._create_delegator(name)) + + @classmethod + def _create_delegator(cls, method_name): + """Create a method that delegates the call to the _instance.""" + def delegator(*args, **kwargs): + return getattr(cls._instance, method_name)(*args, **kwargs) + return delegator + + + + + # + # @classmethod + # def _map_instance_properties(cls): + # """Automatically maps all properties from the instance to the Solver class.""" + # for name, attribute in cls._instance.__class__.__dict__.items(): + # if isinstance(attribute, property): + # # Map property to Solver class + # if not hasattr(cls, name): + # setattr(cls, name, cls._create_property_delegator(name)) + # + # @classmethod + # def _create_property_delegator(cls, property_name): + # """Create a property that delegates access to the _instance.""" + # def getter(self): + # return getattr(self._instance, property_name) + # + # def setter(self, value): + # setattr(self._instance, property_name, value) + # + # def deleter(self): + # delattr(self._instance, property_name) + # + # # Return a property with the mapped getter, setter, and deleter + # return property(getter, setter, deleter) diff --git a/src/smlp_py/solvers/z3/operations.py b/src/smlp_py/solvers/z3/operations.py new file mode 100644 index 00000000..195bfb4e --- /dev/null +++ b/src/smlp_py/solvers/z3/operations.py @@ -0,0 +1,191 @@ +import smlp +import operator as op +from src.smlp_py.solvers.abstract_solver import conditional_cache + + +class SMLPOperations: + @property + @conditional_cache # @functools.cache + def smlp_true(self): + return smlp.true + + @property + @conditional_cache # @functools.cache + def smlp_false(self): + return smlp.false + + @property + @conditional_cache # @functools.cache + def smlp_real(self): + return smlp.Real + + @property + @conditional_cache # @functools.cache + def smlp_integer(self): + return smlp.Integer + + @conditional_cache # @functools.cache + def smlp_var(self, var): + return smlp.Var(var) + + @conditional_cache # @functools.cache + def smlp_cnst(self, const): + return smlp.Cnst(const) + + # rationals + @conditional_cache # @functools.cache + def smlp_q(self, const): + return smlp.Q(const) + + # reals + @conditional_cache # @functools.cache + def smlp_r(self, const): + return smlp.R(const) + + # logical not (logic negation) + @conditional_cache # @functools.cache + def smlp_not(self, form: smlp.form2): + # res1 = ~form + res2 = op.inv(form) + # assert res1 == res2 + return res2 # ~form + + # logical and (conjunction) + @conditional_cache # @functools.cache + def smlp_and(self, form1: smlp.form2, form2: smlp.form2): + ''' test 83 gets stuck with this simplification + if form1 == smlp.true: + return form2 + if form2 == smlp.true: + return form1 + ''' + res1 = op.and_(form1, form2) + # res2 = form1 & form2 + # print('res1', res1, type(res1)); print('res2', res2, type(res2)) + # assert res1 == res2 + return res1 # form1 & form2 + + # conjunction of possibly more than two formulas + # @functools.cache -- error: unhashable type: 'list' + def smlp_and_multi(self, form_list: list[smlp.form2]): + res = self.smlp_true + ''' + for i, form in enumerate(form_list): + res = form if i == 0 else self.smlp_and(res, form) + ''' + for form in form_list: + res = form if res is self.smlp_true else self.smlp_and(res, form) + return res + + # logical or (disjunction) + @conditional_cache # @functools.cache + def smlp_or(self, form1: smlp.form2, form2: smlp.form2): + res1 = op.or_(form1, form2) + # res2 = form1 | form2 + # assert res1 == res2 + return res1 # form1 | form2 + + # disjunction of possibly more than two formulas + # @functools.cache -- error: unhashable type: 'list' + def smlp_or_multi(self, form_list: list[smlp.form2]): + res = self.smlp_false + ''' + for i, form in enumerate(form_list): + res = form if i == 0 else self.smlp_or(res, form) + ''' + for form in form_list: + res = form if res is self.smlp_false else self.smlp_or(res, form) + return res + + # logical implication + @conditional_cache # @functools.cache + def smlp_implies(self, form1: smlp.form2, form2: smlp.form2): + return self.smlp_or(self.smlp_not(form1), form2) + + # addition + @conditional_cache # @functools.cache + def smlp_add(self, term1: smlp.term2, term2: smlp.term2): + return op.add(term1, term2) + + # sum of possibly more than two formulas + # @functools.cache -- error: unhashable type: 'list' + def smlp_add_multi(self, term_list: list[smlp.term2]): + for i, term in enumerate(term_list): + res = term if i == 0 else self.smlp_add(res, term) + return res + + # subtraction + @conditional_cache # @conditional_cache #@functools.cache + def smlp_sub(self, term1: smlp.term2, term2: smlp.term2): + return op.sub(term1, term2) + + # multiplication + @conditional_cache # @functools.cache + def smlp_mult(self, term1: smlp.term2, term2: smlp.term2): + return op.mul(term1, term2) + + # TODO: !!! check that term2 does not evaluate to term 0 ??? + + # Do this before calling smlp_div, whenver possible? + @conditional_cache # @functools.cache + def smlp_div(self, term1: smlp.term2, term2: smlp.term2): + # return self.smlp_mult(self.smlp_cnst(self.smlp_q(1)) / term2, term1) + return self.smlp_mult(op.truediv(self.smlp_cnst(self.smlp_q(1)), term2), term1) + + @conditional_cache # @functools.cache + def smlp_pow(self, term1: smlp.term2, term2: smlp.term2): + return op.pow(term1, term2) + + # equality + @conditional_cache # @functools.cache + def smlp_eq(self, term1: smlp.term2, term2: smlp.term2): + res1 = op.eq(term1, term2) + # res2 = term1 == term2; print('res1', res1, 'res2', res2) + # assert res1 == res2 + return res1 + + # operator != (not equal) + @conditional_cache # @functools.cache + def smlp_ne(self, term1: smlp.term2, term2: smlp.term2): + res1 = op.ne(term1, term2) + # res2 = term1 != term2; print('res1', res1, 'res2', res2) + # assert res1 == res2 + return res1 + + # operator < + @conditional_cache # @functools.cache + def smlp_lt(self, term1: smlp.term2, term2: smlp.term2): + return op.lt(term1, term2) + + # operator <= + + @conditional_cache # @functools.cache + def smlp_le(self, term1: smlp.term2, term2: smlp.term2): + return op.le(term1, term2) + + # operator > + @conditional_cache # @functools.cache + def smlp_gt(self, term1: smlp.term2, term2: smlp.term2): + return op.gt(term1, term2) + + # operator >= + + @conditional_cache # @functools.cache + def smlp_ge(self, term1: smlp.term2, term2: smlp.term2): + return op.ge(term1, term2) + + # if-thne-else operation + @conditional_cache # @functools.cache + def smlp_ite(self, form: smlp.form2, term1: smlp.term2, term2: smlp.term2): + return smlp.Ite(form, term1, term2) + + # this function performs substitution of variables in term2: + # it substitutes occurrences of the keys in subst_dict with respective values, in term2 term. + # @functools.cache + def smlp_subst(self, term: smlp.term2, subst_dict: dict): + return smlp.subst(term, subst_dict) + + # simplifies a ground term to the respective constant; takes als a s + # @functools.cache + def smlp_cnst_fold(self, term: smlp.term2, subst_dict: dict): + return smlp.cnst_fold(term, subst_dict) diff --git a/src/smlp_py/solvers/z3/solver.py b/src/smlp_py/solvers/z3/solver.py new file mode 100644 index 00000000..90e23d58 --- /dev/null +++ b/src/smlp_py/solvers/z3/solver.py @@ -0,0 +1,143 @@ +import smlp +from src.smlp_py.solvers.abstract_solver import AbstractSolver +from src.smlp_py.solvers.z3.operations import SMLPOperations +from memory_profiler import profile + +class Form2_Solver(AbstractSolver, SMLPOperations): + verifier = None + + def __init__(self): + super().__init__() + # self.verifier = verifier + + @property + def smlp_true(self): + return smlp.true + + def create_query(self, query_form=None): + return query_form + + def create_query_and_beta(self, query, beta): + return self.smlp_and(query, beta) + + def substitute_objective_with_witness(self, *args, **kwargs): + stable_witness_terms = kwargs["stable_witness_terms"] + objv_term = kwargs["objv_term"] + + return smlp.subst(objv_term, stable_witness_terms) + + def generate_rad_term(self, *args, **kwargs): + rad = kwargs["rad"] + delta_rel = kwargs["delta_rel"] + var_term = kwargs["var_term"] + candidate = kwargs["candidate"] + + rad_term = self.smlp_cnst(rad) + if delta_rel is not None: # radius for a lemma -- cex holds values of candidate counter-example + rad_term = rad_term * abs(var_term) + else: # radius for excluding a candidate -- cex holds values of the candidate + rad_term = rad_term * abs(candidate) + + return rad_term + + def create_theta_form(self, *args, **kwargs): + theta_form = kwargs["theta_form"] + witness = kwargs["witness"] + var_term = kwargs["var_term"] + rad_term = kwargs["rad_term"] + + return self.smlp_and(theta_form, ((abs(var_term - witness)) <= rad_term)) + + def get_rad_term(self, *args, **kwargs): + rad = kwargs["rad"] + return self.smlp_cnst(rad) + + def create_alpha_or_eta_form(self, *args, **kwargs): + alpha_or_eta_form = kwargs["alpha_or_eta_form"] + is_in_spec = kwargs["is_in_spec"] + is_disjunction = kwargs["is_disjunction"] + is_alpha = kwargs["is_alpha"] + mx = kwargs["mx"] + mn = kwargs["mn"] + v = kwargs["v"] + + if is_disjunction and is_alpha and is_in_spec: + rng = self.smlp_or_multi([self.smlp_eq(self.smlp_var(v), self.smlp_cnst(i)) for i in range(mn, mx + 1)]) + else: + rng = self.smlp_and(self.smlp_var(v) >= self.smlp_cnst(mn), self.smlp_var(v) <= self.smlp_cnst(mx)) + + return self.smlp_and(alpha_or_eta_form, rng) + + def GE(self, *args): + return args[0] >= args[1] + + def parse_ast(self, *args, **kwargs): + expression = kwargs['expression'] + parser = kwargs['parser'] + return parser(expression) + + def create_solver(self, *args, **kwargs): + create_solver = kwargs["create_solver"] + domain = kwargs["domain"] + model_full_term_dict = kwargs["model_full_term_dict"] + incremental = kwargs["incremental"] + solver_logic = kwargs["solver_logic"] + + self.verifier = create_solver(domain, model_full_term_dict, incremental, solver_logic) + return self + + def simplify(self, *args, **kwargs): + formula = kwargs["formula"] + + return formula + + def z3_simplify(self, expression): + return expression + + + def add_formula(self,formula, **kwargs): + self.verifier.add(formula) + + def check(self, *args, **kwargs): + return self.verifier.check(), None + + def generate_theta(self, *args, **kwargs): + pass + + @profile + def create_counter_example(self, *args, **kwargs): + formulas = kwargs["formulas"] + query = kwargs["query"] + + self.create_solver(*args, **kwargs) + for formula in formulas: + self.add_formula(formula) + + self.add_formula(self.smlp_not(query)) + return self + + def substitute(self, *args, **kwargs): + var = kwargs["var"] + substitutions = kwargs["substitutions"] + + return self.smlp_cnst_fold(var, substitutions) + + def calculate_eta_F_t(self, *args, **kwargs): + eta = kwargs["eta"] + term = kwargs["term"] + val = kwargs["val"] + + return self.smlp_and(eta, term > self.smlp_cnst(val)) + + def handle_ite_formula(self, *args, **kwargs): + formula = kwargs["formula"] + + return formula + + def apply_theta(self, *args, **kwargs): + formula = kwargs["formula"] + solver = kwargs["solver"] + + solver.add_formula(self.smlp_not(formula)) + + diff --git a/src/smlp_py/train_keras.py b/src/smlp_py/train_keras.py index 12d79351..cad1a32e 100644 --- a/src/smlp_py/train_keras.py +++ b/src/smlp_py/train_keras.py @@ -38,8 +38,8 @@ def __init__(self): # hyper parameter defaults self._DEF_LAYERS_SPEC = '2,1' - self._DEF_EPOCHS = 2000 - self._DEF_BATCH_SIZE = 200 + self._DEF_EPOCHS = 200 + self._DEF_BATCH_SIZE = 10 self._DEF_OPTIMIZER = 'adam' # options: 'rmsprop', 'adam', 'sgd', 'adagrad', 'nadam' self._DEF_LEARNING_RATE = 0.001 self._HID_ACTIVATION = 'relu' diff --git a/src/smlp_py/vnnlib/__init__.py b/src/smlp_py/vnnlib/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/src/smlp_py/vnnlib/vnnlib_parser.py b/src/smlp_py/vnnlib/vnnlib_parser.py new file mode 100755 index 00000000..7da7f3e8 --- /dev/null +++ b/src/smlp_py/vnnlib/vnnlib_parser.py @@ -0,0 +1,356 @@ + +import re + +import sympy +from sympy import * +from sympy.logic.boolalg import And, Or, Not +from sympy.core.relational import Relational +from sympy.core.numbers import * +import ast +import smlp +import pysmt + +def variable_init_string(variable): + return f"(declare-const {variable} Real)\n" + + +def replace_let_vars(expression): + pattern = r"\(let \(\((\|:\d+\|) \((.+?)\)\)\) (.*)\)" + while 'let' in expression: + match = re.search(pattern, expression) + if not match: + break + + var_name, var_expr, rest_expression = match.groups() + + rest_expression = rest_expression.replace(var_name, f"({var_expr})") + + expression = rest_expression + + return expression + + +class VnnLibParser(object): + + def __init__(self, inputs=[], outputs=[]): + from maraboupy import Marabou, MarabouCore + + self.variables = {} + self.file_name = "query.vnnlib" + self.inputs = inputs + self.outputs = outputs + self.final_output = "" + self.sympy_expr = "" + self.symbols = None + self.symbol_dict = None + self.global_variable_constraints = {} + self.formula_list = [] + self.initialize() + + self.text_parser = TextToSympyParser(self.symbol_dict) + + def initialize(self): + self.initialize_variables(self.inputs, "X") + self.initialize_variables(self.outputs, "Y") + + symb = "" + for symbol in self.inputs: + symb += f"{symbol} " + + for symbol in self.outputs: + symb += f"{symbol} " + + symb = symb[:-1] + self.symbols = symbols(symb) + self.symbol_dict = {str(sym): sym for sym in self.symbols} + self.symbol_dict.update({ + 'And': And, + 'Or': Or, + 'Not': Not + }) + + + def replace_names(self): + for key, variable in self.variables.items(): + self.final_output = self.final_output.replace(key, variable) + + def initialize_variables(self, data, variable_name): + for i, arg in enumerate(data): + variable = f"{variable_name}_{i}" + self.variables[arg] = variable + self.final_output += variable_init_string(variable) + + def finalize(self): + self.replace_names() + with open(self.file_name, 'w') as file: + file.write(self.final_output) + self.call_marabou() + + + def add(self, expression): + expression = replace_let_vars(expression) + self.final_output += f"(assert \n ({expression}) \n )\n" + + def and_(self, exp1, exp2): + return sympy.logic.boolalg.And(exp1, exp2) + + def and_multi_(self, forms): + res = True + for form in forms: + res = form if res is True else self.and_(res, form) + return res + def or_(self, exp1, exp2): + return sympy.logic.boolalg.Or(exp1, exp2) + + def equal(self, var, num): + return sympy.And(var <= num, var >= num) + + def get_symbol(self, symbol): + return self.symbol_dict.get(symbol) + + def save(self, formula): + self.formula_list.append(formula) + + def not_(self, expression): + expression = replace_let_vars(expression) + + def sympy_and(self, exp1, exp2): + return f"And({exp1}, {exp2})" + + def add_global_constaints(self, variable, expression): + # check if the constraint exists in our list + assert variable in self.symbol_dict + + self.global_variable_constraints[variable] = expression + + + + def call_marabou(self): + onnx_file = "/home/ntinouldinho/Desktop/smlp/src/test.onnx" + property_filename = f"/home/ntinouldinho/Desktop/smlp/src/{self.file_name}" + + network = Marabou.read_onnx(onnx_file) + network.saveQuery("./query.txt") + ipq = Marabou.load_query("./query.txt") + MarabouCore.loadProperty(ipq, property_filename) + exitCode_ipq, vals_ipq, _ = Marabou.solve_query(ipq, propertyFilename=property_filename, filename="res.log") + + + + +class SymbolicExpressionHandler: + def __init__(self, variables): + self.symbols = symbols(variables) + self.symbol_dict = {str(sym): sym for sym in self.symbols} + self.symbol_dict.update({ + 'And': And, + 'Or': Or, + 'Not': Not + }) + + def parse_expression(self, expression): + return parse_expr(expression, local_dict=self.symbol_dict) + + def sympy_to_vnnlib(self, expr): + + def parse_expr(e): + # Handle logical operators + if isinstance(e, And): + return "(and {})".format(' '.join(parse_expr(arg) for arg in e.args)) + elif isinstance(e, Or): + return "(or {})".format(' '.join(parse_expr(arg) for arg in e.args)) + elif isinstance(e, Not): + return "(not {})".format(parse_expr(e.args[0])) + + elif isinstance(e, Relational): + left = parse_expr(e.lhs) + right = parse_expr(e.rhs) + op = e.rel_op + return "({} {} {})".format(op, left, right) + + elif isinstance(e, Abs): + return "(abs {})".format(parse_expr(e.args[0])) + + elif isinstance(e, (Symbol, Integer, Float, Zero, One)): + return str(e) + + else: + raise TypeError("Unsupported type: {}".format(type(e))) + + # Start the parsing + return parse_expr(expr) + + +class TextToSympyParser(object): + def __init__(self, map): + self.symbols = map + self._ast_operators_map = { + ast.Add: sympy.Add, # Addition + ast.Sub: sympy.Mul, # Subtraction (handles through Mul with -1) + ast.Mult: sympy.Mul, # Multiplication + ast.Div: sympy.div, # Division (true division, use sympy.Mul with sympy.Pow for reciprocal) + ast.Pow: sympy.Pow, # Exponentiation + ast.BitXor: sympy.Xor, # Bitwise XOR + + ast.USub: sympy.Mul, # Unary subtraction (negation, effectively multiplying by -1) + + ast.Eq: sympy.Eq, # Equality + ast.NotEq: sympy.Ne, # Not equal + ast.Lt: sympy.Lt, # Less than + ast.LtE: sympy.Le, # Less than or equal to + ast.Gt: sympy.Gt, # Greater than + ast.GtE: sympy.Ge, # Greater than or equal to + + ast.And: sympy.And, # Logical AND + ast.Or: sympy.Or, # Logical OR + ast.Not: sympy.Not, # Logical NOT + + ast.IfExp: sympy.ITE # If expression + } + + def parse(self, expr): + # print('evaluating AST expression ====', expr) + assert isinstance(expr, str) + symbol_list = self.symbols + + # recursion + def eval_(node): + if isinstance(node, ast.Num): # + # print('node Num', node.n, type(node.n)) + return sympy.Float(node.n) + elif isinstance(node, ast.BinOp): # + # print('node BinOp', node.op, type(node.op)) + if type(node.op) not in [ast.Div, ast.Pow]: + return self._ast_operators_map[type(node.op)](eval_(node.left), eval_(node.right)) + elif type(node.op) == ast.Div: + if type(node.right) == ast.Constant: + if node.right.n == 0: + raise Exception('Division by 0 in parsed expression ' + expr) + elif not isinstance(node.right.n, int): + raise Exception( + 'Division in parsed expression is only supported for integer constants; got ' + expr) + else: + # print('node.right.n', node.right.n, type(node.right.n)) + return self._ast_operators_map[ast.Mult](smlp.Cnst(smlp.Q(1) / smlp.Q(node.right.n)), + eval_(node.left)) + else: + raise Exception('Opreator ' + str(self._ast_operators_map[type(node.op)]) + + ' with non-constant demominator within ' + str( + expr) + ' is not supported in ast_expr_to_term') + elif type(node.op) == ast.Pow: + if type(node.right) == ast.Constant: + if type(node.right.n) == int: + # print('node.right.n', node.right.n) + if node.right.n == 0: + return sympy.Float(1) + elif node.right.n > 0: + left_term = res_pow = eval_(node.left) + for i in range(1, node.right.n): + res_pow = sympy.Mul(res_pow, left_term) + # print('res_pow', res_pow) + return res_pow + raise Exception('Opreator ' + str(self._ast_operators_map[type(node.op)]) + + ' with non-constant or negative exponent within ' + + str(expr) + 'is not supported in ast_expr_to_term') + else: + raise Exception('Implementation error in function ast_expr_to_term') + elif isinstance(node, ast.UnaryOp): # e.g., -1 + # print('unary op', node.op, type(node.op)); + return self._ast_operators_map[type(node.op)](eval_(node.operand)) + elif isinstance(node, ast.Name): # variable + # print('node Var', node.id, type(node.id)) + return symbol_list[node.id] + elif isinstance(node, ast.BoolOp): + res_boolop = self._ast_operators_map[type(node.op)](eval_(node.values[0]), eval_(node.values[1])) + if len(node.values) > 2: + for i in range(2, len(node.values)): + res_boolop = self._ast_operators_map[type(node.op)](res_boolop, eval_(node.values[i])) + # print('res_boolop', res_boolop) + return res_boolop + elif isinstance(node, ast.Compare): + # print('node Compare', node.ops, type(node.ops), 'left', node.left, 'comp', node.comparators); + # print('len ops', len(node.ops), 'len comparators', len(node.comparators)) + assert len(node.ops) == len(node.comparators) + left_term_0 = eval_(node.left) + right_term_0 = eval_(node.comparators[0]) + if type(node.ops[0]) == ast.Eq: + # if x==10 then x<=10 and x>=10 + if type(left_term_0) == sympy.Symbol and type(right_term_0) == sympy.Float: + res_comp = sympy.And(left_term_0 <= right_term_0, left_term_0 >= right_term_0) + else: + res_comp = self._ast_operators_map[type(node.ops[0])](left_term_0, + right_term_0); # print('res_comp_0', res_comp) + if len(node.ops) > 1: + # print('enum', list(range(1, len(node.ops)))) + left_term_i = right_term_0 + for i in range(1, len(node.ops)): + right_term_i = eval_(node.comparators[i]) + # print('i', i, 'left', left_term_i, 'right', right_term_i) + res_comp_i = self._ast_operators_map[type(node.ops[i])](left_term_i, right_term_i) + res_comp = sympy.And(res_comp, res_comp_i) # self._ast_operators_map[type(node.op.And)] + # for the next iteration (if any): + left_term_i = right_term_i + # print('res_comp', res_comp) + return res_comp + elif isinstance(node, ast.List): + # print('node List', 'elts', node.elts, type(node.elts), 'expr_context', node.expr_context); + raise Exception('Parsing expressions with lists is not supported') + elif isinstance(node, ast.Constant): + if node.n == True: + return True + if node.n == False: + return False + raise Exception('Unsupported comstant ' + str(node.n) + ' in funtion ast_expr_to_term') + elif isinstance(node, ast.IfExp): + res_test = eval_(node.test) + res_body = eval_(node.body) + res_orelse = eval_(node.orelse) + # res_ifexp = smlp.Ite(res_test, res_body, res_orelse) + res_ifexp = self._ast_operators_map[ast.IfExp](res_test, res_body, res_orelse) + # print('res_ifexp',res_ifexp) + return res_ifexp + else: + print('Unexpected node type ' + str(type(node))) + # print('node type', type(node)) + raise TypeError(node) + + return eval_(ast.parse(expr, mode='eval').body) + + +if __name__ == "__main__": + # variable_names = 'x1 x2 p1 p2 y1 y2' + # handler = SymbolicExpressionHandler(variable_names) + # expr_string = "Or(And(x1<=0, y1>9), And(x2<7, y2>10))" + # parsed_expr = handler.parse_expression(expr_string) + # b = Not(parsed_expr) + # b = simplify_logic(b) + # print(b) + # + from sympy import symbols, parse_expr, And + from sympy.parsing.sympy_parser import parse_expr, standard_transformations + + # Define symbols + p2, x1, x2 = symbols('p2 x1 x2') + + # Define the expression string + expression_str = "(p2 < 5) & (x1 == 10) & (x2 < 12)" + + # Standard transformations + transformations = standard_transformations + + # Parse the expression + expr = parse_expr(expression_str, transformations=transformations, local_dict={'p2': p2, 'x1': x1, 'x2': x2}) + + # Print the parsed expression + print(expr) + + + + + + + + + + +