diff --git a/src/infiniccl/moore/infiniccl_moore.cc b/src/infiniccl/moore/infiniccl_moore.cc index b58b2d63a..4dee4aebd 100644 --- a/src/infiniccl/moore/infiniccl_moore.cc +++ b/src/infiniccl/moore/infiniccl_moore.cc @@ -23,6 +23,8 @@ inline mcclDataType_t getMcclDtype(infiniDtype_t datatype) { return mcclFloat; case INFINI_DTYPE_F16: return mcclHalf; + case INFINI_DTYPE_BF16: + return mcclBfloat16; default: std::abort(); return mcclHalf; @@ -83,9 +85,7 @@ infiniStatus_t allReduce( infinicclComm_t comm, infinirtStream_t stream) { - if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) { - return INFINI_STATUS_BAD_PARAM; - } + CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); CHECK_MCCL(mcclAllReduce(sendbuf, recvbuf, count, getMcclDtype(datatype), getMcclRedOp(op), getMcclComm(comm), getMusaStream(stream)));