Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmake/xdnn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ include(ExternalProject)

# cmake-format: off
ExternalProject_Add(xdnn_lib
URL https://github.com/intel/xFasterTransformer/releases/download/IntrinsicGemm/xdnn_v1.5.9.tar.gz
URL_HASH MD5=3aa9cd15df3eb2a7a1c178f3edcf9d37
URL https://github.com/intel/xFasterTransformer/releases/download/IntrinsicGemm/xdnn_v1.6.1.tar.gz
URL_HASH MD5=309dcb57065642bd16e7d1e0863e4cdb
TIMEOUT 120
SOURCE_DIR ${CMAKE_SOURCE_DIR}/3rdparty/xdnn
CONFIGURE_COMMAND ""
Expand Down
78 changes: 65 additions & 13 deletions src/layers/moe_deepseek.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@
#include "mlp_llama.h"
#include "type_selector.h"
#include "timeline.h"
#include "numa_allocator.h"

template <typename WeiT, typename InT = bfloat16_t, typename ImT = bfloat16_t, typename OutT = bfloat16_t>
class DeepSeekMoE {
public:
DeepSeekMoE(int layerId, DecoderContext *ctx) : layerId(layerId), norm(ctx) {
// tp method numa affinity
if (ctx->numSplit > 1) xft_set_preferred_node(ctx->splitIdx);
//dense mlp or concatted all shared experts
shared_expert = new LlamaMLP<WeiT, InT, ImT, OutT>(layerId, ctx);
if (layerId >= ctx->firstKDenseReplace) {
Expand Down Expand Up @@ -289,7 +292,8 @@ class DeepSeekMoE {

// Call forward function of selected experts
// expert-wise for large M or bf16 for now
if (M > 128 || std::is_same_v<WeiT, bfloat16_t> || Env::getInstance().getMoEEngine() == 0) {
if (M * topkExpert / expertNum > 2 || std::is_same_v<WeiT, bfloat16_t> || (M != 1 && Env::getInstance().getMoEEngine() == 0)) {
//if (M * std::min(topkExpert, expertNum) / expertNum > 2 || std::is_same_v<WeiT, bfloat16_t> || Env::getInstance().getMoEEngine() == 0) {
// 5. Reorder the input and weight for each expert
std::vector<int> idx[expertNum]; // index for each expert
std::vector<float> weights[expertNum]; // weight for each expert
Expand Down Expand Up @@ -328,22 +332,64 @@ class DeepSeekMoE {
// Scatter output of expert i (critical section)
scatterOutput(output, oStride, expertData, hiddenSize, idx[i], weights[i]);
}
} else if (Env::getInstance().getMoEEngine() == 1) {
} else if (M == 1 || Env::getInstance().getMoEEngine() == 1) {
// call sparse mlp for each token
for (int i = 0; i < M; ++i) {
TimeLine t("MoE_TokenSparseFW");
OutT *tokenData = ctx->getBuffer<OutT>("tokenData", 1 * hiddenSize, ctx->device);
int nExperts = 0;
for (int j = 0; j < topkExpert; ++j) {
if (selExperts[i * topkExpert + j] < 0) break;
++nExperts;
}
if (nExperts == 0) continue;
sparseForward(ctx, normBuf + i * normStride, selExperts + i * topkExpert, expertWeight + i * topkExpert,
if (nExperts == 0) continue;
OutT *tokenData = ctx->getBuffer<OutT>("tokenData", 1 * hiddenSize, ctx->device);
sparseForward(ctx, normBuf + i * normStride, selExperts + i * topkExpert, expertWeight + i * topkExpert, 1,
nExperts, tokenData, hiddenSize, output + i * oStride, oStride);
}
} else if (Env::getInstance().getMoEEngine() == 2) {
// token-batched for small expertNum and medium M
//const char* env_toks_reorder = std::getenv("TOKS_REORDER");
//const char* env_zero_cpy = std::getenv("CPU_ZERO_COPY");
//if ((env_toks_reorder && std::strcmp(env_toks_reorder, "0") == 0) || (env_zero_cpy && std::strcmp(env_zero_cpy, "0") == 0)) {
//xft::Logger::warning("Unsupported MoE engine 2: TOKS_REORDER should be set as 1");
//exit(-1);
//}

// collect all selected experts without duplicate
std::vector<int> uniqueExperts;
std::unordered_map<int, int> expertIdMap;
for (int i = 0; i < M; ++i) {
for (int j = 0; j < topkExpert; ++j) {
int eid = selExperts[i * topkExpert + j];
if (eid < 0) break;
if (expertIdMap.find(eid) == expertIdMap.end()) {
expertIdMap[eid] = uniqueExperts.size();
uniqueExperts.push_back(eid);
}
}
}

//and pad to nExpertsPadded, set 0 for padding expertWeights
int nExpertsPadded = uniqueExperts.size();
int *selExpertsPadded = uniqueExperts.data();
float *expertWeightPadded = ctx->getBuffer<float>("expertWeightPadded", nExpertsPadded * M, ctx->device);
memset(expertWeightPadded, 0, nExpertsPadded * M * sizeof(float));
for (int i = 0; i < M; ++i) {
for (int j = 0; j < std::min(topkExpert, expertNum); ++j) {
// Fill selected expert weights
int eid = selExperts[i * topkExpert + j];
if (eid < 0) break;
int idx = expertIdMap[eid];
expertWeightPadded[idx * M + i] = expertWeight[i * topkExpert + j];
//expertWeightPadded[i * nExpertsPadded + idx] = expertWeight[i * topkExpert + j];
}
}

OutT *tokenData = ctx->getBuffer<OutT>("tokenData", M * hiddenSize, ctx->device);
sparseForward(ctx, normBuf, selExpertsPadded, expertWeightPadded, M, nExpertsPadded, tokenData, hiddenSize,
output, oStride);
} else {
xft::Logger::error("Unsupported MoE engine: %d", Env::getInstance().getMoEEngine());
xft::Logger::error("Unsupported MoE engine");
exit(-1);
}
#ifdef XFT_DEBUG
Expand Down Expand Up @@ -535,7 +581,7 @@ class DeepSeekMoE {
}
}

void sparseForward(DecoderContext *ctx, ImT *input, int *selExperts, float *expertWeight, int nExperts, OutT *tokenData,
void sparseForward(DecoderContext *ctx, ImT *input, int *selExperts, float *expertWeight, int M, int nExperts, OutT *tokenData,
int hiddenSize, OutT *output, int oStride) {
const WeiT *weightsGUList[nExperts];
const WeiT *weightsDList[nExperts];
Expand All @@ -546,11 +592,9 @@ class DeepSeekMoE {
int blockSize = 128;
float alpha[nExperts];
OutT *imOuts[nExperts];
float *expertWeightPadded[nExperts];
int N1[nExperts], ldc1[nExperts], K2[nExperts];

// just for 1 token
int M = 1;

int K1 = hiddenSize;
int lda1 = hiddenSize;
int N2 = hiddenSize;
Expand All @@ -572,6 +616,7 @@ class DeepSeekMoE {
weightsDList[i] = this->experts[selExperts[i]]->downWeight.Data();
ldaDScales[i] = (K2[i] + blockSize - 1) / blockSize;
scalesDList[i] = this->experts[selExperts[i]]->downScales.Data();
expertWeightPadded[i] = expertWeight + i * M;
}
}

Expand Down Expand Up @@ -626,14 +671,21 @@ class DeepSeekMoE {
TimeLine t("SparseFW_Down");
if (Env::getInstance().getMoESplitBalanceDim() == 1) {
// For sparse mlp, we use compute_batch_AM to compute experts in different dimensions
ctx->mmHelper->compute_batch_AM(M, N2, K2, expertWeight, (const bfloat16_t**)imOuts, lda2, weightsDList,
ctx->mmHelper->compute_batch_AM_MA(M, N2, K2, expertWeightPadded, (const bfloat16_t**)imOuts, lda2, weightsDList,
scalesDList, tokenData, ldc2, ldaDScales, blockSize, nExperts);
//ctx->mmHelper->compute_batch_AM(M, N2, K2, expertWeight, (const bfloat16_t**)imOuts, lda2, weightsDList,
// scalesDList, tokenData, ldc2, ldaDScales, blockSize, nExperts);
} else {
// For sparse mlp with concat experts, we use compute_batch_A to compute experts in the same dimension
ctx->mmHelper->compute_batch_A(M, N2, K2[0], expertWeight, (const bfloat16_t**)imOuts, lda2,
//ctx->mmHelper->compute_batch_A(M, N2, K2[0], expertWeight, (const bfloat16_t**)imOuts, lda2,
// weightsDList, scalesDList, tokenData, ldc2, ldaDScales, blockSize, nExperts);
ctx->mmHelper->compute_batch_A_MA(M, N2, K2[0], expertWeightPadded, (const bfloat16_t**)imOuts, lda2,
weightsDList, scalesDList, tokenData, ldc2, ldaDScales, blockSize, nExperts);
}
xft::addto(output, tokenData, 1.0, hiddenSize);
#pragma omp parallel for
for (int i = 0; i < M; ++i) {
xft::addto(output + i * hiddenSize, tokenData + i * hiddenSize, 1.0, hiddenSize);
}
}
#ifdef XFT_DEBUG
dbg.debugPrint("tokenData (%d %d):\n", 1, hiddenSize);
Expand Down
1 change: 1 addition & 0 deletions src/utils/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ class Env {
// XFT_MOE_ENGINE
// 0: batched tokens computing for each expert
// 1: batched experts computing for each token
// 2: batched experts computing for batched token (may have redundant computing)
int moeEngine = 1;
void initMoEEngine() {
char *xFTMoEEngineValue = getenv("XFT_MOE_ENGINE");
Expand Down
36 changes: 33 additions & 3 deletions src/utils/matmul_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -1721,7 +1721,7 @@ class MMHelper {
}
}

template <typename InT, typename WeiT, typename OutT>
template <typename InT, typename WeiT, typename OutT>
void compute_batch_CM(int M, int *N, int K, float *alphaList, const InT *A, int lda, const WeiT *packedBBatch[],
const float *scalesList[], OutT *CList[], int *ldcList, int *ldsList, int blockSize = 128,
int batchSize = 1) {
Expand All @@ -1736,7 +1736,8 @@ class MMHelper {
exit(-1);
}
}
template <typename InT, typename WeiT, typename OutT>

template <typename InT, typename WeiT, typename OutT>
void compute_batch_AM(int M, int N, int *K, float *alphaList, const InT *A[], int *ldaList,
const WeiT *packedBBatch[], const float *scalesList[], OutT *C, int ldc, int *ldsList, int blockSize = 128,
int batchSize = 1) {
Expand All @@ -1763,7 +1764,36 @@ class MMHelper {
GEMMVERBOSE("xdnn_small_amx_sgemm_bf16bf16bf16_compute_batch_A",
xdnn_small_amx_sgemm_bf16bf16bf16_compute_BA16a64b2a_batch_A(M, N, K, (const XDNN_BF16 **)A, ldaList, (const XDNN_BF16 **)packedBBatch,
(XDNN_BF16 *)C, ldc, alphaList, batchSize));
} else{
} else {
printf("%s:%d: Unsupported data type for compute_residential_batch_A", __FILE__, __LINE__);
exit(-1);
}
}

template <typename InT, typename WeiT, typename OutT>
void compute_batch_AM_MA(int M, int N, int *K, float *alphaList[], const InT *A[], int *ldaList,
const WeiT *packedBBatch[], const float *scalesList[], OutT *C, int ldc, int *ldsList, int blockSize = 128,
int batchSize = 1) {
if constexpr (std::is_same_v<WeiT, e4m3_t> && std::is_same_v<OutT, bfloat16_t>
&& std::is_same_v<InT, bfloat16_t>) {
GEMVKVERBOSE("xdnn_small_amx_sgemm_bf16f8bf16_compute_batch_AM",
xdnn_small_amx_sgemm_bf16f8bf16_compute_batch_AM_MA(M, N, K, (const XDNN_BF16 **)A, ldaList,
(const XDNN_E4M3 **)packedBBatch, (XDNN_BF16 *)C, ldc, scalesList, ldsList, blockSize,
(const float**)alphaList, batchSize));
} else {
printf("%s:%d: Unsupported data type for compute_residential_batch_A", __FILE__, __LINE__);
exit(-1);
}
}

template <typename InT, typename WeiT, typename OutT>
void compute_batch_A_MA(int M, int N, int K, float *alphaList[], const InT *A[], int *ldaList, const WeiT *packedBBatch[],
const float *scalesList[], OutT *C, int ldc, int *ldsList, int blockSize = 128, int batchSize = 1) {
if constexpr (std::is_same_v<WeiT, e4m3_t> && std::is_same_v<OutT, bfloat16_t> && std::is_same_v<InT, bfloat16_t>) {
GEMMVERBOSE("xdnn_small_amx_sgemm_bf16f8bf16_compute_batch_A",
xdnn_small_amx_sgemm_bf16f8bf16_compute_batch_A_MA(M, N, K, (const XDNN_BF16 **)A, ldaList, (const XDNN_E4M3 **)packedBBatch,
(XDNN_BF16 *)C, ldc, scalesList, ldsList, blockSize, (const float**)alphaList, batchSize));
} else {
printf("%s:%d: Unsupported data type for compute_residential_batch_A", __FILE__, __LINE__);
exit(-1);
}
Expand Down