diff --git a/src/cpu/pred/BranchPredictor.py b/src/cpu/pred/BranchPredictor.py index 297b7a92b0..3d4d59d259 100644 --- a/src/cpu/pred/BranchPredictor.py +++ b/src/cpu/pred/BranchPredictor.py @@ -1063,16 +1063,32 @@ class BTBTAGE(TimedBaseBTBPredictor): numDelay = 2 usingMbtbBaseEiterTage = Param.Bool(True, "Whether using MBTB basetable either TAGE ") -class MicroTAGE(BTBTAGE): - """A smaller TAGE predictor configuration to assist uBTB""" - enableSC = Param.Bool(False, "Enable SC or not") # TODO: BTBTAGE doesn't support SC - numPredictors = 1 - tableSizes = [512] - TTagBitSizes = [16] - TTagPcShifts = [1] - - histLengths = [16] - numDelay = 0 +class MicroTAGE(TimedBaseBTBPredictor): + """Micro-sized BTB TAGE predictor used alongside uBTB""" + type = 'MicroTAGE' + cxx_class = 'gem5::branch_prediction::btb_pred::MicroTAGE' + cxx_header = "cpu/pred/btb/microtage.hh" + + needMoreHistories = Param.Bool(True, "MicroTAGE needs more histories") + updateOnRead = Param.Bool(False,"Enable update on read, no need to save tage meta in FTQ") + # Keep vector parameters consistent with numPredictors to avoid constructor asserts. + numPredictors = Param.Unsigned(1, "Number of TAGE predictors") + tableSizes = VectorParam.Unsigned([512],"the TAGE T0~Tn length") + TTagBitSizes = VectorParam.Unsigned([16] ,"the T0~Tn entry's tag bit size") + TTagPcShifts = VectorParam.Unsigned([1] ,"when the T0~Tn entry's tag generating, PC right shift") + blockSize = Param.Unsigned(32,"tage index function uses 32B aligned block address") + + histLengths = VectorParam.Unsigned([18],"the BTB TAGE T0~Tn history length") + maxHistLen = Param.Unsigned(970,"The length of history passed from DBP") + numTablesToAlloc = Param.Unsigned(1,"The number of table to allocated each time") + numWays = Param.Unsigned(2, "Number of ways per set") + baseTableSize = Param.Unsigned(256,"Base table size") + maxBranchPositions = Param.Unsigned(32,"Maximum branch positions per 64-byte block") + useAltOnNaSize = Param.Unsigned(128,"Size of the useAltOnNa table") + useAltOnNaWidth = Param.Unsigned(7,"Width of the useAltOnNa table") + numBanks = Param.Unsigned(4,"Number of banks for bank conflict simulation") + enableBankConflict = Param.Bool(False,"Enable bank conflict simulation") + numDelay = Param.Unsigned(0,"Prediction latency in cycles") class BTBITTAGE(TimedBaseBTBPredictor): type = 'BTBITTAGE' @@ -1162,7 +1178,7 @@ class DecoupledBPUWithBTB(BranchPredictor): numStages = Param.Unsigned(4, "Maximum number of stages in the pipeline") ubtb = Param.UBTB(UBTB(), "UBTB predictor") abtb = Param.AheadBTB(AheadBTB(), "ABTB predictor") - microtage = Param.BTBTAGE(MicroTAGE(), "MicroTAGE predictor to assist uBTB") + microtage = Param.MicroTAGE(MicroTAGE(), "MicroTAGE predictor to assist uBTB") mbtb = Param.MBTB(MBTB(), "MBTB predictor") tage = Param.BTBTAGE(BTBTAGE(), "TAGE predictor") ittage = Param.BTBITTAGE(BTBITTAGE(), "ITTAGE predictor") diff --git a/src/cpu/pred/SConscript b/src/cpu/pred/SConscript index 75895b79ad..7abd4380c8 100644 --- a/src/cpu/pred/SConscript +++ b/src/cpu/pred/SConscript @@ -49,7 +49,8 @@ SimObject('BranchPredictor.py', sim_objects=[ 'DecoupledStreamBPU', 'DefaultFTB', 'DecoupledBPUWithFTB', 'TimedBaseFTBPredictor', 'FTBTAGE', 'FTBRAS', 'FTBuRAS', 'FTBITTAGE', 'AheadBTB', 'MBTB', 'UBTB', 'DecoupledBPUWithBTB', - 'TimedBaseBTBPredictor', 'BTBRAS', 'BTBTAGE', 'BTBITTAGE', 'BTBMGSC'], enums=["BpType"]) + 'TimedBaseBTBPredictor', 'BTBRAS', 'BTBTAGE', 'MicroTAGE', + 'BTBITTAGE', 'BTBMGSC'], enums=["BpType"]) DebugFlag('Indirect') Source('bpred_unit.cc') @@ -101,6 +102,7 @@ Source('btb/mbtb.cc') Source('btb/timed_base_pred.cc') Source('btb/fetch_target_queue.cc') Source('btb/btb_tage.cc') +Source('btb/microtage.cc') Source('btb/btb_ittage.cc') Source('btb/btb_mgsc.cc') Source('btb/folded_hist.cc') diff --git a/src/cpu/pred/btb/decoupled_bpred.hh b/src/cpu/pred/btb/decoupled_bpred.hh index e5c5dda8be..3f9ea28ddc 100644 --- a/src/cpu/pred/btb/decoupled_bpred.hh +++ b/src/cpu/pred/btb/decoupled_bpred.hh @@ -15,15 +15,16 @@ #include "cpu/o3/dyn_inst_ptr.hh" #include "cpu/pred/bpred_unit.hh" #include "cpu/pred/btb/abtb.hh" -#include "cpu/pred/btb/mbtb.hh" #include "cpu/pred/btb/btb_ittage.hh" +#include "cpu/pred/btb/btb_mgsc.hh" #include "cpu/pred/btb/btb_tage.hh" #include "cpu/pred/btb/btb_ubtb.hh" -#include "cpu/pred/btb/btb_mgsc.hh" #include "cpu/pred/btb/fetch_target_queue.hh" #include "cpu/pred/btb/jump_ahead_predictor.hh" #include "cpu/pred/btb/loop_buffer.hh" #include "cpu/pred/btb/loop_predictor.hh" +#include "cpu/pred/btb/mbtb.hh" +#include "cpu/pred/btb/microtage.hh" #include "cpu/pred/btb/ras.hh" #include "cpu/pred/general_arch_db.hh" @@ -103,7 +104,7 @@ class DecoupledBPUWithBTB : public BPredUnit UBTB *ubtb{}; AheadBTB *abtb{}; MBTB *mbtb{}; - BTBTAGE *microtage{}; + MicroTAGE *microtage{}; BTBTAGE *tage{}; BTBITTAGE *ittage{}; BTBMGSC *mgsc{}; diff --git a/src/cpu/pred/btb/microtage.cc b/src/cpu/pred/btb/microtage.cc new file mode 100644 index 0000000000..6618a3aed6 --- /dev/null +++ b/src/cpu/pred/btb/microtage.cc @@ -0,0 +1,1098 @@ +#include "cpu/pred/btb/microtage.hh" + +#include +#include +#include + +#ifdef UNIT_TEST +// Define debug flags for unit testing +namespace gem5 { +namespace debug { + bool TAGEUseful = true; + bool TAGEHistory = true; +} +} +#endif + +#ifndef UNIT_TEST +#include "base/debug_helper.hh" +#include "base/intmath.hh" +#include "base/trace.hh" +#include "base/types.hh" +#include "cpu/o3/dyn_inst.hh" +#include "debug/TAGE.hh" + +#endif +namespace gem5 { + +namespace branch_prediction { + +namespace btb_pred{ + +#ifdef UNIT_TEST +namespace test { +#endif + +#ifdef UNIT_TEST +// Test constructor for unit testing mode +MicroTAGE::MicroTAGE(unsigned numPredictors, unsigned numWays, unsigned tableSize, unsigned numBanks) + : TimedBaseBTBPredictor(), + numPredictors(numPredictors), + numWays(numWays), + maxBranchPositions(32), + updateOnRead(false), + numBanks(numBanks), + bankIdWidth(ceilLog2(numBanks)), + blockWidth(floorLog2(blockSize)), + bankBaseShift(instShiftAmt), + indexShift(bankBaseShift + ceilLog2(numBanks)), + enableBankConflict(false), + lastPredBankId(0), + predBankValid(false) +{ + setNumDelay(1); + + // Initialize with default parameters for testing + tableSizes.resize(numPredictors, tableSize); + tableTagBits.resize(numPredictors, 8); + tablePcShifts.resize(numPredictors, 1); + histLengths.resize(numPredictors); + for (unsigned i = 0; i < numPredictors; ++i) { + histLengths[i] = (i + 1) * 4; + } + maxHistLen = histLengths[numPredictors-1]; + numTablesToAlloc = 1; +#else +// Constructor: Initialize TAGE predictor with given parameters +MicroTAGE::MicroTAGE(const Params& p): +TimedBaseBTBPredictor(p), +numPredictors(p.numPredictors), +tableSizes(p.tableSizes), +tableTagBits(p.TTagBitSizes), +tablePcShifts(p.TTagPcShifts), +histLengths(p.histLengths), +maxHistLen(p.maxHistLen), +numWays(p.numWays), +maxBranchPositions(p.maxBranchPositions), +numTablesToAlloc(p.numTablesToAlloc), +updateOnRead(p.updateOnRead), +numBanks(p.numBanks), +bankIdWidth(ceilLog2(p.numBanks)), +blockWidth(p.blockSize ? floorLog2(p.blockSize) : 0), +bankBaseShift(instShiftAmt), // strip instruction alignment bits before indexing +indexShift(bankBaseShift + ceilLog2(p.numBanks)), +enableBankConflict(p.enableBankConflict), +lastPredBankId(0), +predBankValid(false), +tageStats(this, p.numPredictors, p.numBanks) +{ + this->needMoreHistories = p.needMoreHistories; + + // Warn if updateOnRead is disabled (bank simulation works better with it enabled) + if (!p.updateOnRead) { + warn("MicroTAGE: Bank simulation works better with updateOnRead=true"); + } +#endif + tageTable.resize(numPredictors); + tableIndexBits.resize(numPredictors); + tableIndexMasks.resize(numPredictors); + tableTagBits.resize(numPredictors); + tableTagMasks.resize(numPredictors); + // Ensure PC shift vector has entries for all predictors (fallback default = 1) + if (tablePcShifts.size() < numPredictors) { + tablePcShifts.resize(numPredictors, 1); + } + + // Initialize base table for fallback predictions + for (unsigned int i = 0; i < numPredictors; ++i) { + //initialize ittage predictor + assert(tableSizes.size() >= numPredictors); + tageTable[i].resize(tableSizes[i]); + for (unsigned int j = 0; j < tableSizes[i]; ++j) { + tageTable[i][j].resize(numWays); + } + + tableIndexBits[i] = ceilLog2(tableSizes[i]); + tableIndexMasks[i].resize(tableIndexBits[i], true); + + assert(histLengths.size() >= numPredictors); + + assert(tableTagBits.size() >= numPredictors); + tableTagMasks[i].resize(tableTagBits[i], true); + + assert(tablePcShifts.size() >= numPredictors); + + tagFoldedHist.push_back(PathFoldedHist((int)histLengths[i], (int)tableTagBits[i], 16)); + altTagFoldedHist.push_back(PathFoldedHist((int)histLengths[i], (int)tableTagBits[i]-1, 16)); + indexFoldedHist.push_back(PathFoldedHist((int)histLengths[i], (int)tableIndexBits[i], 16)); + } + usefulResetCnt = 0; + +#ifndef UNIT_TEST + hasDB = true; + dbName = std::string("microtage"); +#endif +} + +MicroTAGE::~MicroTAGE() +{ +} + +// Set up tracing for debugging +void +MicroTAGE::setTrace() +{ +#ifndef UNIT_TEST + if (enableDB) { + std::vector> fields_vec = { + std::make_pair("startPC", UINT64), + std::make_pair("branchPC", UINT64), + std::make_pair("wayIdx", UINT64), + std::make_pair("mainFound", UINT64), + std::make_pair("mainCounter", UINT64), + std::make_pair("mainUseful", UINT64), + std::make_pair("mainTable", UINT64), + std::make_pair("mainIndex", UINT64), + std::make_pair("altFound", UINT64), + std::make_pair("altCounter", UINT64), + std::make_pair("altUseful", UINT64), + std::make_pair("altTable", UINT64), + std::make_pair("altIndex", UINT64), + std::make_pair("useAlt", UINT64), + std::make_pair("predTaken", UINT64), + std::make_pair("actualTaken", UINT64), + std::make_pair("allocSuccess", UINT64), + std::make_pair("allocTable", UINT64), + std::make_pair("allocIndex", UINT64), + std::make_pair("allocWay", UINT64), + std::make_pair("history", TEXT), + std::make_pair("indexFoldedHist", UINT64), + }; + tageMissTrace = _db->addAndGetTrace("TAGEMISSTRACE", fields_vec); + tageMissTrace->init_table(); + } +#endif +} + +void +MicroTAGE::tick() {} + +void +MicroTAGE::tickStart() {} + +/** + * @brief Generate prediction for a single BTB entry by searching TAGE tables + * + * @param btb_entry The BTB entry to generate prediction for + * @param startPC The starting PC address for calculating indices and tags + * @param predMeta Optional prediction metadata; if provided, use snapshot for index/tag + * calculation (update path); if nullptr, use current folded history (prediction path) + * @return TagePrediction containing main and alternative predictions + */ +MicroTAGE::TagePrediction +MicroTAGE::generateSinglePrediction(const BTBEntry &btb_entry, + const Addr &startPC, + std::shared_ptr predMeta) { + DPRINTF(TAGE, "generateSinglePrediction for btbEntry: %#lx\n", btb_entry.pc); + + bool provided = false; + TageTableInfo main_info; + + // Search from highest to lowest table for matches + // Calculate branch position within the block (like RTL's cfiPosition) + unsigned position = getBranchIndexInBlock(btb_entry.pc, startPC); + + for (int i = numPredictors - 1; i >= 0; --i) { + // Calculate index and tag: use snapshot if provided, otherwise use current folded history + // Tag includes position XOR (like RTL: tag = tempTag ^ cfiPosition) + Addr index = predMeta ? getTageIndex(startPC, i, predMeta->indexFoldedHist[i].get()) + : getTageIndex(startPC, i); + Addr tag = predMeta ? getTageTag(startPC, i, + predMeta->tagFoldedHist[i].get(),predMeta->altTagFoldedHist[i].get(), position) + : getTageTag(startPC, i, tagFoldedHist[i].get(),altTagFoldedHist[i].get(), position); + + bool match = false; // for each table, only one way can be matched + TageEntry matching_entry; + unsigned matching_way = 0; + + // Search all ways for a matching entry + for (unsigned way = 0; way < numWays; way++) { + auto &entry = tageTable[i][index][way]; + // entry valid, tag match (position already encoded in tag, no need to check pc) + if (entry.valid && tag == entry.tag) { + matching_entry = entry; + matching_way = way; + match = true; + + // Do not use LRU; keep logic simple and align with CBP-style replacement + + DPRINTF(TAGE, "hit table %d[%lu][%u]: valid %d, tag %lu, ctr %d, useful %d, btb_pc %#lx, pos %u\n", + i, index, way, entry.valid, entry.tag, entry.counter, entry.useful, btb_entry.pc, position); + break; // only one way can be matched, aviod multi hit, TODO: RTL how to do this? + } + } + + if (match) { + if (!provided) { + // First match becomes main prediction + main_info = TageTableInfo(true, matching_entry, i, index, tag, matching_way); + provided = true; + } + } else { + DPRINTF(TAGE, "miss table %d[%lu] for tag %lu (with pos %u), btb_pc %#lx\n", + i, index, tag, position, btb_entry.pc); + } + } + + // Generate final prediction + bool main_taken = main_info.taken(); + bool base_pred = btb_entry.ctr >= 0; + + bool taken = provided ? main_taken : base_pred; + + DPRINTF(TAGE, "tage predict %#lx taken %d\n", btb_entry.pc, taken); + DPRINTF(TAGE, "tage main prvided %d ? main_taken %d : base_taken %d\n", provided, main_taken, base_pred); + + return TagePrediction(btb_entry.pc, main_info, provided, taken, base_pred); +} + +/** + * @brief Look up predictions in TAGE tables for a stream of instructions + * + * @param startPC The starting PC address for the instruction stream + * @param btbEntries Vector of BTB entries to make predictions for + * @return Map of branch PC addresses to their predicted outcomes + */ +void +MicroTAGE::lookupHelper(const Addr &startPC, const std::vector &btbEntries, CondTakens& results) +{ + DPRINTF(TAGE, "lookupHelper startAddr: %#lx\n", startPC); + + // Process each BTB entry to make predictions + for (auto &btb_entry : btbEntries) { + // Only predict for valid conditional branches + if (btb_entry.isCond && btb_entry.valid) { + auto pred = generateSinglePrediction(btb_entry, startPC); + meta->preds[btb_entry.pc] = pred; + tageStats.updateStatsWithTagePrediction(pred, true); + results.push_back({btb_entry.pc, pred.taken || btb_entry.alwaysTaken}); + } + } +} + +void +MicroTAGE::dryRunCycle(Addr startPC) { + // No operation in dry run cycle for MicroTAGE + // Record prediction bank for next tick's conflict detection + lastPredBankId = getBankId(startPC); + predBankValid = true; + + return; +} + +/** + * @brief Makes predictions for a stream of instructions using TAGE predictor + * + * This function is called during the prediction stage and: + * 1. Uses lookupHelper to get predictions for all BTB entries + * 2. Stores predictions in the stage prediction structure + * 3. Handles multiple prediction stages with different delays + * + * @param startPC Starting PC of the instruction stream + * @param history Current branch history + * @param stagePreds Vector of predictions for different pipeline stages + */ +void +MicroTAGE::putPCHistory(Addr startPC, const bitset &history, std::vector &stagePreds) { + // Record prediction bank for next tick's conflict detection + lastPredBankId = getBankId(startPC); + predBankValid = true; + +#ifndef UNIT_TEST + // Record prediction access per bank + tageStats.predAccessPerBank[lastPredBankId]++; +#endif + + DPRINTF(TAGE, "putPCHistory startAddr: %#lx, bank: %u\n", + startPC, lastPredBankId); + + // IMPORTANT: when this function is called, + // btb entries should already be in stagePreds + // get prediction and save it + + // Clear old prediction metadata and save current history state + meta = std::make_shared(); + meta->tagFoldedHist = tagFoldedHist; + meta->altTagFoldedHist = altTagFoldedHist; + meta->indexFoldedHist = indexFoldedHist; + meta->history = history; + + for (int s = getDelay(); s < stagePreds.size(); s++) { + // TODO: only lookup once for one btb entry in different stages + auto &stage_pred = stagePreds[s]; + stage_pred.condTakens.clear(); + lookupHelper(startPC, stage_pred.btbEntries, stage_pred.condTakens); + } + +} + +std::shared_ptr +MicroTAGE::getPredictionMeta() { + return meta; +} + +/** + * @brief Prepare BTB entries for update by filtering and processing + * + * @param stream The fetch stream containing update information + * @return Vector of BTB entries that need to be updated + */ +std::vector +MicroTAGE::prepareUpdateEntries(const FetchStream &stream) { + auto all_entries = stream.updateBTBEntries; + + // Add potential new BTB entry if it's a btb miss during prediction + if (!stream.updateIsOldEntry) { + BTBEntry potential_new_entry = stream.updateNewBTBEntry; + bool new_entry_taken = stream.exeTaken && stream.getControlPC() == potential_new_entry.pc; + if (!new_entry_taken) { + potential_new_entry.alwaysTaken = false; + } + all_entries.push_back(potential_new_entry); + } + + // Filter: only keep conditional branches that are not always taken + if (getResolvedUpdate()) { + auto remove_it = std::remove_if(all_entries.begin(), all_entries.end(), + [](const BTBEntry &e) { return !(e.isCond && !e.alwaysTaken && e.resolved); }); + all_entries.erase(remove_it, all_entries.end()); + } else { + auto remove_it = std::remove_if(all_entries.begin(), all_entries.end(), + [](const BTBEntry &e) { return !(e.isCond && !e.alwaysTaken); }); + all_entries.erase(remove_it, all_entries.end()); + } + + return all_entries; +} + +/** + * @brief Update predictor state for a single entry + * + * @param entry The BTB entry being updated + * @param actual_taken The actual outcome of the branch + * @param pred The prediction made for this entry + * @param stream The fetch stream containing update information + * @return true if need to allocate new entry + */ +bool +MicroTAGE::updatePredictorStateAndCheckAllocation(const BTBEntry &entry, + bool actual_taken, + const TagePrediction &pred, + const FetchStream &stream) { + tageStats.updateStatsWithTagePrediction(pred, false); + + auto &main_info = pred.mainInfo; + bool used_base = !pred.mainprovided; + // Use base table instead of entry.ctr for fallback prediction + Addr startPC = stream.getRealStartPC(); + bool base_taken = entry.ctr >= 0; + + // Update use_alt_on_na when provider is weak (0 or -1) + if (main_info.found) { + bool main_weak = (main_info.entry.counter == 0 || main_info.entry.counter == -1); + if (main_weak) { + tageStats.updateProviderNa++; + bool base_correct = (base_taken == actual_taken); + //updateCounter(base_correct, useAltOnNaWidth, useAlt[uidx]); + tageStats.updateUseAltOnNaUpdated++; + if (base_correct) { + tageStats.updateUseAltOnNaCorrect++; + } else { + tageStats.updateUseAltOnNaWrong++; + } + } + } + + // Update main prediction provider + if (main_info.found) { + DPRINTF(TAGE, "prediction provided by table %d, idx %lu, way %u, updating corresponding entry\n", + main_info.table, main_info.index, main_info.way); + + auto &way = tageTable[main_info.table][main_info.index][main_info.way]; + + // Update prediction counter + updateCounter(actual_taken, 3, way.counter); + + // Update useful bit based on several conditions + bool main_is_correct = main_info.taken() == actual_taken; + bool base_is_correct_and_strong = + (base_taken == actual_taken) && + (abs(2 * entry.ctr + 1) == 5); + + // a. Special reset (humility mechanism) + if (base_is_correct_and_strong && main_is_correct) { + way.useful = 0; + DPRINTF(TAGEUseful, "useful bit reset to 0 due to humility rule\n"); + } else if (main_info.taken() != base_taken) { + // b. Original logic to set useful bit high + if (main_is_correct) { + way.useful = 1; + } + } + + // c. Reset u on counter sign flip (becomes weak) + if (way.counter == 0 || way.counter == -1) { + way.useful = 0; + DPRINTF(TAGEUseful, "useful bit reset to 0 due to weak counter\n"); + } + DPRINTF(TAGE, "useful bit is now %d\n", way.useful); + + // No LRU maintenance + + if (!main_is_correct) { + tageStats.updateUtageWrong++; + } + } + + + // Update statistics + if (used_base) { + bool base_correct = base_taken == actual_taken; + if (base_correct) { + tageStats.updateUseAltCorrect++; + } else { + tageStats.updateUseAltWrong++; + tageStats.updateUtageWrong++; + } + if (main_info.found && main_info.taken() != base_taken) { + tageStats.updateAltDiffers++; + } + } + + // Check if misprediction occurred + bool this_fb_mispred = stream.squashType == SquashType::SQUASH_CTRL && + stream.squashPC == entry.pc; + // No allocation if no misprediction + if (!this_fb_mispred) { + return false; + } + + // All other cases: allocate longer history table + return true; +} + +/** + * @brief Handle allocation of new entries + * + * @param startPC The starting PC address + * @param entry The BTB entry being updated + * @param actual_taken The actual outcome of the branch + * @param start_table The starting table for allocation + * @param meta The metadata of the predictor + * @return true if allocation is successful + */ +bool +MicroTAGE::handleNewEntryAllocation(const Addr &startPC, + const BTBEntry &entry, + bool actual_taken, + unsigned start_table, + std::shared_ptr meta, + uint64_t &allocated_table, + uint64_t &allocated_index, + uint64_t &allocated_way) { + // Simple set-associative allocation (no LFSR, no per-way table gating): + // - For each table from start_table upward, check the set at computed index. + // - Prefer invalid ways; else choose any way with useful==0 and weak counter. + // - If none, apply a one-step age penalty to a strong, not-useful way (no allocation). + + // Calculate branch position within the block (like RTL's cfiPosition) + unsigned position = getBranchIndexInBlock(entry.pc, startPC); + + for (unsigned ti = start_table; ti < numPredictors; ++ti) { + Addr newIndex = getTageIndex(startPC, ti, meta->indexFoldedHist[ti].get()); + Addr newTag = getTageTag(startPC, ti, + meta->tagFoldedHist[ti].get(), meta->altTagFoldedHist[ti].get(), position); + + auto &set = tageTable[ti][newIndex]; + + // Allocate into invalid way or not-useful and weak way + for (unsigned way = 0; way < numWays; ++way) { + auto &cand = set[way]; + const bool weakish = std::abs(cand.counter * 2 + 1) <= 3; // -3,-2,-1,0,1,2 + if (!cand.valid || (!cand.useful && weakish)) { + short newCounter = actual_taken ? 0 : -1; + DPRINTF(TAGE, "allocating entry in table %d[%lu][%u], tag %lu (with pos %u), counter %d, pc %#lx\n", + ti, newIndex, way, newTag, position, newCounter, entry.pc); + cand = TageEntry(newTag, newCounter, entry.pc); // u = 0 default + tageStats.updateAllocSuccess++; + allocated_table = ti; + allocated_index = newIndex; + allocated_way = way; + usefulResetCnt = usefulResetCnt <= 0 ? 0 : usefulResetCnt - 1; + return true; + } + } + + // 3) Apply age penalty to one strong, not-useful way to make it replacable later + for (unsigned way = 0; way < numWays; ++way) { + auto &cand = set[way]; + const bool weakish = std::abs(cand.counter * 2 + 1) <= 3; + if (!cand.useful && !weakish) { + if (cand.counter > 0) cand.counter--; else cand.counter++; + DPRINTF(TAGE, "age penalty applied on table %d[%lu][%u], new ctr %d\n", + ti, newIndex, way, cand.counter); + break; // one penalty per table per update + } + } + + tageStats.updateAllocFailure++; + usefulResetCnt++; + } + + if (usefulResetCnt >= 256) { + usefulResetCnt = 0; + tageStats.updateResetU++; + DPRINTF(TAGE, "reset useful bit of all entries\n"); + for (auto &table : tageTable) { + for (auto &set : table) { + for (auto &way : set) { + way.useful = false; + } + } + } + } + + DPRINTF(TAGE, "no eligible way found for allocation starting from table %d\n", start_table); + tageStats.updateAllocFailureNoValidTable++; + return false; +} + +/** + * @brief Probe resolved update for bank conflicts without mutating state. + * Returns false if the update cannot proceed due to a bank conflict. + */ +bool +MicroTAGE::canResolveUpdate(const FetchStream &stream) { + Addr startAddr = stream.getRealStartPC(); + unsigned updateBank = getBankId(startAddr); + +#ifndef UNIT_TEST + // Record attempted update access per bank (even if it conflicts) + tageStats.updateAccessPerBank[updateBank]++; +#endif + + if (enableBankConflict && predBankValid && updateBank == lastPredBankId) { + tageStats.updateBankConflict++; + tageStats.updateDeferredDueToConflict++; +#ifndef UNIT_TEST + tageStats.updateBankConflictPerBank[updateBank]++; +#endif + DPRINTF(TAGE, "Bank conflict detected: update bank %u conflicts with prediction bank %u, " + "deferring this update (will retry after blocking prediction)\n", + updateBank, lastPredBankId); + predBankValid = false; + return false; + } + + return true; +} + +/** + * @brief Perform resolved update after probe success. + */ +void +MicroTAGE::doResolveUpdate(const FetchStream &stream) { + if (enableBankConflict && predBankValid) { + // Prediction consumed; clear bank tag for next cycle + predBankValid = false; + } + update(stream); +} + +/** + * @brief Updates the TAGE predictor state based on actual branch execution results + * + * @param stream The fetch stream containing branch execution information + */ +void +MicroTAGE::update(const FetchStream &stream) { + Addr startAddr = stream.getRealStartPC(); + unsigned updateBank = getBankId(startAddr); + + DPRINTF(TAGE, "update startAddr: %#lx, bank: %u\n", startAddr, updateBank); + + // ========== Normal Update Logic ========== + // Prepare BTB entries to update + auto entries_to_update = prepareUpdateEntries(stream); + + // Get prediction metadata snapshot and bind to member for helpers + auto predMeta = std::static_pointer_cast(stream.predMetas[getComponentIdx()]); + if (!predMeta) { + DPRINTF(TAGE, "update: no prediction meta, skip\n"); + return; + } + + bool utage_hit = false; + // Process each BTB entry + for (auto &btb_entry : entries_to_update) { + bool actual_taken = stream.exeTaken && stream.exeBranchInfo == btb_entry; + TagePrediction recomputed; + if (updateOnRead) { // if update on read is enabled, re-read providers using snapshot + // Re-read providers using snapshot (do not rely on prediction-time main/alt) + recomputed = generateSinglePrediction(btb_entry, startAddr, predMeta); + } else { // otherwise, use the prediction from the prediction-time main/alt + recomputed = predMeta->preds[btb_entry.pc]; + } + if (recomputed.mainprovided) { + utage_hit = true; + } + // Update predictor state and check if need to allocate new entry + bool need_allocate = updatePredictorStateAndCheckAllocation(btb_entry, actual_taken, recomputed, stream); + + // Handle new entry allocation if needed + bool alloc_success = false; + uint64_t allocated_table = 0; + uint64_t allocated_index = 0; + uint64_t allocated_way = 0; + if (need_allocate) { + + // Handle allocation of new entries + uint start_table = 0; + auto &main_info = recomputed.mainInfo; + if (main_info.found) { + start_table = main_info.table + 1; // start from the table after the main prediction table + } + alloc_success = handleNewEntryAllocation(startAddr, btb_entry, actual_taken, + start_table, predMeta, allocated_table, allocated_index, allocated_way); + } + +#ifndef UNIT_TEST + // if (enableDB) { + // TageMissTrace t; + // std::string history_str; + // boost::dynamic_bitset<> history_low50 = predMeta->history; + // if (history_low50.size() > 50) { + // history_low50.resize(50); // get the lower 50 bits of history + // } + // boost::to_string(history_low50, history_str); + // auto main_info = recomputed.mainInfo; + // t.set(startAddr, btb_entry.pc, main_info.way, + // main_info.found, main_info.entry.counter, main_info.entry.useful, + // main_info.table, main_info.index, + // recomputed.useAlt, recomputed.taken, actual_taken, alloc_success, + // allocated_table, allocated_index, allocated_way, + // history_str, predMeta->indexFoldedHist[main_info.table].get()); + // tageMissTrace->write_record(t); + // } +#endif + } + if (utage_hit){ + tageStats.updateUtageHit++;//for RTL align pred Accuracy + } + checkUtageUpdateMisspred(stream); + DPRINTF(TAGE, "end update\n"); +} + +void +MicroTAGE::checkUtageUpdateMisspred(const FetchStream &stream) { + auto predMeta = std::static_pointer_cast(stream.predMetas[getComponentIdx()]); + // use for microtage updatemispred counting + // sort microtage predictions by pc to find the first taken branch + std::vector> lastPreds; + lastPreds.reserve(predMeta->preds.size()); + for (auto &kv : predMeta->preds) { + lastPreds.emplace_back(kv.first, kv.second); + } + std::sort(lastPreds.begin(), lastPreds.end(), + [](const std::pair &a, + const std::pair &b) { + return a.first < b.first; + }); + Addr first_taken_pc = 0; + for (auto &entry_info : lastPreds) { + if (entry_info.second.taken) { + first_taken_pc = entry_info.first; + break; + } + } + bool fallthrough_mispred = (first_taken_pc == 0 && stream.exeTaken) || + (first_taken_pc != 0 && !stream.exeTaken); + bool branch_mispred = stream.exeTaken && first_taken_pc != stream.exeBranchInfo.pc; + if (fallthrough_mispred || branch_mispred) { + tageStats.updateMispred++; + } +} + +// Update prediction counter with saturation +void +MicroTAGE::updateCounter(bool taken, unsigned width, short &counter) { + int max = (1 << (width-1)) - 1; + int min = -(1 << (width-1)); + if (taken) { + satIncrement(max, counter); + } else { + satDecrement(min, counter); + } +} + +// Calculate TAGE tag with folded history - optimized version using bitwise operations +Addr +MicroTAGE::getTageTag(Addr pc, int t, uint64_t foldedHist, uint64_t altFoldedHist, Addr position) +{ + // Create mask for tableTagBits[t] to limit result size + Addr mask = (1ULL << tableTagBits[t]) - 1; + + // Extract lower bits of PC directly (remove instruction alignment bits) + Addr pcBits = (pc >> bankBaseShift) & mask; + + // Extract and prepare folded history bits + Addr foldedBits = foldedHist & mask; + + // Extract alt tag bits and shift left by 1 + Addr altTagBits = (altFoldedHist << 1) & mask; + + // XOR all components together, including position (like RTL) + return pcBits ^ foldedBits ^ position ^ altTagBits; +} + +Addr +MicroTAGE::getTageIndex(Addr pc, int t, uint64_t foldedHist) +{ + // Create mask for tableIndexBits[t] to limit result size + Addr mask = (1ULL << tableIndexBits[t]) - 1; + + const unsigned pcShift = enableBankConflict ? indexShift : bankBaseShift; + Addr pcBits = (pc >> pcShift) & mask; + Addr foldedBits = foldedHist & mask; + + return pcBits ^ foldedBits; +} + +Addr +MicroTAGE::getTageIndex(Addr pc, int t) +{ + return getTageIndex(pc, t, indexFoldedHist[t].get()); +} + +bool +MicroTAGE::matchTag(Addr expected, Addr found) +{ + return expected == found; +} + +bool +MicroTAGE::satIncrement(int max, short &counter) +{ + if (counter < max) { + ++counter; + } + return counter == max; +} + +bool +MicroTAGE::satDecrement(int min, short &counter) +{ + if (counter > min) { + --counter; + } + return counter == min; +} + +unsigned +MicroTAGE::getBranchIndexInBlock(Addr branchPC, Addr startPC) { + // Calculate branch position within the fetch block (0 .. maxBranchPositions-1) + const Addr alignedPC = startPC & ~(blockSize - 1); + + unsigned position = 0; + if (branchPC >= alignedPC) { + const Addr byteOffset = branchPC - alignedPC; + position = byteOffset >> instShiftAmt; + } else { + warn_once("MicroTAGE: branch %#lx precedes block start %#lx; treating as offset 0", + branchPC, startPC); + } + + if (position >= maxBranchPositions) { + warn_once("MicroTAGE: branch %#lx exceeds block [%#lx, %#lx) (blockSize=%lu, instShift=%u, maxPositions=%u); clamping index", + branchPC, alignedPC, + alignedPC + blockSize, + static_cast(blockSize), instShiftAmt, maxBranchPositions); + position %= maxBranchPositions; + } + + return position; +} + +unsigned +MicroTAGE::getBankId(Addr pc) const +{ + // Extract bank ID bits after removing instruction alignment + return (pc >> bankBaseShift) & ((1 << bankIdWidth) - 1); +} + +/** + * @brief Updates branch history for speculative execution + * + * This function updates three types of folded histories: + * - Tag folded history: Used for tag computation + * - Alternative tag folded history: Used for alternative tag computation + * - Index folded history: Used for table index computation + * + * @param history The current branch history + * @param shamt The number of bits to shift + * @param taken Whether the branch was taken + */ +void +MicroTAGE::doUpdateHist(const boost::dynamic_bitset<> &history, bool taken, Addr pc, Addr target) +{ + if (debug::TAGEHistory) { // if debug flag is off, do not use to_string since it's too slow + std::string buf; + boost::to_string(history, buf); + DPRINTF(TAGEHistory, "in doUpdateHist, taken %d, pc %#lx, history %s\n", taken, pc, buf.c_str()); + } + if (!taken) { + DPRINTF(TAGEHistory, "not updating folded history, since FB not taken\n"); + return; + } + + for (int t = 0; t < numPredictors; t++) { + for (int type = 0; type < 3; type++) { + auto &foldedHist = type == 0 ? indexFoldedHist[t] : type == 1 ? tagFoldedHist[t] : altTagFoldedHist[t]; + // since we have folded path history, we can put arbitrary shamt here, and it wouldn't make a difference + foldedHist.update(history, 2, taken, pc, target); + DPRINTF(TAGEHistory, "t: %d, type: %d, foldedHist _folded 0x%lx\n", t, type, foldedHist.get()); + } + } +} + +/** + * @brief Updates branch history for speculative execution + * + * This function updates the branch history for speculative execution + * based on the provided history and prediction information. + * + * It first retrieves the history information from the prediction metadata + * and then calls the doUpdateHist function to update the folded histories. + * + * @param history The current branch history + * @param pred The prediction metadata containing history information + */ +void +MicroTAGE::specUpdatePHist(const boost::dynamic_bitset<> &history, FullBTBPrediction &pred) +{ + auto [pc, target, taken] = pred.getPHistInfo(); + doUpdateHist(history, taken, pc, target); +} + +/** + * @brief Recovers branch history state after a misprediction + * + * This function: + * 1. Restores the folded histories from the saved metadata + * 2. Updates the histories with the correct branch outcome + * 3. Ensures predictor state is consistent after recovery + * + * @param history The branch history to recover to + * @param entry The fetch stream entry containing recovery information + * @param shamt Number of bits to shift in history update + * @param cond_taken The actual branch outcome + */ +void +MicroTAGE::recoverPHist(const boost::dynamic_bitset<> &history, + const FetchStream &entry, int shamt, bool cond_taken) +{ + std::shared_ptr predMeta = std::static_pointer_cast(entry.predMetas[getComponentIdx()]); + for (int i = 0; i < numPredictors; i++) { + tagFoldedHist[i].recover(predMeta->tagFoldedHist[i]); + indexFoldedHist[i].recover(predMeta->indexFoldedHist[i]); + altTagFoldedHist[i].recover(predMeta->altTagFoldedHist[i]); + } + doUpdateHist(history, cond_taken, entry.getControlPC(), entry.getTakenTarget()); +} + +// Check folded history after speculative update and recovery +void +MicroTAGE::checkFoldedHist(const boost::dynamic_bitset<> &hist, const char * when) +{ + DPRINTF(TAGE, "checking folded history when %s\n", when); + if (debug::TAGEHistory) { + std::string hist_str; + boost::to_string(hist, hist_str); + DPRINTF(TAGEHistory, "history:\t%s\n", hist_str.c_str()); + } + for (int t = 0; t < numPredictors; t++) { + for (int type = 0; type < 3; type++) { + std::string buf2, buf3; + auto &foldedHist = type == 0 ? indexFoldedHist[t] : type == 1 ? tagFoldedHist[t] : altTagFoldedHist[t]; + foldedHist.check(hist); + } + } +} + +#ifndef UNIT_TEST +// Constructor for TAGE statistics +MicroTAGE::TageStats::TageStats(statistics::Group* parent, int numPredictors, int numBanks): + statistics::Group(parent), + ADD_STAT(predNoHitUseBim, statistics::units::Count::get(), "use bimodal when no hit on prediction"), + ADD_STAT(predUseAlt, statistics::units::Count::get(), "use alt on prediction"), + ADD_STAT(updateNoHitUseBim, statistics::units::Count::get(), "use bimodal when no hit on update"), + ADD_STAT(updateUseAlt, statistics::units::Count::get(), "use alt on update"), + ADD_STAT(updateUseAltCorrect, statistics::units::Count::get(), "use alt on update and correct"), + ADD_STAT(updateUseAltWrong, statistics::units::Count::get(), "use alt on update and wrong"), + ADD_STAT(updateAltDiffers, statistics::units::Count::get(), "alt differs on update"), + ADD_STAT(updateUseAltOnNaUpdated, statistics::units::Count::get(), "use alt on na ctr updated when update"), + ADD_STAT(updateProviderNa, statistics::units::Count::get(), "provider weak when update"), + ADD_STAT(updateUseNaCorrect, statistics::units::Count::get(), "use na on update and correct"), + ADD_STAT(updateUseNaWrong, statistics::units::Count::get(), "use na on update and wrong"), + ADD_STAT(updateUseAltOnNaCorrect, statistics::units::Count::get(), "use alt on na correct when update"), + ADD_STAT(updateUseAltOnNaWrong, statistics::units::Count::get(), "use alt on na wrong when update"), + ADD_STAT(updateAllocFailure, statistics::units::Count::get(), "alloc failure when update"), + ADD_STAT(updateAllocFailureNoValidTable, statistics::units::Count::get(), "alloc failure no valid table when update"), + ADD_STAT(updateAllocSuccess, statistics::units::Count::get(), "alloc success when update"), + ADD_STAT(updateMispred, statistics::units::Count::get(), "mispred when update"), + ADD_STAT(updateResetU, statistics::units::Count::get(), "reset u when update"), + + ADD_STAT(updateUtageHit, statistics::units::Count::get(), "number of updates where utage provided the main prediction"), + ADD_STAT(updateUtageWrong, statistics::units::Count::get(), "number of updates where utage prediction was wrong"), + + ADD_STAT(updateBankConflict, statistics::units::Count::get(), "number of bank conflicts detected"), + ADD_STAT(updateDeferredDueToConflict, statistics::units::Count::get(), "number of updates deferred due to bank conflict (retried later)"), + ADD_STAT(updateBankConflictPerBank, statistics::units::Count::get(), "bank conflicts per bank"), + ADD_STAT(updateAccessPerBank, statistics::units::Count::get(), "update accesses per bank"), + ADD_STAT(predAccessPerBank, statistics::units::Count::get(), "prediction accesses per bank"), + ADD_STAT(predTableHits, statistics::units::Count::get(), "hit of each tage table on prediction"), + ADD_STAT(updateTableHits, statistics::units::Count::get(), "hit of each tage table on update"), + ADD_STAT(updateTableMispreds, statistics::units::Count::get(), "mispreds of each table when update"), + + ADD_STAT(condPredwrong, statistics::units::Count::get(), "number of conditional branch mispredictions committed"), + ADD_STAT(condMissTakens, statistics::units::Count::get(), "number of conditional branch mispredictions committed with no prediction"), + ADD_STAT(condCorrect, statistics::units::Count::get(), "number of conditional branch correct predictions committed"), + ADD_STAT(condMissNoTakens, statistics::units::Count::get(), "number of conditional branch correct predictions committed with no prediction"), + ADD_STAT(predHit, statistics::units::Count::get(), "number of conditional branch predictions that hit"), + ADD_STAT(predMiss, statistics::units::Count::get(), "number of conditional branch predictions that miss") +{ + predTableHits.init(0, numPredictors-1, 1); + updateTableHits.init(0, numPredictors-1, 1); + updateTableMispreds.init(numPredictors); + + // Initialize per-bank statistics vectors + updateBankConflictPerBank.init(numBanks); + updateAccessPerBank.init(numBanks); + predAccessPerBank.init(numBanks); +} +#endif + +// Update statistics based on TAGE prediction +void +MicroTAGE::TageStats::updateStatsWithTagePrediction(const TagePrediction &pred, bool when_pred) +{ + bool hit = pred.mainInfo.found; + unsigned hit_table = pred.mainInfo.table; + bool useAlt = pred.mainprovided ? false : true; + if (when_pred) { + if (hit) { +#ifndef UNIT_TEST + predTableHits.sample(hit_table, 1); +#endif + } else { + predNoHitUseBim++; + } + if (!hit || useAlt) { + predUseAlt++; + } + } else { + if (hit) { +#ifndef UNIT_TEST + updateTableHits.sample(hit_table, 1); +#endif + } else { + updateNoHitUseBim++; + } + if (!hit || useAlt) { + updateUseAlt++; + } + } +} + +// Update LRU counters for a set +void +MicroTAGE::updateLRU(int table, Addr index, unsigned way) +{ + // Increment LRU counters for all entries in the set + for (unsigned i = 0; i < numWays; i++) { + if (i != way && tageTable[table][index][i].valid) { + tageTable[table][index][i].lruCounter++; + } + } + // Reset LRU counter for the accessed entry + tageTable[table][index][way].lruCounter = 0; +} + +// Find the LRU victim in a set +unsigned +MicroTAGE::getLRUVictim(int table, Addr index) +{ + unsigned victim = 0; + unsigned maxLRU = 0; + + // Find the entry with the highest LRU counter + for (unsigned i = 0; i < numWays; i++) { + if (!tageTable[table][index][i].valid) { + return i; // Use invalid entry if available + } + if (tageTable[table][index][i].lruCounter > maxLRU) { + maxLRU = tageTable[table][index][i].lruCounter; + victim = i; + } + } + return victim; +} + +#ifndef UNIT_TEST +void +MicroTAGE::commitBranch(const FetchStream &stream, const DynInstPtr &inst) +{ + if (!inst->isCondCtrl()) { + // tage only deals with conditional branches + return; + } + auto meta = std::static_pointer_cast(stream.predMetas[getComponentIdx()]); + auto pc = inst->pcState().instAddr(); + auto it = meta->preds.find(pc); + bool pred_taken = false; + bool pred_hit = false; + if (it != meta->preds.end()) { + pred_taken = it->second.taken; + pred_hit = true; + } + bool this_cond_taken = stream.exeTaken && stream.exeBranchInfo.pc == pc; + bool predcorrect = (pred_taken == this_cond_taken); + if (!predcorrect) { + tageStats.condPredwrong++; + if (!pred_hit) { + tageStats.condMissTakens++; + } + }else{ + tageStats.condCorrect++; + if (!pred_hit) { + tageStats.condMissNoTakens++; + } + } + + if (pred_hit) { + tageStats.predHit++; + } else { + tageStats.predMiss++; + } +} +#endif + +#ifdef UNIT_TEST +} // namespace test +#endif + +} // namespace btb_pred + +} // namespace branch_prediction + +} // namespace gem5 diff --git a/src/cpu/pred/btb/microtage.hh b/src/cpu/pred/btb/microtage.hh new file mode 100644 index 0000000000..ee4daed62d --- /dev/null +++ b/src/cpu/pred/btb/microtage.hh @@ -0,0 +1,439 @@ +#ifndef __CPU_PRED_BTB_MICROTAGE_HH__ +#define __CPU_PRED_BTB_MICROTAGE_HH__ + +#include +#include +#include +#include +#include + +#include "base/sat_counter.hh" +#include "base/types.hh" +#include "cpu/inst_seq.hh" +#include "cpu/pred/btb/folded_hist.hh" +#include "cpu/pred/btb/stream_struct.hh" +#include "cpu/pred/btb/timed_base_pred.hh" + +// Conditional includes based on build mode +#ifdef UNIT_TEST + #include "cpu/pred/btb/test/test_dprintf.hh" +#else + #include "debug/DecoupleBP.hh" + #include "debug/TAGEUseful.hh" + #include "debug/TAGEHistory.hh" + #include "params/MicroTAGE.hh" + #include "sim/sim_object.hh" +#endif + +namespace gem5 +{ + +namespace branch_prediction +{ + +namespace btb_pred +{ + +// Conditional namespace wrapper for testing +#ifdef UNIT_TEST +namespace test { +#endif + +class MicroTAGE : public TimedBaseBTBPredictor +{ + using defer = std::shared_ptr; + using bitset = boost::dynamic_bitset<>; + public: +#ifdef UNIT_TEST + // Test constructor + MicroTAGE(unsigned numPredictors = 4, unsigned numWays = 2, unsigned tableSize = 1024, unsigned numBanks = 4); +#else + // Production constructor + typedef MicroTAGEParams Params; +#endif + + // Represents a single entry in the TAGE prediction table + struct TageEntry + { + public: + bool valid; // Whether this entry is valid + Addr tag; // Tag for matching + short counter; // Prediction counter (-4 to 3), 3bits, 0 and -1 are weak + bool useful; // 1-bit usefulness counter; true means useful + Addr pc; // branch pc, like branch position, for btb entry pc check + unsigned lruCounter; // Counter for LRU replacement policy + + TageEntry() : valid(false), tag(0), counter(0), useful(false), pc(0), lruCounter(0) {} + + TageEntry(Addr tag, short counter, Addr pc) : + valid(true), tag(tag), counter(counter), useful(false), pc(pc), lruCounter(0) {} + bool taken() const { + return counter >= 0; + } + }; + + // Contains information about a TAGE table lookup + struct TageTableInfo + { + public: + bool found; // Whether a matching entry was found + TageEntry entry; // The matching entry + unsigned table; // Which table this entry was found in + Addr index; // Index in the table + Addr tag; // Tag that was matched + unsigned way; // Which way this entry was found in + TageTableInfo() : found(false), table(0), index(0), tag(0), way(0) {} + TageTableInfo(bool found, TageEntry entry, unsigned table, Addr index, Addr tag, unsigned way) : + found(found), entry(entry), table(table), index(index), tag(tag), way(way) {} + bool taken() const { + return entry.taken(); + } + }; + + // Contains the complete prediction result + struct TagePrediction + { + public: + Addr btb_pc; // btb entry pc, same as tage entry pc + TageTableInfo mainInfo; // Main prediction info + //TageTableInfo altInfo; // Alternative prediction info + bool mainprovided; // Whether to use alternative prediction, true if main is weak or no main prediction + bool taken; // Final prediction outcome + bool basePred; // Alternative prediction = alt_provided ? alt_taken : base_taken; + + TagePrediction() : btb_pc(0), mainprovided(false), taken(false), basePred(false) {} + TagePrediction(Addr btb_pc, TageTableInfo mainInfo, + bool mainprovided, bool taken, bool basePred) : + btb_pc(btb_pc), mainInfo(mainInfo), + mainprovided(mainprovided), taken(taken), basePred(basePred) {} + }; + + +#ifndef UNIT_TEST + MicroTAGE(const Params& p); +#endif + ~MicroTAGE(); + + void tickStart() override; + + void tick() override; + void dryRunCycle(Addr startAddr) override; + // Make predictions for a stream of instructions and record in stage preds + void putPCHistory(Addr startAddr, + const boost::dynamic_bitset<> &history, + std::vector &stagePreds) override; + + std::shared_ptr getPredictionMeta() override; + + // speculative update 3 folded history, according history and pred.taken + // the other specUpdateHist methods are left blank + void specUpdatePHist(const boost::dynamic_bitset<> &history, FullBTBPrediction &pred) override; + + // Recover 3 folded history after a misprediction, then update 3 folded history according to history and pred.taken + // the other recoverHist methods are left blank + void recoverPHist(const boost::dynamic_bitset<> &history, + const FetchStream &entry,int shamt, bool cond_taken) override; + +#ifdef UNIT_TEST + // API compatibility wrappers for testing + void specUpdateHist(const boost::dynamic_bitset<> &history, FullBTBPrediction &pred) override + { + specUpdatePHist(history, pred); + } + + void recoverHist(const boost::dynamic_bitset<> &history, const FetchStream &entry, int shamt, + bool cond_taken) override + { + recoverPHist(history, entry, shamt, cond_taken); + } +#endif + + // Update predictor state based on actual branch outcomes + void update(const FetchStream &entry) override; + bool canResolveUpdate(const FetchStream &entry) override; + void doResolveUpdate(const FetchStream &entry) override; + +#ifndef UNIT_TEST + void commitBranch(const FetchStream &stream, const DynInstPtr &inst) override; +#endif + + void setTrace() override; + + // check folded hists after speculative update and recover + void checkFoldedHist(const bitset &history, const char *when); + +#ifndef UNIT_TEST + private: +#endif + + // Look up predictions in TAGE tables for a stream of instructions + void lookupHelper(const Addr &startPC, const std::vector &btbEntries, CondTakens& results); + + // Calculate TAGE index for a given PC and table + Addr getTageIndex(Addr pc, int table); + + // Calculate TAGE index with folded history (uint64_t version for performance) + Addr getTageIndex(Addr pc, int table, uint64_t foldedHist); + + // Calculate TAGE tag for a given PC and table + // position: branch position within the block (xored into tag like RTL) + Addr getTageTag(Addr pc, int table, Addr position = 0); + + // Calculate TAGE tag with folded history (uint64_t version for performance) + // position: branch position within the block (xored into tag like RTL) + Addr getTageTag(Addr pc, int table, uint64_t foldedHist, uint64_t altFoldedHist, Addr position = 0); + + // Get offset within a block for a given PC + Addr getOffset(Addr pc) { + return (pc & (blockSize - 1)) >> 1; + } + + // Get branch index within a prediction block + unsigned getBranchIndexInBlock(Addr branchPC, Addr startPC); + + // Get bank ID from PC (after removing instruction alignment bits) + // Extract bits [bankBaseShift + bankIdWidth - 1 : bankBaseShift] + unsigned getBankId(Addr pc) const; + + // Update branch history + void doUpdateHist(const bitset &history, bool taken, Addr pc, Addr target); + + // Number of TAGE predictor tables + const unsigned numPredictors; + + // Size of each prediction table + std::vector tableSizes; + + // Number of bits used for indexing each table + std::vector tableIndexBits; + + // Masks for table indexing + std::vector tableIndexMasks; + + // Number of bits used for tags in each table + std::vector tableTagBits; + + // Masks for tag matching + std::vector tableTagMasks; + + // PC shift amounts for each table + std::vector tablePcShifts; + + // History lengths for each table + std::vector histLengths; + + // Folded history for tag calculation + std::vector tagFoldedHist; + + // Folded history for alternative tag calculation + std::vector altTagFoldedHist; + + // Folded history for index calculation + std::vector indexFoldedHist; + + // Linear feedback shift register for allocation + LFSR64 allocLFSR; + + // Maximum history length, not used + unsigned maxHistLen; + + // Number of ways for set associative design + const unsigned numWays; + + // The actual TAGE prediction tables (table x index x way) + std::vector>> tageTable; + + const unsigned maxBranchPositions; // Maximum branch positions per 64-byte block + + // useful bit reset counter, when cnt >= 256, reset useful bit of all entries + int usefulResetCnt{0}; + + // Check if a tag matches + bool matchTag(Addr expected, Addr found); + + // Set tag bits for a given table + void setTag(Addr &dest, Addr src, int table); + + // Number of tables to allocate on misprediction + unsigned numTablesToAlloc; + + // Instruction shift amount + unsigned instShiftAmt {1}; + + // use for microtage updatemispred counting + void checkUtageUpdateMisspred(const FetchStream &stream); + + // Update prediction counter with saturation + void updateCounter(bool taken, unsigned width, short &counter); + + // Increment counter with saturation + bool satIncrement(int max, short &counter); + + // Decrement counter with saturation + bool satDecrement(int min, short &counter); + + // Cache for TAGE indices + std::vector tageIndex; + + // Cache for TAGE tags + std::vector tageTag; + + // Whether to update on read + bool updateOnRead; + + // ========== Bank Configuration ========== + // Bank mechanism to simulate hardware bank conflicts + // When prediction and update access the same bank in one cycle, update is dropped + const unsigned numBanks; // Number of banks (e.g., 4) + const unsigned bankIdWidth; // log2(numBanks), computed in constructor + const unsigned blockWidth; // floorLog2(blockSize), e.g., 5 for 32B blocks + const unsigned bankBaseShift; // Bits removed before bank selection (default: instShiftAmt) + const unsigned indexShift; // bankBaseShift + bankIdWidth when banking enabled + bool enableBankConflict; // Enable/disable bank conflict simulation + + // Track last prediction bank for conflict detection + unsigned lastPredBankId; // Bank ID of last prediction + bool predBankValid; // Whether lastPredBankId is valid + +#ifdef UNIT_TEST + typedef uint64_t Scalar; +#else + typedef statistics::Scalar Scalar; +#endif + + // Statistics for TAGE predictor +#ifdef UNIT_TEST + struct TageStats + { +#else + struct TageStats : public statistics::Group + { +#endif + Scalar predNoHitUseBim; + Scalar predUseAlt; + Scalar updateNoHitUseBim; + Scalar updateUseAlt; + Scalar updateUseAltCorrect; + Scalar updateUseAltWrong; + Scalar updateAltDiffers; + Scalar updateUseAltOnNaUpdated; + Scalar updateProviderNa; + Scalar updateUseNaCorrect; + Scalar updateUseNaWrong; + Scalar updateUseAltOnNaCorrect; + Scalar updateUseAltOnNaWrong; + Scalar updateAllocFailure; + Scalar updateAllocFailureNoValidTable; + Scalar updateAllocSuccess; + Scalar updateMispred; + Scalar updateResetU; + + Scalar updateUtageHit; + Scalar updateUtageWrong; + + // Bank conflict statistics + Scalar updateBankConflict; // Number of bank conflicts detected + Scalar updateDeferredDueToConflict; // Number of updates deferred due to bank conflict (retried later) + +#ifndef UNIT_TEST + // Fine-grained per-bank statistics + statistics::Vector updateBankConflictPerBank; // Conflicts per bank + statistics::Vector updateAccessPerBank; // Update accesses per bank + statistics::Vector predAccessPerBank; // Prediction accesses per bank + + statistics::Distribution predTableHits; + statistics::Distribution updateTableHits; + + statistics::Vector updateTableMispreds; +#endif + + Scalar condPredwrong; + Scalar condMissTakens; + Scalar condCorrect; + Scalar condMissNoTakens; + Scalar predHit; + Scalar predMiss; + + int bankIdx; + int numPredictors; + int numBanks; + +#ifndef UNIT_TEST + TageStats(statistics::Group* parent, int numPredictors, int numBanks); +#endif + void updateStatsWithTagePrediction(const TagePrediction &pred, bool when_pred); + } ; + + TageStats tageStats; + +#ifndef UNIT_TEST + TraceManager *tageMissTrace; +#endif + +public: + + // Recover folded history after misprediction + void recoverFoldedHist(const bitset& history); + +public: + + + // Metadata for TAGE prediction + typedef struct TageMeta + { + std::unordered_map preds; + std::vector tagFoldedHist; + std::vector indexFoldedHist; + std::vector altTagFoldedHist; + bitset history; // for viewing + TageMeta() {} + } TageMeta; + +private: + + // Helper method to generate prediction for a single BTB entry + // If predMeta is provided, use snapshot folded history for index/tag calculation (update path) + // If predMeta is nullptr, use current folded history (prediction path) + TagePrediction generateSinglePrediction(const BTBEntry &btb_entry, + const Addr &startPC, + const std::shared_ptr predMeta = nullptr); + + // Helper method to prepare BTB entries for update + std::vector prepareUpdateEntries(const FetchStream &stream); + + // Helper method to update predictor state for a single entry + bool updatePredictorStateAndCheckAllocation(const BTBEntry &entry, + bool actual_taken, + const TagePrediction &pred, + const FetchStream &stream); + + // Helper method to handle new entry allocation + bool handleNewEntryAllocation(const Addr &startPC, + const BTBEntry &entry, + bool actual_taken, + unsigned main_table, + std::shared_ptr meta, + uint64_t &allocated_table, + uint64_t &allocated_index, + uint64_t &allocated_way); + + + // Helper methods for LRU management + void updateLRU(int table, Addr index, unsigned way); + unsigned getLRUVictim(int table, Addr index); + + std::shared_ptr meta; +}; + +// Close conditional namespace wrapper for testing +#ifdef UNIT_TEST +} // namespace test +#endif + +} // namespace btb_pred + +} // namespace branch_prediction + +} // namespace gem5 + +#endif // __CPU_PRED_BTB_MICROTAGE_HH__