diff --git a/examples/cpp/example.cpp b/examples/cpp/example.cpp index 34cf4eb0..f6a68ce7 100644 --- a/examples/cpp/example.cpp +++ b/examples/cpp/example.cpp @@ -467,7 +467,7 @@ int main(int argc, char **argv) { } auto result = model.finalize(); - if (isMaster) { + if (true) { std::cout << "\n[INFO] Final output is: " << std::endl; std::vector sent = tokenizer->batchDecode(result, batchSize); for (auto str : sent) { diff --git a/src/common/sequence.h b/src/common/sequence.h new file mode 100644 index 00000000..84049436 --- /dev/null +++ b/src/common/sequence.h @@ -0,0 +1,313 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#pragma once + +#include +#include +#include + +/* + SequencePool + ┌──────┬──────┬──────┐ + │ │ │ ◄───┼──┬─ SequenceMeta + ├──────┼──────┼──────┤ │ + BatchInputs │ │ │ ◄───┼──┘ + │ └▲─┬─▲─┴──────┴──────┘ + │ │ │ └───────────────────────────────────┐ + ▼ ┌──┬──┬──┬──┐ │ │ ┌──┬──┬──┬──┬──┬──┬──┬──┬──┐ │ + Input ─►│ │ │ │ ├──┘ └─────►│ │ │ │ │ │ │ │ │ ├─┐ │ + └──┴──┴──┴──┘ └──┴──┴──┴──┴──┴──┴──┴──┴──┘ │ │ + InputQueue TaskWaitingQueue0 │ │ + ┌───────────────────────────────┘ │ + │ ┌──┬──┬──┬──┬──┬──┬──┬──┬──┐ │ + └─►│ │ │ │ │ │ │ │ │ ├───┘ + └──┴──┴──┴──┴──┴──┴──┴──┴──┘ + TaskWaitingQueue1 +*/ + +namespace xft { + +// The SequenceMeta is one sequence of batch inputs and includes the generated tokens. +class SequenceMeta { +public: + SequenceMeta(int32_t _sequenceID, int32_t _inputSeqLen, std::vector &_inputTokens) + : sequenceID(_sequenceID), inputSeqLen(_inputSeqLen), pastSeqLen(0), step(0) { + inputTokens.resize(_inputSeqLen); + inputTokens.assign(_inputTokens.begin(), _inputTokens.end()); + nextTokens.resize(_inputSeqLen); + setPastSeqLen(getPastSeqLen()); + } + + SequenceMeta(int32_t _sequenceID, int32_t _inputSeqLen) + : sequenceID(_sequenceID), inputSeqLen(_inputSeqLen), inputTokens(_inputSeqLen, 0), pastSeqLen(0), step(0) { + nextTokens.resize(_inputSeqLen); + } + + ~SequenceMeta() {} + + int32_t getSequenceID() const { return sequenceID; } + + // For first tokens + void stepForward() { + if (getStep() == 0) { + setPastSeqLen(inputTokens.size()); + setStep(getStep() + 1); + } + } + + // For next token + void stepForward(int32_t token) { + // addNextToken(token); + setPastSeqLen(getPastSeqLen() + 1); + setStep(getStep() + 1); + } + + // Get the input tokens in sequence + int32_t getInputSeqLen() const { return inputSeqLen; } + + const int32_t *getInputTokens() const { return inputTokens.data(); } + + int32_t getPastSeqLen() const { return pastSeqLen; } + + void setPastSeqLen(int32_t _pastSeqLen) { pastSeqLen = _pastSeqLen; } + + // For next tokens + void addNextToken(int32_t token) { + nextTokens.clear(); + nextTokens.push_back(token); + inputTokens.push_back(token); + } + + int32_t getLatestToken() const { return nextTokens.back(); } + + const int32_t *getTotalTokens() const { return getInputTokens(); } + + int32_t getStep() const { return step; } + + void setStep(int32_t _step) { step = _step; } + +private: + int32_t sequenceID; + int32_t inputSeqLen; + int32_t pastSeqLen; + std::vector inputTokens; // input tokens + next tokens + std::vector nextTokens; // next tokens + int32_t step; + +#ifdef PIPELINE_PARALLEL +public: + template + void allocBuffer(int32_t hiddenSize, void *_hiddenStates) { + hiddenStates = xft::alloc(sizeof(T) * getInputSeqLen() * hiddenSize); + memcpy(hiddenStates, _hiddenStates, sizeof(T) * getInputSeqLen() * hiddenSize); + } + + int32_t getHiddenStatesSize() const { return hiddenStatesSize; } + + void setHiddenStatesSize(int32_t _hiddenStatesSize) { hiddenStatesSize = _hiddenStatesSize; } + +private: + int32_t hiddenSize; + int64_t hiddenStatesSize; + void *hiddenStates; +#endif +}; + +// For beam searcher +class SequenceGroupMeta { +public: + SequenceGroupMeta(int32_t _num_beams, std::vector &seq) { + num_beams = _num_beams; + sequences = seq; + } + +private: + int32_t num_beams; + std::vector sequences; +}; + +// SequencePool +// ┌──────┬──────┬──────┐ +// │ │ │ ◄───┼──┬─ SequenceMeta +// ├──────┼──────┼──────┤ │ +// │ │ │ ◄───┼──┘ +// └──────┴──────┴──────┘ +class SequencePool { +public: + static SequencePool &getInstance() { + static SequencePool instance; + return instance; + } + + int32_t createSequenceID() { + int32_t id = globalSequenceID++; + if (id >= 10 * 1024) { + globalSequenceID = 0; + id = globalSequenceID++; + } + return id; + } + + SequenceMeta *createMeta(int32_t sequenceID, int32_t inputSeqLen, std::vector &inputTokens) { + auto *sequenceMeta = new SequenceMeta(sequenceID, inputSeqLen, inputTokens); + return sequenceMeta; + } + + SequenceMeta *createMeta(int32_t sequenceID, int32_t inputSeqLen) { + auto *sequenceMeta = new SequenceMeta(sequenceID, inputSeqLen); + return sequenceMeta; + } + + bool add(int32_t sequenceID, SequenceMeta *sequence, bool force = false) { + bool isSuccess = false; + if (force) { + auto it = hub.find(sequenceID); + if (it != hub.end()) { remove(it->first, true); } + + hub[sequenceID] = sequence; + isSuccess = true; + } else { + bool exist = has(sequenceID); + if (!exist) { + hub[sequenceID] = sequence; + isSuccess = true; + } + } + + return isSuccess; + } + + bool has(int32_t sequenceID) const { return hub.find(sequenceID) != hub.end(); } + + SequenceMeta *get(int32_t sequenceID) const { + auto it = hub.find(sequenceID); + if (it != hub.end()) { + return it->second; + } else { + return nullptr; + } + } + + bool remove(int32_t sequenceID, bool deep = false) { + bool isSuccess = false; + if (has(sequenceID)) { + if (deep == true) { + auto it = hub.find(sequenceID); + if (it != hub.end()) { delete it->second; } + } + isSuccess = hub.erase(sequenceID); + } + + return isSuccess; + } + + bool replace(int32_t sequenceID, SequenceMeta *newSequenceMeta) { + bool isSuccess = false; + auto it = hub.find(sequenceID); + if (it != hub.end()) { + remove(it->first, true); + hub[sequenceID] = newSequenceMeta; + isSuccess = true; + } + + return isSuccess; + } + + void clear() { + for (auto &it : hub) { + delete it.second; + } + hub.clear(); + globalSequenceID = 0; + } + +private: + SequencePool() {} + + int32_t globalSequenceID = 0; + std::unordered_map hub; +}; + +// Manage input sequenceMeta +class InputQueue { +public: + static InputQueue &getInstance() { + static InputQueue instance; + return instance; + } + + bool empty() { return queue.empty(); } + + SequenceMeta *pop() { + auto seq = queue.front(); + queue.pop(); + return seq; + } + + void push(SequenceMeta *seq) { queue.push(seq); } + + void clear() { + while (!queue.empty()) { + queue.pop(); + } + } + +private: + InputQueue() {} + + std::queue queue; +}; + +// Manage executive sequenceMeta +class TaskWaitingQueue { +public: + static TaskWaitingQueue &getInstance() { + static TaskWaitingQueue instance; + return instance; + } + + bool empty() { return queue.empty(); } + + int32_t size() { return queue.size(); } + + bool isFull() { + bool full = false; + if (this->size() >= Env::getInstance().getMaxRequestNum()) { full = true; } + return full; + } + + SequenceMeta *front() { return queue.front(); } + + SequenceMeta *pop() { + auto seq = queue.front(); + queue.pop(); + return seq; + } + + void push(SequenceMeta *seq) { queue.push(seq); } + + void clear() { + while (!queue.empty()) { + queue.pop(); + } + } + +private: + TaskWaitingQueue() {} + + std::queue queue; +}; + +} // namespace xft \ No newline at end of file diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index 3685baae..8c11815d 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -65,6 +65,10 @@ struct DecoderContext { // For custom usage int reserved1; +#ifdef PIPELINE_PARALLEL + int32_t sequenceID; +#endif + // Model structure configuration int vocabSize; int embeddingSize; @@ -319,4 +323,4 @@ struct DecoderContext { } ~DecoderContext() { free(this->rawBuffer); } -}; +}; \ No newline at end of file diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index ab289027..cbe06e82 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -31,6 +31,8 @@ #include "mlp_chatglm2.h" #include "mlp_standard.h" #include "model_factory.h" +#include "sequence.h" +#include "thread_util.h" #include "timeline.h" #include "transformer_ctx.h" #include "transpose_util.h" @@ -278,7 +280,7 @@ class CommonDecoder : public AbstractDecoder { int userSideBS = dims[0]; int beamSize = dims[1]; - int batchSize = (step == 0 ? userSideBS : userSideBS * beamSize); // as samples are duplicated at step 0 + int batchSize = (step == 0 ? userSideBS : userSideBS * beamSize); // as sequence are duplicated at step 0 int seqLen = dims[2]; int pastSeqLen = step == 0 ? 0 : this->accSeqLen; int inputSeqLen = seqLen; @@ -286,6 +288,7 @@ class CommonDecoder : public AbstractDecoder { // Prepare context DecoderContext *ctx = this->getContext(); ctx->resize(batchSize, seqLen, pastSeqLen); + int hiddenSize = ctx->hiddenSize; if (step == 0) { // Reset initial and accumulated sequence length at the first step @@ -314,7 +317,7 @@ class CommonDecoder : public AbstractDecoder { } AttnInT *embBuf = (AttnInT *)actBuffers->Data(); - MlpOutT *outBuf = (MlpOutT *)(embBuf + batchSize * inputSeqLen * ctx->hiddenSize); + MlpOutT *outBuf = (MlpOutT *)(embBuf + batchSize * inputSeqLen * hiddenSize); // Embedding this->embeddingForward(ids, embBuf, batchSize, inputSeqLen); @@ -324,9 +327,8 @@ class CommonDecoder : public AbstractDecoder { dbg.debugPrint("---- embedding.forward ----\n"); dbg.debugPrint("ids:\n"); dbg.dumpMatrix(ids, batchSize, inputSeqLen, inputSeqLen); - dbg.debugPrint( - "embBuf(rows: %d, cols: %d, stride: %d):\n", batchSize * inputSeqLen, ctx->hiddenSize, ctx->hiddenSize); - dbg.dumpMatrix(embBuf, batchSize * inputSeqLen, ctx->hiddenSize, ctx->hiddenSize); + dbg.debugPrint("embBuf(rows: %d, cols: %d, stride: %d):\n", batchSize * inputSeqLen, hiddenSize, hiddenSize); + dbg.dumpMatrix(embBuf, batchSize * inputSeqLen, hiddenSize, hiddenSize); #endif // Prepare attention mask @@ -337,19 +339,61 @@ class CommonDecoder : public AbstractDecoder { t1.release(); #ifdef PIPELINE_PARALLEL + int curr_world_rank = ctx->ppRank * ctx->tpSize + ctx->tpRank; + int prev_world_rank = (ctx->ppRank - 1) * ctx->tpSize + ctx->tpRank; // if current pipeline parallel stage rank isn't the first stage, should receive previous stage data - if (ctx->ppSize > 1 && ctx->ppRank > 0) { - int curr_world_rank = ctx->ppRank * ctx->tpSize + ctx->tpRank; - int prev_world_rank = (ctx->ppRank - 1) * ctx->tpSize + ctx->tpRank; - int count = batchSize * inputSeqLen * ctx->hiddenSize; - MPI_Recv(embBuf, count, MPI_FLOAT, prev_world_rank, curr_world_rank, MPI_COMM_WORLD, MPI_STATUS_IGNORE); - // TODO: Error: different scope when dynamic loading so file - // this->messenger.worldRecvFP32(embBuf, count, prev_world_rank, curr_world_rank); + if (ctx->ppSize > 1 && ctx->ppRank > 0 && enabledBackgroundSync == false) { + enabledBackgroundSync = true; + // int64_t count = batchSize * inputSeqLen * hiddenSize; + ThreadPool::getInstance().addTask([curr_world_rank, prev_world_rank, seqLen, hiddenSize, pastSeqLen, this] { + while (true) { + int64_t recvBuf[2] = {0, 0}; + MPI_Recv(&recvBuf, 2, MPI_INT64_T, prev_world_rank, curr_world_rank, MPI_COMM_WORLD, + MPI_STATUS_IGNORE); + int32_t sequenceID = recvBuf[0]; + int64_t count = recvBuf[1]; + // TODO: Error: different scope when dynamic loading so file + // this->messenger.worldRecvFP32(embBuf, count, prev_world_rank, curr_world_rank); + TimeLine t("Decoder.Seq" + std::to_string(sequenceID) + ".MPI_Recv"); + if (!SequencePool::getInstance().has(sequenceID)) { + SequenceMeta *sequence = SequencePool::getInstance().createMeta(sequenceID, seqLen); + sequence->setHiddenStatesSize(count); + // sequence->setPastSeqLen(pastSeqLen); + // sequence->allocBuffer(hiddenSize, embBuf); + SequencePool::getInstance().add(sequence->getSequenceID(), sequence); + } + TaskWaitingQueue::getInstance().push(SequencePool::getInstance().get(sequenceID)); + } + }); + } + + while (!InputQueue::getInstance().empty()) { + if (!TaskWaitingQueue::getInstance().isFull()) { + auto sequence = InputQueue::getInstance().pop(); + // sequence->setPastSeqLen(pastSeqLen); + // sequence->allocBuffer(hiddenSize, embBuf); + SequencePool::getInstance().add(sequence->getSequenceID(), sequence); + TaskWaitingQueue::getInstance().push(SequencePool::getInstance().get(sequence->getSequenceID())); + } } + + while (TaskWaitingQueue::getInstance().empty()); + + SequenceMeta *runningTask = nullptr; + int32_t sequenceID = -1; + if (!TaskWaitingQueue::getInstance().empty()) { + runningTask = TaskWaitingQueue::getInstance().front(); + sequenceID = runningTask->getSequenceID(); + ctx->sequenceID = runningTask->getSequenceID(); + // runningTask->setPastSeqLen(pastSeqLen); + // runningTask->allocBuffer(hiddenSize, embBuf); + MPI_Recv(embBuf, TaskWaitingQueue::getInstance().front()->getHiddenStatesSize(), MPI_FLOAT, prev_world_rank, + curr_world_rank + 1000, MPI_COMM_WORLD, MPI_STATUS_IGNORE); + + TimeLine t("Decoder.Seq" + std::to_string(sequenceID) + ".Step"); #endif // Decoder: forward - int hiddenSize = ctx->hiddenSize; int layers_per_pp_stage = this->decoders.size(); for (int i = 0; i < layers_per_pp_stage; ++i) { int workers = this->messenger.getSize(); @@ -402,11 +446,16 @@ class CommonDecoder : public AbstractDecoder { } #ifdef PIPELINE_PARALLEL + } + // If current pipeline stage isn't the end of stage, should send data to next stage and return nullptr if (ctx->ppSize > 1 && ctx->ppRank < ctx->ppSize - 1) { + TimeLine t("Decoder.Seq" + std::to_string(sequenceID) + ".MPI_Send"); int next_world_rank = (ctx->ppRank + 1) * ctx->tpSize + ctx->tpRank; - int count = batchSize * inputSeqLen * ctx->hiddenSize; - MPI_Send(embBuf, count, MPI_FLOAT, next_world_rank, next_world_rank, MPI_COMM_WORLD); + int64_t count = batchSize * inputSeqLen * hiddenSize; + int64_t sendBuf[2] = {sequenceID, count}; + MPI_Send(&sendBuf, 2, MPI_INT64_T, next_world_rank, next_world_rank, MPI_COMM_WORLD); + MPI_Send(embBuf, count, MPI_FLOAT, next_world_rank, next_world_rank + 1000, MPI_COMM_WORLD); // TODO: Error: different scope when dynamic loading so file // this->messenger.worldSendFP32(embBuf, count, next_world_rank, next_world_rank); return std::tuple(nullptr, 0, 0); @@ -982,6 +1031,8 @@ class CommonDecoder : public AbstractDecoder { int startId; int endId; + bool enabledBackgroundSync = false; + #ifdef DEBUG Debugger dbg; #endif diff --git a/src/models/models.cpp b/src/models/models.cpp index fe08ed89..18f08236 100644 --- a/src/models/models.cpp +++ b/src/models/models.cpp @@ -32,6 +32,7 @@ #include "searcher.h" #include "timeline.h" #include "yarn_llama.h" +#include "sequence.h" namespace xft { enum class GenerationMode { GREEDY_SEARCH, BEAM_SEARCH, SAMPLE }; @@ -84,6 +85,21 @@ void Model::input(std::vector &inputIds_, int batchSize_) { inputIds.resize(dims[1]); if (decoder->getRank() == 0) { inputIds = inputIds_; } messenger.broadcast(inputIds.data(), dims[1]); + + if (this->isMaster()) { + for (int i = 0; i < 2; ++i) { + int sequenceID = SequencePool::getInstance().createSequenceID(); + InputQueue::getInstance().push(SequencePool::getInstance().createMeta(sequenceID, seqLen, inputIds)); + } + + while (!InputQueue::getInstance().empty()) { + if (!TaskWaitingQueue::getInstance().isFull()) { + auto sequence = InputQueue::getInstance().pop(); + SequencePool::getInstance().add(sequence->getSequenceID(), sequence); + TaskWaitingQueue::getInstance().push(SequencePool::getInstance().get(sequence->getSequenceID())); + } + } + } } void Model::config(int maxLen_, int numBeams_, int numBeamHypsToKeep_, float lenPenalty_, bool doEarlyStopping_, @@ -114,6 +130,11 @@ void Model::config(SearcherConfig &config_, const std::vector> // Slaves get exit flags and exit directly if (decoder->getRank() > 0 && configuration.numBeams == 0) { exit(0); } + InputQueue::getInstance().clear(); + TaskWaitingQueue::getInstance().clear(); + SequencePool::getInstance().clear(); + // ThreadPool::getInstance().clear(); + createSearcher(configuration); setStopWords(stopWordsList_); } @@ -141,12 +162,26 @@ std::vector Model::generate() { exit(-1); } - if (isNewInput) { + std::vector token; + if (!this->isMaster() && isNewInput) { isNewInput = false; - return searcher->getNextToken(inputIds.data(), batchSize, inputIds.size() / batchSize); + token = searcher->getNextToken(inputIds.data(), batchSize, inputIds.size() / batchSize); + TaskWaitingQueue::getInstance().front()->stepForward(); } else { - return searcher->getNextToken(); + while(TaskWaitingQueue::getInstance().empty()); + + if (TaskWaitingQueue::getInstance().front()->getStep() == 0) { + isNewInput = false; + token = searcher->getNextToken(inputIds.data(), batchSize, inputIds.size() / batchSize); + TaskWaitingQueue::getInstance().front()->stepForward(); + } else { + token = searcher->getNextToken(); + TaskWaitingQueue::getInstance().front()->stepForward(token[0]); + } } + + TaskWaitingQueue::getInstance().pop(); + return token; } void Model::createSearcher(SearcherConfig &config_) { diff --git a/src/searchers/greedy_search.cpp b/src/searchers/greedy_search.cpp index 0e55648e..493fb717 100644 --- a/src/searchers/greedy_search.cpp +++ b/src/searchers/greedy_search.cpp @@ -14,10 +14,14 @@ // ============================================================================ #include "greedy_search.h" #include "messenger.h" +#include "sequence.h" #include "search_utils.h" +#include "thread_util.h" + +using namespace xft; GreedySearch::GreedySearch(AbstractDecoder &dec, const SearcherConfig &config) - : decoder(dec), maxLen(config.maxLen), step(0), repetitionPenalty(config.repetitionPenalty) { + : decoder(dec), maxLen(config.maxLen), step(0), repetitionPenalty(config.repetitionPenalty), enabledBackgroundSync(false) { eosTokenId = config.eosTokenId == -1 ? decoder.getEndId() : config.eosTokenId; padTokenId = config.padTokenId == -1 ? eosTokenId : config.padTokenId; if (repetitionPenalty <= 0) { @@ -35,21 +39,42 @@ std::vector GreedySearch::syncToken(std::tuple &result) // Messenger &messenger = decoder.getMessenger(); if (std::get<0>(result) == nullptr) { // The first embedding pipeline parallel stage - this->nextTokens = std::vector(batchSize, 0); - if (ctx->ppSize > 1 && ctx->ppRank == 0) { + if (ctx->ppSize > 1 && ctx->ppRank == 0 && enabledBackgroundSync == false) { + enabledBackgroundSync = true; int predictor_world_rank = (ctx->ppSize - 1) * ctx->tpSize + ctx->tpRank; - MPI_Recv(this->nextTokens.data(), batchSize, MPI_INT32_T, predictor_world_rank, predictor_world_rank, - MPI_COMM_WORLD, MPI_STATUS_IGNORE); - // TODO: Error: different scope when dynamic loading so file - // messenger.worldRecvINT32(this->nextTokens.data(), batchSize, predictor_world_rank, predictor_world_rank); + ThreadPool::getInstance().addTask([predictor_world_rank, this] { + while (true) { + int32_t recvBuf[2]; + MPI_Recv(&recvBuf, 2, MPI_INT32_T, predictor_world_rank, predictor_world_rank, MPI_COMM_WORLD, + MPI_STATUS_IGNORE); + int32_t sequenceID = recvBuf[0]; + this->nextTokens[0] = recvBuf[1]; + // MPI_Recv(&sequenceID, 1, MPI_INT32_T, predictor_world_rank, predictor_world_rank, MPI_COMM_WORLD, + // MPI_STATUS_IGNORE); + TimeLine t("GreedySearch.Seq" + std::to_string(sequenceID) + ".MPI_Recv"); + // MPI_Recv(this->nextTokens.data(), this->batchSize, MPI_INT32_T, predictor_world_rank, + // predictor_world_rank, MPI_COMM_WORLD, MPI_STATUS_IGNORE); + if (SequencePool::getInstance().has(sequenceID)) { + auto sequence = SequencePool::getInstance().get(sequenceID); + TaskWaitingQueue::getInstance().push(sequence); + } else { + printf("Error: should have sequenceID\n"); + fflush(stdout); + } + } + }); } } else { // The last predictor pipeline parallel stage this->nextTokens = this->search(result); if (ctx->ppSize > 1 && ctx->ppRank == ctx->ppSize - 1) { + TimeLine t("GreedySearch.Seq" + std::to_string(ctx->sequenceID) + ".MPI_Send"); int embedding_world_rank = 0 * ctx->tpSize + ctx->tpRank; int predictor_world_rank = (ctx->ppSize - 1) * ctx->tpSize + ctx->tpRank; - MPI_Send(this->nextTokens.data(), batchSize, MPI_INT32_T, embedding_world_rank, predictor_world_rank, - MPI_COMM_WORLD); + int32_t sendBuf[2] = {ctx->sequenceID, nextTokens[0]}; + MPI_Send(&sendBuf, 2, MPI_INT32_T, embedding_world_rank, predictor_world_rank, MPI_COMM_WORLD); + // MPI_Send(&ctx->sequenceID, 1, MPI_INT32_T, embedding_world_rank, predictor_world_rank, MPI_COMM_WORLD); + // MPI_Send(this->nextTokens.data(), batchSize, MPI_INT32_T, embedding_world_rank, predictor_world_rank, + // MPI_COMM_WORLD); // TODO: Error: different scope when dynamic loading so file // messenger.worldSendINT32(this->nextTokens.data(), batchSize, embedding_world_rank, predictor_world_rank); } @@ -82,6 +107,8 @@ std::vector GreedySearch::getNextToken(int *ids, int batchSize, int seqLen) std::copy(ids, ids + batchSize * seqLen, output.begin()); int64_t dims[3] = {batchSize, 1, seqLen}; + if (this->nextTokens.size() != batchSize) + this->nextTokens.resize(batchSize, 0); std::tuple result = decoder.forward(ids, dims, this->step++); @@ -92,6 +119,8 @@ std::vector GreedySearch::getNextToken(int *ids, int batchSize, int seqLen) std::vector GreedySearch::getNextToken() { TimeLine t("Next Token"); int64_t dims[3] = {batchSize, 1, 1}; + if (this->nextTokens.size() != batchSize) + this->nextTokens.resize(batchSize, 0); std::tuple result = decoder.forward(nextTokens.data(), dims, this->step++); diff --git a/src/searchers/greedy_search.h b/src/searchers/greedy_search.h index 607d9737..5b4ec164 100644 --- a/src/searchers/greedy_search.h +++ b/src/searchers/greedy_search.h @@ -47,6 +47,7 @@ class GreedySearch : public AbstractSearcher { std::vector> cachedRepetVec; std::vector doneBatch; + bool enabledBackgroundSync; int batchSize; int step; int curLen; diff --git a/src/utils/environment.h b/src/utils/environment.h index e6d94338..2630ff74 100644 --- a/src/utils/environment.h +++ b/src/utils/environment.h @@ -41,6 +41,9 @@ class Env { // get Engine Kind and Index int getPipelineStage() { return pipelineStageValue; } + // get Engine Kind and Index + int getMaxRequestNum() { return maxRequestNumValue; } + // get AMX Threshold M int getAMXThresholdM() { return AMXThresholdMValue; } @@ -73,6 +76,9 @@ class Env { // init Pipeline Parallel initPipelineStage(); + // init Max request number + initMaxRequestNum(); + // init Engine Kind and Index initEngineKindIndex(); @@ -173,6 +179,21 @@ class Env { } } + // Max request number + int maxRequestNumValue = 1; + void initMaxRequestNum() { + char *xft_max_request_num_value = getenv("XFT_MAX_REQUEST_NUM"); + if (xft_max_request_num_value != NULL) { + int value = atoi(xft_max_request_num_value); + if (value >= 1) + maxRequestNumValue = value; + else + printf("[ERROR] XFT_MAX_REQUEST_NUM value need to be greater than 0.\n"); + } else { + maxRequestNumValue = 1; + } + } + // AMX Threshold M int AMXThresholdMValue = 1; void initAMXThresholdM() { diff --git a/src/utils/thread_util.h b/src/utils/thread_util.h index c6826051..f22c59ac 100644 --- a/src/utils/thread_util.h +++ b/src/utils/thread_util.h @@ -1,6 +1,29 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ #pragma once #include +#include +#include +#include +#include +#include +#include + +namespace xft { + template void parallel_for(int tasks, const Lambda &fn) { #pragma omp parallel for @@ -15,4 +38,58 @@ void parallel_for_dschedule(int tasks, const Lambda &fn) { for (int i = 0; i < tasks; i++) { fn(i); } -} \ No newline at end of file +} + +class ThreadPool { +public: + static ThreadPool &getInstance() { + static ThreadPool instance; + return instance; + } + + template + void addTask(F &&f, Args &&...args) { + { + std::unique_lock lock(queueMutex); + tasks.emplace(std::bind(std::forward(f), std::forward(args)...)); + } + condition.notify_one(); + } + + void clear() { + stop = true; + condition.notify_all(); + for (std::thread &worker : workers) { + worker.join(); + } + } + +private: + ThreadPool() : stop(false) { + for (size_t i = 0; i < numThreads; ++i) { + workers.emplace_back([this] { + while (true) { + std::function task; + { + std::unique_lock lock(queueMutex); + condition.wait(lock, [this] { return stop || !tasks.empty(); }); + if (stop && tasks.empty()) { return; } + task = std::move(tasks.front()); + tasks.pop(); + } + task(); + } + }); + } + } + + static constexpr size_t numThreads = 1; + std::vector workers; + std::queue> tasks; + + std::mutex queueMutex; + std::condition_variable condition; + bool stop; +}; + +} // namespace xft \ No newline at end of file