From 66fec6209851a12c725641cf60da7806a824cea9 Mon Sep 17 00:00:00 2001 From: "deng.xiangyu" Date: Sat, 30 Aug 2025 23:36:42 +0800 Subject: [PATCH] fix cast from f32 to f16 --- src/utils/types.cpp | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/utils/types.cpp b/src/utils/types.cpp index 4163c214..b5710e45 100644 --- a/src/utils/types.cpp +++ b/src/utils/types.cpp @@ -52,8 +52,21 @@ fp16_t _f32_to_f16(float val) { // Infinity return fp16_t{static_cast(sign | 0x7C00)}; } else if (exponent >= -14) { // Normalized case - return fp16_t{(uint16_t)(sign | ((exponent + 15) << 10) | (mantissa >> 13))}; - } else if (exponent >= -24) { + // --- START MODIFICATION --- + // Add 0x1000 (2^12), which is half of the value of the 13th bit. + // This effectively rounds the 10-bit mantissa to the nearest value. + uint32_t rounded_mantissa = mantissa + 0x1000; + + // Check for overflow in the mantissa after rounding + if (rounded_mantissa & 0x800000) { + // If mantissa overflows, we need to increment the exponent + // and reset mantissa. This is rare but important for correctness. + return fp16_t{(uint16_t)(sign | ((exponent + 15 + 1) << 10))}; + } + + return fp16_t{(uint16_t)(sign | ((exponent + 15) << 10) | (rounded_mantissa >> 13))}; + // --- END MODIFICATION --- +} else if (exponent >= -24) { mantissa |= 0x800000; // Add implicit leading 1 mantissa >>= (-14 - exponent); return fp16_t{(uint16_t)(sign | (mantissa >> 13))};