diff --git a/include/souper/Extractor/Solver.h b/include/souper/Extractor/Solver.h index 3eff5543a..7e73cecb4 100644 --- a/include/souper/Extractor/Solver.h +++ b/include/souper/Extractor/Solver.h @@ -30,11 +30,11 @@ class Solver { public: virtual ~Solver(); virtual std::error_code - infer(const BlockPCs &BPCs, const std::vector &PCs, - Inst *LHS, Inst *&RHS, InstContext &IC) = 0; + infer(const BlockPCs &BPCs, const std::vector &PCs, + Inst *LHS, std::vector &RHSs, InstContext &IC) = 0; virtual std::error_code isValid(InstContext &IC, const BlockPCs &BPCs, - const std::vector &PCs, + const std::vector &PCs, InstMapping Mapping, bool &IsValid, std::vector> *Model) = 0; virtual std::string getName() = 0; diff --git a/include/souper/Infer/ExhaustiveSynthesis.h b/include/souper/Infer/ExhaustiveSynthesis.h index 5c184741b..bc2a2262f 100644 --- a/include/souper/Infer/ExhaustiveSynthesis.h +++ b/include/souper/Infer/ExhaustiveSynthesis.h @@ -30,7 +30,7 @@ class ExhaustiveSynthesis { std::error_code synthesize(SMTLIBSolver *SMTSolver, const BlockPCs &BPCs, const std::vector &PCs, - Inst *TargetLHS, Inst *&RHS, + Inst *TargetLHS, std::vector &RHS, InstContext &IC, unsigned Timeout); }; diff --git a/lib/Extractor/Solver.cpp b/lib/Extractor/Solver.cpp index ef786089d..36a88099f 100644 --- a/lib/Extractor/Solver.cpp +++ b/lib/Extractor/Solver.cpp @@ -75,7 +75,8 @@ class BaseSolver : public Solver { std::error_code infer(const BlockPCs &BPCs, const std::vector &PCs, - Inst *LHS, Inst *&RHS, InstContext &IC) override { + Inst *LHS, std::vector &RHSs, + InstContext &IC) override { std::error_code EC; /* @@ -103,7 +104,7 @@ class BaseSolver : public Solver { if (EC) return EC; if (!IsSat) { - RHS = I; + RHSs.emplace_back(I); return EC; } } @@ -142,7 +143,7 @@ class BaseSolver : public Solver { if (EC) return EC; if (!IsSat) { - RHS = Const; + RHSs.emplace_back(Const); return EC; } } @@ -188,7 +189,7 @@ class BaseSolver : public Solver { if (EC) return EC; if (!SmallQueryIsSat) { - RHS = I; + RHSs.emplace_back(I); break; } } @@ -214,18 +215,21 @@ class BaseSolver : public Solver { if(SMTSolver->supportsModels()) { if (EnableExhaustiveSynthesis) { ExhaustiveSynthesis ES; - EC = ES.synthesize(SMTSolver.get(), BPCs, PCs, LHS, RHS, IC, Timeout); - if (EC || RHS) + EC = ES.synthesize(SMTSolver.get(), BPCs, PCs, LHS, RHSs, + IC, Timeout); + if (EC || !RHSs.empty()) return EC; } else if (InferInsts) { InstSynthesis IS; + Inst *RHS; EC = IS.synthesize(SMTSolver.get(), BPCs, PCs, LHS, RHS, IC, Timeout); + RHSs.emplace_back(RHS); if (EC || RHS) return EC; } } - RHS = 0; + RHSs.clear(); return EC; } @@ -281,16 +285,17 @@ class MemCachingSolver : public Solver { std::error_code infer(const BlockPCs &BPCs, const std::vector &PCs, - Inst *LHS, Inst *&RHS, InstContext &IC) override { + Inst *LHS, std::vector &RHSs, + InstContext &IC) override { ReplacementContext Context; std::string Repl = GetReplacementLHSString(BPCs, PCs, LHS, Context); const auto &ent = InferCache.find(Repl); if (ent == InferCache.end()) { ++MemMissesInfer; - std::error_code EC = UnderlyingSolver->infer(BPCs, PCs, LHS, RHS, IC); + std::error_code EC = UnderlyingSolver->infer(BPCs, PCs, LHS, RHSs, IC); std::string RHSStr; - if (!EC && RHS) { - RHSStr = GetReplacementRHSString(RHS, Context); + if (!EC && !RHSs.empty()) { + RHSStr = GetReplacementRHSString(RHSs.front(), Context); } InferCache.emplace(Repl, std::make_pair(EC, RHSStr)); return EC; @@ -299,12 +304,12 @@ class MemCachingSolver : public Solver { std::string ES; StringRef S = ent->second.second; if (S == "") { - RHS = 0; + RHSs.clear(); } else { ParsedReplacement R = ParseReplacementRHS(IC, "", S, Context, ES); if (ES != "") return std::make_error_code(std::errc::protocol_error); - RHS = R.Mapping.RHS; + RHSs.emplace_back(R.Mapping.RHS); } return ent->second.first; } @@ -351,7 +356,8 @@ class ExternalCachingSolver : public Solver { std::error_code infer(const BlockPCs &BPCs, const std::vector &PCs, - Inst *LHS, Inst *&RHS, InstContext &IC) override { + Inst *LHS, std::vector &RHSs, + InstContext &IC) override { ReplacementContext Context; std::string LHSStr = GetReplacementLHSString(BPCs, PCs, LHS, Context); if (LHSStr.length() > MaxLHSSize) @@ -360,26 +366,26 @@ class ExternalCachingSolver : public Solver { if (KV->hGet(LHSStr, "result", S)) { ++ExternalHits; if (S == "") { - RHS = 0; + RHSs.clear(); } else { std::string ES; ParsedReplacement R = ParseReplacementRHS(IC, "", S, Context, ES); if (ES != "") return std::make_error_code(std::errc::protocol_error); - RHS = R.Mapping.RHS; + RHSs.emplace_back(R.Mapping.RHS); } return std::error_code(); } else { ++ExternalMisses; if (NoInfer) { - RHS = 0; + RHSs.clear(); KV->hSet(LHSStr, "result", ""); return std::error_code(); } - std::error_code EC = UnderlyingSolver->infer(BPCs, PCs, LHS, RHS, IC); + std::error_code EC = UnderlyingSolver->infer(BPCs, PCs, LHS, RHSs, IC); std::string RHSStr; - if (!EC && RHS) { - RHSStr = GetReplacementRHSString(RHS, Context); + if (!EC && !RHSs.empty()) { + RHSStr = GetReplacementRHSString(RHSs.front(), Context); } KV->hSet(LHSStr, "result", RHSStr); return EC; diff --git a/lib/Infer/ExhaustiveSynthesis.cpp b/lib/Infer/ExhaustiveSynthesis.cpp index 09a260a1a..9eb962207 100644 --- a/lib/Infer/ExhaustiveSynthesis.cpp +++ b/lib/Infer/ExhaustiveSynthesis.cpp @@ -22,6 +22,7 @@ static const unsigned MaxTries = 30; static const unsigned MaxInputSpecializationTries = 2; static const unsigned MaxLHSCands = 15; +static const unsigned MaxRHS = 1; using namespace souper; using namespace llvm; @@ -361,7 +362,7 @@ std::error_code ExhaustiveSynthesis::synthesize(SMTLIBSolver *SMTSolver, const BlockPCs &BPCs, const std::vector &PCs, - Inst *LHS, Inst *&RHS, + Inst *LHS, std::vector &RHSs, InstContext &IC, unsigned Timeout) { std::vector Inputs; findCands(LHS, Inputs, /*WidthMustMatch=*/false, /*FilterVars=*/false, MaxLHSCands); @@ -581,8 +582,9 @@ ExhaustiveSynthesis::synthesize(SMTLIBSolver *SMTSolver, llvm::errs() << "second query is UNSAT-- works for all values of this constant\n"; llvm::errs() << Tries << " tries were made for synthesizing constants\n"; } - RHS = I2; - return EC; + RHSs.emplace_back(I2); + if (RHSs.size() >= MaxRHS) + return EC; } } if (DebugLevel > 2) diff --git a/lib/Pass/Pass.cpp b/lib/Pass/Pass.cpp index 1792d99fb..ace5a62b5 100644 --- a/lib/Pass/Pass.cpp +++ b/lib/Pass/Pass.cpp @@ -399,9 +399,10 @@ struct SouperPass : public ModulePass { Changed = true; continue; } + std::vector RHSs; if (std::error_code EC = S->infer(Cand.BPCs, Cand.PCs, Cand.Mapping.LHS, - Cand.Mapping.RHS, IC)) { + RHSs, IC)) { if (EC == std::errc::timed_out || EC == std::errc::value_too_large) { continue; @@ -409,6 +410,7 @@ struct SouperPass : public ModulePass { report_fatal_error("Unable to query solver: " + EC.message() + "\n"); } } + Cand.Mapping.RHS = RHSs.empty() ? 0 : RHSs.front(); if (!Cand.Mapping.RHS) continue; diff --git a/lib/Tool/CandidateMapUtils.cpp b/lib/Tool/CandidateMapUtils.cpp index 060466ed6..92b8e5c0e 100644 --- a/lib/Tool/CandidateMapUtils.cpp +++ b/lib/Tool/CandidateMapUtils.cpp @@ -83,16 +83,16 @@ bool SolveCandidateMap(llvm::raw_ostream &OS, CandidateMap &M, Cand.PCs, Cand.Mapping.LHS, Context), HField, 1); } - Inst *RHS; + std::vector RHSs; if (std::error_code EC = - S->infer(Cand.BPCs, Cand.PCs, Cand.Mapping.LHS, RHS, IC)) { + S->infer(Cand.BPCs, Cand.PCs, Cand.Mapping.LHS, RHSs, IC)) { llvm::errs() << "Unable to query solver: " << EC.message() << '\n'; return false; } - if (RHS) { + if (!RHSs.empty()) { OS << '\n'; OS << "; Static profile " << Profile[I] << '\n'; - Cand.Mapping.RHS = RHS; + Cand.Mapping.RHS = RHSs.front(); Cand.printFunction(OS); Cand.print(OS); } @@ -121,14 +121,14 @@ bool CheckCandidateMap(llvm::Module &Mod, CandidateMap &M, Solver *S, bool OK = true; for (auto &Cand : M) { - Inst *RHS; + std::vector RHSs; if (std::error_code EC = - S->infer(Cand.BPCs, Cand.PCs, Cand.Mapping.LHS, RHS, IC)) { + S->infer(Cand.BPCs, Cand.PCs, Cand.Mapping.LHS, RHSs, IC)) { llvm::errs() << "Unable to query solver: " << EC.message() << '\n'; return false; } - if (RHS) { - Cand.Mapping.RHS = RHS; + if (!RHSs.empty()) { + Cand.Mapping.RHS = RHSs.front(); if (Cand.Mapping.RHS->K != Inst::Const) { llvm::errs() << "found replacement:\n"; Cand.printFunction(llvm::errs()); diff --git a/tools/souper-check.cpp b/tools/souper-check.cpp index 984c6dfef..33135747c 100644 --- a/tools/souper-check.cpp +++ b/tools/souper-check.cpp @@ -83,12 +83,14 @@ int SolveInst(const MemoryBufferRef &MB, Solver *S) { OldCost = cost(Rep.Mapping.RHS); Rep.Mapping.RHS = 0; } + std::vector RHSs; if (std::error_code EC = S->infer(Rep.BPCs, Rep.PCs, Rep.Mapping.LHS, - Rep.Mapping.RHS, IC)) { + RHSs, IC)) { llvm::errs() << EC.message() << '\n'; Ret = 1; ++Error; } + Rep.Mapping.RHS = RHSs.empty() ? 0 : RHSs.front(); if (Rep.Mapping.RHS) { ++Success; if (ReInferRHS) {