From 8de11f9a3e658f2f960483bc89c75b4f0c6c1b70 Mon Sep 17 00:00:00 2001 From: Shantanu Gontia Date: Sun, 9 Mar 2025 14:23:16 -0700 Subject: [PATCH 1/4] Add Float8E4M3 Code --- pkg/float8e4m3bits/float8e4m3.go | 148 +++++++++++++++++++++ pkg/float8e4m3bits/float8e4m3_constants.go | 26 ++++ pkg/float8e4m3bits/float8e4m3_test.go | 44 ++++++ pkg/float8e4m3bits/rounddown.go | 1 + pkg/float8e4m3bits/roundhalfdown.go | 1 + pkg/float8e4m3bits/roundhalftowardszero.go | 1 + pkg/float8e4m3bits/roundhalfup.go | 2 + pkg/float8e4m3bits/roundnearesteven.go | 1 + pkg/float8e4m3bits/roundnearestodd.go | 1 + pkg/float8e4m3bits/roundtowardszero.go | 1 + pkg/float8e4m3bits/roundup.go | 1 + 11 files changed, 227 insertions(+) create mode 100644 pkg/float8e4m3bits/float8e4m3.go create mode 100644 pkg/float8e4m3bits/float8e4m3_constants.go create mode 100644 pkg/float8e4m3bits/float8e4m3_test.go create mode 100644 pkg/float8e4m3bits/rounddown.go create mode 100644 pkg/float8e4m3bits/roundhalfdown.go create mode 100644 pkg/float8e4m3bits/roundhalftowardszero.go create mode 100644 pkg/float8e4m3bits/roundhalfup.go create mode 100644 pkg/float8e4m3bits/roundnearesteven.go create mode 100644 pkg/float8e4m3bits/roundnearestodd.go create mode 100644 pkg/float8e4m3bits/roundtowardszero.go create mode 100644 pkg/float8e4m3bits/roundup.go diff --git a/pkg/float8e4m3bits/float8e4m3.go b/pkg/float8e4m3bits/float8e4m3.go new file mode 100644 index 0000000..0407c9f --- /dev/null +++ b/pkg/float8e4m3bits/float8e4m3.go @@ -0,0 +1,148 @@ +package F8E4M3 + +import ( + "math" + "math/big" + floatBit "github.com/shantanu-gontia/float-conv/pkg" + F32 "github.com/shantanu-gontia/float-conv/pkg/float32bits" +) + +// Some constants that will help with bit manipulation we'll need to +// perform +const ( + // Mantissa bits retained in float8e4m3 + f32Float8E4M3MantissaMask uint32 = 0b0_00000000_11100000000000000000000 + // Mantissa bits not retained in float8e4m3 + f32Float8E4M3HalfSubnormalMask uint32 = + 0b0_00000000_00011111111111111111111 + // LSB of float8e4m3 and the rest of the extra precision + f32Float8E4M3SubnormalMask uint32 = 0b0_00000000_00111111111111111111111 + // LSB of float8e4m3 + f32Float8E4M3SubnormalLSB uint32 = 0b0_00000000_00100000000000000000000 + // Most significant bit not retained in float8e4m3 + f32Float8E4M3HalfSubnormalLSB uint32 = 0b0_00000000_00010000000000000000000 +) + +// Alias type for uint8. This is used to represent the bits that make up a +// Float8E4M3 number. This type also comes with utility methods to support +// Floating point conversions with different Rounding Modes and Out of Bounds +// responses +type Bits uint8 + +// Alias type for uint8. This is used to represent the FP8E8M0 format which +// is used to specify the scale factor for a OCP MXFP8 floating point format. +// The actual floating point represented by an OCP MXPF8 number is the +// value encoded in the bits scaled by 2^(scale_factor) when scale_factor +// is interpreted as an fp8e8m0 number. For example, to apply no scaling, +// the number passed must be 127, because the actual scale that is +// multiplied is 2^(scale_factor - 127). If the exponent in the resulting +// number is out of the range supported by the format, then the result is +// undefined per the spec. In our case, we clamp to Infinity of the same sign +// as the original number +type ScaleFactor int8 + +// Convert the given [Bits] type to the floating point number it represents, +// inside a float32 value. This is effectively, a bit_cast to float8e4m3, +// followed by a upcast to float32. Since Go doesn't natively support +// float8e4m3 values, this method performs some bit-twiddling, +// to align the bits per the float32 bit representation and then scaling +// the final result with the [ScaleFactor] +func (input Bits) ToFloat32(scaleFactor ScaleFactor) float32 { + asUint8 := uint8(input) + // Extract the Sign, Exponent and Mantissa + signBit := (asUint8 & SignMask) >> 7 + exponentBits := (asUint8 & ExponentMask) >> 3 + mantissaBits := (asUint8 & MantissaMask) + + // Special Values like Inf, NaN etc. need to be handled before applying + // the general algorithm to calculate the number + if asUint8 == PositiveNaN { + return math.Float32frombits(F32.PositiveNaN) + } + if asUint8 == NegativeNaN { + return math.Float32frombits(F32.NegativeNaN) + } + if asUint8 == PositiveZero { + return math.Float32frombits(F32.PositiveZero) + } + if asUint8 == NegativeZero { + return math.Float32frombits(F32.NegativeZero) + } + + // Variables to store the sign, exponent and mantissa bits that will be + // used to construct the float32 number + var float32SignBit, float32ExponentBits, float32MantissaBits uint32 + + float32SignBit = uint32(signBit) << 31 + float32Exponent := 0 + + if exponentBits == 0 { + // Subnormals in F8E4M3, are normals in F32. So, they need special + // handling. If the float8e4m3 number is subnornmal, then it is + // evaluated as (assuming default scale-factor) + // (-1)^sign * 2^(1-7) * (m0/2 + m1/4 + m2/8) + // = (-1)^sign * 2^-6 * (m0/2 + m1/4 + m2/8) + // So, what we want to do is , find the first bit in the mantissa that + // is 1. This will become the implicit precision bit in the float32 + // value. Let's suppose it is m0. In that case, we have + // (-1)^sign * 2^-6 * (1/2 + m1/4 + m2/8) + // = (-1)^sign * 2^-7 * (1 + m1/2 + m2/4) + // So, if the MSB set bit is m0, then the result exponent = Emin - 1 + // and, we need to shift the mantissa to the right when it's in the + // float32 container. And, there is an extra mantissa left-shift by 1 + // Suppose now it's m2. In that case, we have + // (-1)^sign * 2^(-6) * (0/2 + 0/4 + 1/8) + // = (-1)^sign * 2^(-9) * ([m2=1]) + // So, the result exponent is Emin - 3 and, in the float32, the + // mantissa is shifted to the left by 3 bits. + + // We find the value we need to subtract from the minimum exponent + // F32 mantissa starts at the bit index 31 and ends at 23 (inclusive) + // 31 | 30 29 28 27 26 25 24 23 | 22 ... + currMantissaBitMask := uint8(0b0_0000_100) + resultMantissaBits := mantissaBits + resultExponent := ExponentMin + extraShift := 0 + for ; currMantissaBitMask != 0; currMantissaBitMask >>= 1 { + currMantissaBit := currMantissaBitMask & mantissaBits + resultExponent -= 1 + extraShift++ + if currMantissaBit != 0 { + // We need to zero out this one bit, since this is what + // becomes the implicit bit in the float32 + resultMantissaBits = mantissaBits & ^currMantissaBitMask + break + } + } + + // F32 has 23 mantissa bits, and F8E4M3 has 3. Therefore, to align the + // bits, we need to shift to the left by 20 bits. To account for the + // right shift (since in float32 the first set bit is implicit), + // we can actually just shift left by 19 bits instead of 20. + float32MantissaBits = uint32(resultMantissaBits) << (20 + extraShift) + float32Exponent = resultExponent + } else { + // For the normal case, all we need to do is correct the exponent to + // use the bias of the float32 format + float32MantissaBits = uint32(mantissaBits) << 20 + actualExponent := int(exponentBits) - ExponentBias + float32Exponent = actualExponent + } + + // Apply the Scale Factor + float32Exponent = float32Exponent + (int(scaleFactor) - 127) + if (float32Exponent > F32.ExponentMax) { + return math.Float32frombits(float32SignBit | F32.PositiveInfinity) + } else if (float32Exponent < F32.ExponentMin) { + return math.Float32frombits(float32SignBit | F32.PositiveZero) + } + + float32ExponentBits = uint32(float32Exponent + F32.ExponentBias) << 23 + return math.Float32frombits(float32SignBit | float32ExponentBits | + float32MantissaBits) +} + +func handleOverflow(signBit uint32, om floatBit.OverflowMode) (Bits, + big.Accuracy, floatBit.Status) { + return Bits(0), big.Exact, floatBit.Fits +} diff --git a/pkg/float8e4m3bits/float8e4m3_constants.go b/pkg/float8e4m3bits/float8e4m3_constants.go new file mode 100644 index 0000000..70b4ffd --- /dev/null +++ b/pkg/float8e4m3bits/float8e4m3_constants.go @@ -0,0 +1,26 @@ +package F8E4M3 + +const ( + SignMask uint8 = 0b1_0000_000 + ExponentMask uint8 = 0b0_1111_000 + MantissaMask uint8 = 0b0_0000_111 + + // Float8E4M3 doesn't support Infinities + + PositiveMaxNormal uint8 = 0b0_1111_110 + NegativeMaxNormal uint8 = 0b1_1111_110 + + PositiveZero uint8 = 0b0_0000_000 + NegativeZero uint8 = 0b1_0000_000 + + PositiveMinSubnormal uint8 = 0b0_0000_001 + NegativeMinSubnormal uint8 = 0b1_0000_001 + + NaN uint8 = 0b0_1111_111 + PositiveNaN uint8 = 0b0_1111_111 + NegativeNaN uint8 = 0b1_1111_111 + + ExponentBias int = 7 + ExponentMin int = -6 + ExponentMax int = 8 +) diff --git a/pkg/float8e4m3bits/float8e4m3_test.go b/pkg/float8e4m3bits/float8e4m3_test.go new file mode 100644 index 0000000..24b739a --- /dev/null +++ b/pkg/float8e4m3bits/float8e4m3_test.go @@ -0,0 +1,44 @@ +package F8E4M3 + +import ( + "testing" + "math" + F32 "github.com/shantanu-gontia/float-conv/pkg/float32bits" +) + +func TestToFloat32(t *testing.T) { + testCases := []struct{ + // Input + input Bits + scaleFactor int8 + // Output + golden float32 + }{ + { + input: 0b0_0000_000, + scaleFactor: 127, + golden: math.Float32frombits(F32.PositiveZero), + }, + { + input: 0b1_0000_000, + scaleFactor: 127, + golden: math.Float32frombits(F32.NegativeZero), + }, + { + input: 0b1_0000_000, + scaleFactor: 254, + golden: math.Float32frombits(F32.NegativeZero), + } + } + + for _, tt := range testCases { + t.Run("ToFloat32", func(t* testing.T) { + result := tt.input.ToFloat32() + if result != tt.golden { + t.Logf("Failed Input Set:\n") + t.Logf("Input: %0#16b (%0#4x)", tt.input, tt.input) + t.Errorf("Expected Output: %f (%0#8x). Got: %f (%0#8x)", tt.golden, math.Float32bits(tt.golden), result, math.Float32bits(result)) + } + }) + } +} diff --git a/pkg/float8e4m3bits/rounddown.go b/pkg/float8e4m3bits/rounddown.go new file mode 100644 index 0000000..f7a0f50 --- /dev/null +++ b/pkg/float8e4m3bits/rounddown.go @@ -0,0 +1 @@ +package F8E4M3 diff --git a/pkg/float8e4m3bits/roundhalfdown.go b/pkg/float8e4m3bits/roundhalfdown.go new file mode 100644 index 0000000..f7a0f50 --- /dev/null +++ b/pkg/float8e4m3bits/roundhalfdown.go @@ -0,0 +1 @@ +package F8E4M3 diff --git a/pkg/float8e4m3bits/roundhalftowardszero.go b/pkg/float8e4m3bits/roundhalftowardszero.go new file mode 100644 index 0000000..f7a0f50 --- /dev/null +++ b/pkg/float8e4m3bits/roundhalftowardszero.go @@ -0,0 +1 @@ +package F8E4M3 diff --git a/pkg/float8e4m3bits/roundhalfup.go b/pkg/float8e4m3bits/roundhalfup.go new file mode 100644 index 0000000..cf9475b --- /dev/null +++ b/pkg/float8e4m3bits/roundhalfup.go @@ -0,0 +1,2 @@ +package F8E4M3 + diff --git a/pkg/float8e4m3bits/roundnearesteven.go b/pkg/float8e4m3bits/roundnearesteven.go new file mode 100644 index 0000000..f7a0f50 --- /dev/null +++ b/pkg/float8e4m3bits/roundnearesteven.go @@ -0,0 +1 @@ +package F8E4M3 diff --git a/pkg/float8e4m3bits/roundnearestodd.go b/pkg/float8e4m3bits/roundnearestodd.go new file mode 100644 index 0000000..f7a0f50 --- /dev/null +++ b/pkg/float8e4m3bits/roundnearestodd.go @@ -0,0 +1 @@ +package F8E4M3 diff --git a/pkg/float8e4m3bits/roundtowardszero.go b/pkg/float8e4m3bits/roundtowardszero.go new file mode 100644 index 0000000..f7a0f50 --- /dev/null +++ b/pkg/float8e4m3bits/roundtowardszero.go @@ -0,0 +1 @@ +package F8E4M3 diff --git a/pkg/float8e4m3bits/roundup.go b/pkg/float8e4m3bits/roundup.go new file mode 100644 index 0000000..f7a0f50 --- /dev/null +++ b/pkg/float8e4m3bits/roundup.go @@ -0,0 +1 @@ +package F8E4M3 From f4db2589f8ad64e956ba449a93b13b6423b98e6e Mon Sep 17 00:00:00 2001 From: Shantanu Gontia Date: Sun, 9 Mar 2025 16:32:39 -0700 Subject: [PATCH 2/4] Float16 Typo --- pkg/float16bits/float16.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/float16bits/float16.go b/pkg/float16bits/float16.go index d21c967..5551eea 100644 --- a/pkg/float16bits/float16.go +++ b/pkg/float16bits/float16.go @@ -85,7 +85,7 @@ func (input Bits) ToFloat32() float32 { // = (-1)^sign * 2^(-15) * (1 + m1/2 + m2/4 + m3/8 + ...) // So, if the MSB set bit is m0, then the result exponent = Emin - 1 // and, we need to shift the mantissa to the right when it's in the - // float32 container. And, there as an extra mantissa left-shift by + // float32 container. And, there is an extra mantissa left-shift by // 1 // Let's now say it's m2. In that case, we have // (-1)^sign * 2^(-14) * (0/2 + 0/4 + 1/8 + m3/16 + ...) From 83e9c93a7c3dcdf776b2c37df5e7c7e06370b82c Mon Sep 17 00:00:00 2001 From: Shantanu Gontia Date: Fri, 14 Mar 2025 03:02:15 -0700 Subject: [PATCH 3/4] Add Rounding Modes for FP8E4M3, fix some F16 comments --- pkg/float16bits/float16.go | 14 +- pkg/float16bits/float16_test.go | 10 +- pkg/float16bits/rounddown.go | 9 +- pkg/float16bits/roundhalftowardszero.go | 2 +- pkg/float16bits/roundnearesteven.go | 12 +- pkg/float16bits/roundnearestodd.go | 14 +- pkg/float16bits/roundtowardszero.go | 2 +- pkg/float8e4m3bits/float8e4m3.go | 182 ++- pkg/float8e4m3bits/float8e4m3_test.go | 1377 +++++++++++++++++++- pkg/float8e4m3bits/rounddown.go | 54 + pkg/float8e4m3bits/roundhalfdown.go | 65 + pkg/float8e4m3bits/roundhalftowardszero.go | 58 + pkg/float8e4m3bits/roundhalfup.go | 65 + pkg/float8e4m3bits/roundnearesteven.go | 79 ++ pkg/float8e4m3bits/roundnearestodd.go | 80 ++ pkg/float8e4m3bits/roundtowardszero.go | 39 + pkg/float8e4m3bits/roundup.go | 53 + 17 files changed, 2068 insertions(+), 47 deletions(-) diff --git a/pkg/float16bits/float16.go b/pkg/float16bits/float16.go index 5551eea..61fce62 100644 --- a/pkg/float16bits/float16.go +++ b/pkg/float16bits/float16.go @@ -181,7 +181,7 @@ func FromBigFloat(input big.Float, rm floatBit.RoundingMode, // the underflow response in the BF16 methods. // F32.PositiveMinSubnormal will trigger underflow response in BF16 asFloat32 = math.Float32frombits(F32.PositiveMinSubnormal) - } else if closestFloat32 == -0.0 && fromBigFloatAcc == big.Above { + } else if closestFloat32 == math.Float32frombits(F32.NegativeZero) && fromBigFloatAcc == big.Above { // And for the negative case // F32.NegativeMinSubnormal will trigger underflow response in BF16 asFloat32 = math.Float32frombits(F32.NegativeMinSubnormal) @@ -386,9 +386,7 @@ func FromFloat32(input float32, rm floatBit.RoundingMode, // Utility function to check if the number with the given exponent and mantissa // bits would overflow when trying to represent it in a float16 value -// exponentBits should correspond to bits which are encoded with the float16 -// bias in mind. mantissaBits should occupy the bits with the float32 format -// in mind. +// mantissaBits should occupy the bits as they would in a float32 number func checkOverflow(actualExponent int, mantissaBits uint32) bool { // If the exponent is larger than the max, then it's overflow if actualExponent > ExponentMax { @@ -408,11 +406,11 @@ func checkOverflow(actualExponent int, mantissaBits uint32) bool { } // Utility function to check if the number with the given exponent and mantissa -// bits would overflow when trying to represent it in a float16 value +// bits would underflow when trying to represent it in a float16 value // Subnormals require shifting the mantissa to align the exponents. This might // cause loss of precision that cannot be detected by mantissaBits alone as // they are already shifted. The lostPrecision parameter helps us with that. -// If it's true then there was precision lost when mantissa was being aligned +// If it's true then there was precision lost when mantissa was being aligned. func checkUnderflow(mantissaBits uint32, lostPrecision bool) bool { // This assumes that the exponent is 0, so any extra precision in the // mantissa means underflow. @@ -496,7 +494,7 @@ func (b *Bits) ToFloatFormat() floatBit.FloatBitFormat { // 5 Exponent Bits exponentRetVal := make([]byte, 0, 5) - for i := 0; i < 5; i++ { + for range 5 { currentExponentBit := exponentBits & 0x1 var valueToAppend byte if currentExponentBit == 0 { @@ -510,7 +508,7 @@ func (b *Bits) ToFloatFormat() floatBit.FloatBitFormat { // 10 Mantissa Bits mantissaRetVal := make([]byte, 0, 10) - for i := 0; i < 10; i++ { + for range 10 { currentMantissaBit := mantissaBits & 0x1 var valueToAppend byte if currentMantissaBit == 0 { diff --git a/pkg/float16bits/float16_test.go b/pkg/float16bits/float16_test.go index 2e1e48d..85786e3 100644 --- a/pkg/float16bits/float16_test.go +++ b/pkg/float16bits/float16_test.go @@ -38,7 +38,7 @@ func TestToFloat32(t *testing.T) { }, { input: 0b1_00000_0000000000, - golden: -0.0, + golden: math.Float32frombits(F32.NegativeZero), }, { input: 0b0_00000_1111111111, @@ -400,8 +400,8 @@ func TestRoundTowardsPositiveInf(t *testing.T) { func TestRoundTowardsNegativeInf(t *testing.T) { - // Rounding towards positive infinity involves adding 1 if the number - // is positive, otherwise truncating, so that the number is closer to +inf + // Rounding towards negative infinity involves adding 1 if the number + // is negative, otherwise truncating, so that the number is closer to -inf testCases := []struct { // Inputs @@ -1021,7 +1021,7 @@ func TestRoundHalfTowardsNegativeInf(t *testing.T) { func TestRoundNearestEven(t *testing.T) { // Rounding half towards positive infinity involves rounding to the nearest // representable number, and breaking ties by rounding towards the - // number closer to +inf + // number with LSB = 0 testCases := []struct { // Inputs @@ -1240,7 +1240,7 @@ func TestRoundNearestEven(t *testing.T) { func TestRoundNearestOdd(t *testing.T) { // Rounding half towards positive infinity involves rounding to the nearest // representable number, and breaking ties by rounding towards the - // number closer to +inf + // number with LSB = 1 testCases := []struct { // Inputs diff --git a/pkg/float16bits/rounddown.go b/pkg/float16bits/rounddown.go index 295e76c..967c03b 100644 --- a/pkg/float16bits/rounddown.go +++ b/pkg/float16bits/rounddown.go @@ -28,12 +28,12 @@ func roundDown(signBit, exponentBits, float16Mantissa := uint16(mantissaF16Precision >> 13) // For this rounding mode, we only need to add 1 to the Least-precision - // mantissa, if the input was positive, to bring it closer to +inf. - // For negative numbers, this is achieved by simply truncating. + // mantissa, if the input was negative, to bring it closer to -inf. + // For positive numbers, this is achieved by simply truncating. exponentMantissaComposite := (float16Exponent | float16Mantissa) - // If positive and there is extra precision, then add 1 + // If negative and there is extra precision, then add 1 if (float16Sign != 0) && (mantissaExtraPrecision != 0 || lostPrecision) { exponentMantissaComposite += 1 } @@ -41,7 +41,8 @@ func roundDown(signBit, exponentBits, resultVal := Bits(float16Sign | exponentMantissaComposite) resultAcc := big.Exact - // If there was extra precision bits set, then we need to + // If there was extra precision bits set, then we need to update the + // accuracy if mantissaExtraPrecision != 0 || lostPrecision { // We always round to a smaller value resultAcc = big.Below diff --git a/pkg/float16bits/roundhalftowardszero.go b/pkg/float16bits/roundhalftowardszero.go index 066a53a..7cf95e4 100644 --- a/pkg/float16bits/roundhalftowardszero.go +++ b/pkg/float16bits/roundhalftowardszero.go @@ -24,7 +24,7 @@ func roundHalfTowardsZero(signBit, exponentBits, exponentMantissaComposite := float16Exponent | float16Mantissa // If the extra precision bits exceed 1 0 0 0 0.... - // we need to add 1 to LSB of F32 mantissa, otherwise truncate + // we need to add 1 to LSB of F32 mantissa. // For all other cases we truncate addedOne := false if mantissaExtraPrecision > f32Float16HalfSubnormalLSB { diff --git a/pkg/float16bits/roundnearesteven.go b/pkg/float16bits/roundnearesteven.go index 62dad33..6431e96 100644 --- a/pkg/float16bits/roundnearesteven.go +++ b/pkg/float16bits/roundnearesteven.go @@ -19,12 +19,12 @@ func roundNearestEven(signBit, exponentBits, mantissaBits uint32, // break ties by rounding towards the number that is even (LSB is 0) // LSB | Extra Precision Bits - // m9 m8 m8 m7 - // 1. if m29 m28 m27 ... > 1 0 0 0 ... (more than half) we round up - // 2. if m29 m28 m27 ... < 1 0 0 0 ... (less than half) we truncate - // 3. if m29 m28 m27 ... == 1 0 0 0 ... (exactly half), then - // 3.1 m30 == 0, we truncate - // 3.2 m30 == 1, we round up + // m13 m12 m11 m10 ... m0 + // 1. if m12 m11 m10 ... > 1 0 0 0 ... (more than half) we round up + // 2. if m12 m11 m10 ... < 1 0 0 0 ... (less than half) we truncate + // 3. if m12 m11 m10 ... == 1 0 0 0 ... (exactly half), then + // 3.1 m13 == 0, we truncate + // 3.2 m13 == 1, we round up mantissaF16Precision := mantissaBits & f32Float16MantissaMask mantissaExtraPrecision := mantissaBits & f32Float16HalfSubnormalMask diff --git a/pkg/float16bits/roundnearestodd.go b/pkg/float16bits/roundnearestodd.go index 0e35ecc..8357bfb 100644 --- a/pkg/float16bits/roundnearestodd.go +++ b/pkg/float16bits/roundnearestodd.go @@ -22,12 +22,12 @@ func roundNearestOdd(signBit, exponentBits, mantissaBits uint32, // break ties by rounding towards the number that is even (LSB is 0) // LSB | Extra Precision Bits - // m9 m8 m7 m6 - // 1. if m8 m7 m6 ... > 1 0 0 0 ... (more than half) we round up - // 2. if m8 m7 m6 ... < 1 0 0 0 ... (less than half) we truncate - // 3. if m8 m7 m6 ... == 1 0 0 0 ... (exactly half), then - // 3.1 m9 == 1, we truncate - // 3.2 m9 == 0, we round up + // m13 m12 m11 m10 ... m0 + // 1. if m12 m11 m10 ... > 1 0 0 0 ... (more than half) we round up + // 2. if m12 m11 m10 ... < 1 0 0 0 ... (less than half) we truncate + // 3. if m12 m11 m10 ... == 1 0 0 0 ... (exactly half), then + // 3.1 m13 == 1, we truncate + // 3.2 m13 == 0, we round up mantissaF16Precision := mantissaBits & f32Float16MantissaMask mantissaExtraPrecision := mantissaBits & f32Float16HalfSubnormalMask @@ -55,7 +55,7 @@ func roundNearestOdd(signBit, exponentBits, mantissaBits uint32, mantissaF32LSB := mantissaBits & f32Float16SubnormalLSB // In the case we're at the mid-point, we only add 1, if the LSB of the - // float32 retained mantissa is 0 + // float16 retained mantissa is 0 if (mantissaF32LSB == 0) && (mantissaExtraPrecision == f32Float16HalfSubnormalLSB) && !lostPrecision { exponentMantissaComposite += 1 diff --git a/pkg/float16bits/roundtowardszero.go b/pkg/float16bits/roundtowardszero.go index bd49122..97645c1 100644 --- a/pkg/float16bits/roundtowardszero.go +++ b/pkg/float16bits/roundtowardszero.go @@ -31,7 +31,7 @@ func truncate(signBit, exponentBits, mantissaBits uint32, resultAcc := big.Exact // If there was extra precision, then the number did not fit in the - // float32 format, so we need to report the status appropriately + // float16 format, so we need to report the status appropriately if mantissaExtraPrecision != 0 || lostPrecision { if signBit == 0 { resultAcc = big.Below diff --git a/pkg/float8e4m3bits/float8e4m3.go b/pkg/float8e4m3bits/float8e4m3.go index 0407c9f..5887961 100644 --- a/pkg/float8e4m3bits/float8e4m3.go +++ b/pkg/float8e4m3bits/float8e4m3.go @@ -1,10 +1,13 @@ package F8E4M3 import ( - "math" - "math/big" - floatBit "github.com/shantanu-gontia/float-conv/pkg" - F32 "github.com/shantanu-gontia/float-conv/pkg/float32bits" + "errors" + "math" + "math/big" + "slices" + + floatBit "github.com/shantanu-gontia/float-conv/pkg" + F32 "github.com/shantanu-gontia/float-conv/pkg/float32bits" ) // Some constants that will help with bit manipulation we'll need to @@ -39,7 +42,7 @@ type Bits uint8 // number is out of the range supported by the format, then the result is // undefined per the spec. In our case, we clamp to Infinity of the same sign // as the original number -type ScaleFactor int8 +type ScaleFactor uint8 // Convert the given [Bits] type to the floating point number it represents, // inside a float32 value. This is effectively, a bit_cast to float8e4m3, @@ -54,6 +57,14 @@ func (input Bits) ToFloat32(scaleFactor ScaleFactor) float32 { exponentBits := (asUint8 & ExponentMask) >> 3 mantissaBits := (asUint8 & MantissaMask) + // in F8E8M0, 255 signals NaN + if (scaleFactor == 255) { + if (signBit == 0) { + return math.Float32frombits(F32.PositiveNaN) + } + return math.Float32frombits(F32.NegativeNaN) + } + // Special Values like Inf, NaN etc. need to be handled before applying // the general algorithm to calculate the number if asUint8 == PositiveNaN { @@ -142,7 +153,164 @@ func (input Bits) ToFloat32(scaleFactor ScaleFactor) float32 { float32MantissaBits) } -func handleOverflow(signBit uint32, om floatBit.OverflowMode) (Bits, - big.Accuracy, floatBit.Status) { +// Convert the given [Bits] type to a [big.Float] arbitrary precision +// floating-point number with the given scale factor (Note that +// this works by intermediate conversion to float32, so if the number +// is not representable in float32 with the given scale factor, the result +// is same as what would be after conversion to float32 +func (input Bits) ToBigFloat(scaleFactor ScaleFactor) big.Float { + asFloat32 := input.ToFloat32(scaleFactor) + asBigFloat := *big.NewFloat(float64(asFloat32)) + return asBigFloat +} + +// Convert the given [big.Float] arbitrary-precison floating-point number +// to a [Bits] type representing the bits of a float8e4m3 number. The input +// is scaled by the given scaleFactor. If the number cannot be represented +// in the float8e4m3 format exactly, then the rounding mode, overflow mode, +// and underflow mode decide the result. Returns the result [Bits], +// a [big.Accuracy] which encodes whether the result value was the same, +// larger, or smaller than the input, and a [floatBit.Status] which encodes, +// whether the result ft in the [Bits], caused overflow, or underflow. +func FromBigFloat(input big.Float, scaleFactor ScaleFactor, + rm floatBit.RoundingMode, om floatBit.OverflowMode, + um floatBit.UnderflowMode) (Bits, big.Accuracy, floatBit.Status) { return Bits(0), big.Exact, floatBit.Fits } + +// Convert the given float32 number into a [Bits] type which represents the +// bits of a OCP MXFP8E4M3 number. Signature and usage is identical to +// [FromBigFloat] except the input is a float32 +func FromFloat32(input float32, scaleFactor ScaleFactor, + rm floatBit.RoundingMode, om floatBit.OverflowMode, + um floatBit.UnderflowMode) (Bits, big.Accuracy, floatBit.Status) { + return Bits(0), big.Exact, floatBit.Fits +} + +// Utility function to check if the number with the given exponent and mantissa +// bits would overflow when trying to represent it in a float8e4m3 value +// with the given scale factor. +// mantissaBits should occupy the bits as they would in a float32 number +func checkOverflow(actualExponent int, mantissaBits uint32, + scaleFactor ScaleFactor) bool { + + // Remove the scaling + scaledExponent := actualExponent - (int(scaleFactor) - 127) + + if scaledExponent > ExponentMax { + return true + } + + // If the exponent is equal to the maximum exponent, and all the float32 + // mantissa bits are set, but there is additional precision in the number + // than can be represented in float32, then it exceeds the maximum normal + // and so, overflows. + if (scaledExponent == ExponentMax) && + (mantissaBits & 0b0_00000000_11000000000000000000000 == + 0b0_00000000_11000000000000000000000) && + (mantissaBits & f32Float8E4M3SubnormalMask > 0){ + return true + } + return false +} + +// Utility function to check if the number with the given exponent and mantissa +// bits would underflow when trying to represent it in a float8e4m3 value +// with the given scaleFactor. Subnormals require shifting the mantissa to +// align the exponents. This might cause loss of precision that cannot be +// detected by the mantissaBits alone as they are already shifted. The +// lostPrecision parameter helps us with that. If it's set to true then there +// was precision lost when the mantissa was being aligned. +func checkUnderflow(mantissaBits uint32, lostPrecision bool) bool { + // This assumes that the exponent after scaling is 0, so any extra + // precision in the mantissa means underflow + f8e4m3PrecisionMantissa := mantissaBits & f32Float8E4M3MantissaMask + f8e4m3ExtraPrecisionMantissa := mantissaBits & + f32Float8E4M3HalfSubnormalMask + if f8e4m3PrecisionMantissa == 0 && (f8e4m3ExtraPrecisionMantissa != 0 || + lostPrecision) { + return true + } + return false +} + +// ToFloatFormat converts the given [Bits] type representing the bits that +// make up a OCP MXFP8E4M3 umber into [floatBit.FloatBitFormat] +// Implements the FloatBitFormatter interface +func (b* Bits) ToFloatFormat() floatBit.FloatBitFormat { + // Iterate over the bits and construct the return values + + asUint := uint8(*b) + signBits := (asUint & SignMask) >> 7 + exponentBits := (asUint & ExponentMask) >> 3 + mantissaBits := (asUint & MantissaMask) + + // 1 Sign Bit + signRetVal := make([]byte, 0, 1) + if signBits == 0 { + signRetVal = append(signRetVal, byte('0')) + } else { + signRetVal = append(signRetVal, byte('1')) + } + + // 4 Exponent Bits + exponentRetVal := make([]byte, 0, 4) + for range 5 { + currentExponentBit := exponentBits & 0x1 + var valueToAppend byte + if currentExponentBit == 0 { + valueToAppend = '0' + } else { + valueToAppend = '1' + } + exponentRetVal = append(exponentRetVal, valueToAppend) + exponentBits >>= 1 + } + + // 3 Mantissa Bits + mantissaRetVal := make([]byte, 0, 3) + for range 3 { + currentMantissaBit := mantissaBits & 0x1 + var valueToAppend byte + if currentMantissaBit == 0 { + valueToAppend = '0' + } else { + valueToAppend = '1' + } + mantissaRetVal = append(mantissaRetVal, valueToAppend) + mantissaBits >>= 1 + } + slices.Reverse(mantissaRetVal) + + return floatBit.FloatBitFormat{Sign: signRetVal, Exponent: exponentRetVal, + Mantissa: mantissaRetVal} +} + +// Conversion error returns the difference between the input [big.Float] +// number and the float8e4m3 number represented by the bits in the [Bits] +// receiver when it's scaled with the given scale factor +func (b* Bits) ConversionError(input* big.Float, scaleFactor ScaleFactor) ( + big.Float, error) { + // If the receiver is a NaN then we return an error + asFloat32 := b.ToFloat32(scaleFactor) + if math.IsNaN(float64(asFloat32)) { + return *big.NewFloat(0.0), errors.New("NaN encountered") + } + + // Positive Infinity == Positiive Infinity + if math.IsInf(float64(asFloat32), 1) && + (input.IsInf() && input.Sign() > 0) { + return *big.NewFloat(0), nil + } + + // Negative Infinity == Negative Infinity + if math.IsInf(float64(asFloat32), -1) && + (input.IsInf() && input.Sign() < 0) { + return *big.NewFloat(0), nil + } + + asBigFloat := b.ToBigFloat(scaleFactor) + convDiff := asBigFloat.Sub(&asBigFloat, input) + return *convDiff, nil +} + diff --git a/pkg/float8e4m3bits/float8e4m3_test.go b/pkg/float8e4m3bits/float8e4m3_test.go index 24b739a..c300690 100644 --- a/pkg/float8e4m3bits/float8e4m3_test.go +++ b/pkg/float8e4m3bits/float8e4m3_test.go @@ -1,16 +1,19 @@ package F8E4M3 import ( - "testing" - "math" - F32 "github.com/shantanu-gontia/float-conv/pkg/float32bits" + "math" + "math/big" + "testing" + + floatBit "github.com/shantanu-gontia/float-conv/pkg" + F32 "github.com/shantanu-gontia/float-conv/pkg/float32bits" ) func TestToFloat32(t *testing.T) { testCases := []struct{ // Input input Bits - scaleFactor int8 + scaleFactor ScaleFactor // Output golden float32 }{ @@ -28,17 +31,1375 @@ func TestToFloat32(t *testing.T) { input: 0b1_0000_000, scaleFactor: 254, golden: math.Float32frombits(F32.NegativeZero), - } + }, + { + input: 0b0_0000_001, + scaleFactor: 255, + golden: math.Float32frombits(F32.PositiveNaN), + }, + { + input: 0b1_0000_100, + scaleFactor: 255, + golden: math.Float32frombits(F32.NegativeNaN), + }, + { + input: 0b0_0000_001, + scaleFactor: 254, + // 2^(118) + golden: math.Float32frombits(0b0_11110101_00000000000000000000000), + }, + { + input: 0b1_0000_100, + scaleFactor: 254, + // -2^(-7) * 2^(127) = -2^(120) + golden: math.Float32frombits(0b1_11110111_00000000000000000000000), + }, + { + input: 0b0_1001_000, + scaleFactor: 254, + // 2^(9-7) => 2 + (254 - 127) = 129 => Infinity + golden: math.Float32frombits(F32.PositiveInfinity), + }, + { + input: 0b1_1001_000, + scaleFactor: 254, + golden: math.Float32frombits(F32.NegativeInfinity), + }, + { + input: 0b0_0000_001, + scaleFactor: 127, + // 2^(-9) => -9 + 127 = 118 + golden: math.Float32frombits(0b0_01110110_00000000000000000000000), + }, + { + input: 0b1_0000_001, + scaleFactor: 0, + // 2^(-9) => -9 -127 + 127 = -9 => 0 + golden: math.Float32frombits(F32.NegativeZero), + }, + { + input: 0b0_0111_001, + scaleFactor: 50, + // 2^(7-7) => 0 + (50 - 127) + 127 = 50 + golden: math.Float32frombits(0b0_00110010_00100000000000000000000), + }, + { + input: Bits(PositiveNaN), + scaleFactor: 127, + golden: math.Float32frombits(F32.PositiveNaN), + }, + { + input: Bits(NegativeNaN), + scaleFactor: 127, + golden: math.Float32frombits(F32.NegativeNaN), + }, + { + input: Bits(PositiveMaxNormal), + scaleFactor: 127, + // Exponent bits -> 2^(15-7) => 8 + (127 - 127) Scale + 127 f32Bias = 135 + golden: math.Float32frombits(0b0_10000111_11000000000000000000000), + }, + { + input: Bits(NegativeMaxNormal), + scaleFactor: 246, + // Exponent bits -> 2^(15-7) => 8 + (254 - 127) Scale + 127 f32Bias = 254 + golden: math.Float32frombits(0b1_11111110_11000000000000000000000), + }, } for _, tt := range testCases { t.Run("ToFloat32", func(t* testing.T) { - result := tt.input.ToFloat32() - if result != tt.golden { + result := tt.input.ToFloat32(tt.scaleFactor) + if math.Float32bits(result) != math.Float32bits(tt.golden) { t.Logf("Failed Input Set:\n") - t.Logf("Input: %0#16b (%0#4x)", tt.input, tt.input) + t.Logf("Input: %0#8b (%0#2x)", tt.input, tt.input) + t.Logf("Scale Factor: %d (%d)", tt.scaleFactor, int(tt.scaleFactor) - 127) t.Errorf("Expected Output: %f (%0#8x). Got: %f (%0#8x)", tt.golden, math.Float32bits(tt.golden), result, math.Float32bits(result)) } }) } } + +func TestCheckOverflow(t* testing.T) { + testCases := []struct{ + // Inputs + actualExponent int + mantissaBits uint32 + scaleFactor ScaleFactor + // Outputs + golden bool + }{ + { + actualExponent: 8, + mantissaBits: 0b0_00000000_11100000000000000000000, + scaleFactor: 127, + golden: true, + }, + { + actualExponent: 8, + mantissaBits: 0b0_00000000_11000000000000000000000, + scaleFactor: 127, + golden: false, + }, + { + actualExponent: 127, + mantissaBits: 0b0_00000000_11100000000000000000000, + scaleFactor: 254, + golden: false, + }, + { + actualExponent: 55, + mantissaBits: 0b0_00000000_11000000000000000000000, + scaleFactor: 55, + golden: true, + }, + } + + for _, tt := range testCases { + result := checkOverflow(tt.actualExponent, tt.mantissaBits, tt.scaleFactor) + if result != tt.golden { + t.Logf("Failed Input Set:\n") + t.Logf("Exponent: %d Mantissa Bits: %#08x, Scale Factor: %d (%d)", tt.actualExponent, tt.mantissaBits, tt.scaleFactor, int(tt.scaleFactor) - 127) + t.Errorf("Expected: %v, Got: %v", tt.golden, result) + } + } +} + + +func TestCheckUnderflow(t* testing.T) { + testCases := []struct{ + // Inputs + mantissaBits uint32 + lostPrecision bool + // Outputs + golden bool + }{ + { + mantissaBits: 0b0_00000000_11100000000000000000000, + lostPrecision: false, + golden: false, + }, + { + mantissaBits: 0b0_00000000_00010000000000000000000, + lostPrecision: false, + golden: true, + }, + { + mantissaBits: 0b0_00000000_00100000000000000000000, + lostPrecision: true, + golden: false, + }, + { + mantissaBits: 0b0_00000000_00000000000000000000000, + lostPrecision: true, + golden: true, + }, + { + mantissaBits: 0b0_00000000_00000000000000000000001, + lostPrecision: false, + golden: true, + }, + } + + for _, tt := range testCases { + result := checkUnderflow(tt.mantissaBits, tt.lostPrecision) + if result != tt.golden { + t.Logf("Failed Input Set:\n") + t.Logf("Mantissa Bits: %#016x, lostPrecison: %v", tt.mantissaBits, tt.lostPrecision) + t.Errorf("Expected: %v, Got: %v", tt.golden, result) + } + } +} + +func TestRoundTowardsPositiveInf(t *testing.T) { + testCases := []struct { + // Inputs + signBit uint32 + exponentBits uint32 + mantissaBits uint32 + lostPrecision bool + // Outputs + goldenVal Bits + goldenAcc big.Accuracy + }{ + // Exact + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_010_00000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0010_010), + goldenAcc: big.Exact, + }, + // Normal, Rounds Up (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_010_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0010_011), + goldenAcc: big.Above, + }, + // Normal, Truncates (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_010_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0010_010), + goldenAcc: big.Above, + }, + // Subnormal, rounds up (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_010_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0000_011), + goldenAcc: big.Above, + }, + // Subornmal, truncates (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_010_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0000_010), + goldenAcc: big.Above, + }, + // Cases where precision was lost + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_00000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b0_0000_010), + goldenAcc: big.Above, + }, + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_00000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b1_0000_001), + goldenAcc: big.Above, + }, + } + + for _, tt := range testCases { + t.Run("RoundTowardsPositiveInf", func(t *testing.T) { + resultVal, resultAcc := roundTowardsPositiveInf(tt.signBit, tt.exponentBits, tt.mantissaBits, tt.lostPrecision) + if (resultVal != tt.goldenVal) || (resultAcc != tt.goldenAcc) { + t.Logf("Failed Input Set:\n") + t.Logf("signBit: %#08x, exponentBits: %#08x, mantissaBits: %#08x", tt.signBit, tt.exponentBits, tt.mantissaBits) + t.Logf("lostPrecision: %v", tt.lostPrecision) + t.Errorf("Expected Result: %0#2x, Got: %0#2x\n", tt.goldenVal, resultVal) + t.Errorf("Expected Accuracy: %v, Got: %v\n", tt.goldenAcc, resultAcc) + } + }) + } + +} + +func TestRoundTowardsNegativeInf(t *testing.T) { + testCases := []struct { + // Inputs + signBit uint32 + exponentBits uint32 + mantissaBits uint32 + lostPrecision bool + // Outputs + goldenVal Bits + goldenAcc big.Accuracy + }{ + // Exact + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_00000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0010_000), + goldenAcc: big.Exact, + }, + // Normal, Truncate (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_001_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0010_001), + goldenAcc: big.Below, + }, + // Normal, Rounds Up (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_001_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0010_010), + goldenAcc: big.Below, + }, + // Subnormal, Truncate (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0000_001), + goldenAcc: big.Below, + }, + // Subnormal, Rounds Up (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0000_010), + goldenAcc: big.Below, + }, + // Cases where precision was lost + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_00000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b0_0000_001), + goldenAcc: big.Below, + }, + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_00000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b1_0000_010), + goldenAcc: big.Below, + }, + } + + for _, tt := range testCases { + t.Run("RoundTowardsNegativeInf", func(t *testing.T) { + resultVal, resultAcc := roundTowardsNegativeInf(tt.signBit, tt.exponentBits, tt.mantissaBits, tt.lostPrecision) + if (resultVal != tt.goldenVal) || (resultAcc != tt.goldenAcc) { + t.Logf("Failed Input Set:\n") + t.Logf("signBit: %#08x, exponentBits: %#08x, mantissaBits: %#08x", tt.signBit, tt.exponentBits, tt.mantissaBits) + t.Logf("lostPrecision: %v\n", tt.lostPrecision) + t.Errorf("Expected Result: %0#2x, Got: %0#2x\n", tt.goldenVal, resultVal) + t.Errorf("Expected Accuracy: %v, Got: %v\n", tt.goldenAcc, resultAcc) + } + }) + } +} + +func TestRoundTowardsZero(t *testing.T) { + testCases := []struct { + // Inputs + signBit uint32 + exponentBits uint32 + mantissaBits uint32 + lostPrecision bool + // Outputs + goldenVal Bits + goldenAcc big.Accuracy + }{ + // Zero -> Zero + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0x0, + lostPrecision: false, + goldenVal: Bits(0x0), + goldenAcc: big.Exact, + }, + // Exact + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_00000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0000_001), + goldenAcc: big.Exact, + }, + // Positive RTZ to below + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_01000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0000_001), + goldenAcc: big.Below, + }, + // Negative RTZ to above + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_011_01000000000000100001, + lostPrecision: false, + goldenVal: Bits(0b1_0000_011), + goldenAcc: big.Above, + }, + // Lost precision before got pased into func + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_011_00000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b1_0000_011), + goldenAcc: big.Above, + }, + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_011_00000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b0_0000_011), + goldenAcc: big.Below, + }, + } + + for _, tt := range testCases { + t.Run("RoundTowardsZero", func(t *testing.T) { + resultVal, resultAcc := roundTowardsZero(tt.signBit, tt.exponentBits, tt.mantissaBits, tt.lostPrecision) + if (resultVal != tt.goldenVal) || (resultAcc != tt.goldenAcc) { + t.Logf("Failed Input Set:\n") + t.Logf("signBit: %#08x, exponentBits: %#08x, mantissaBits: %#08x", tt.signBit, tt.exponentBits, tt.mantissaBits) + t.Logf("lostPrecision: %v\n", tt.lostPrecision) + t.Errorf("Expected Result: %0#2x, Got: %0#2x\n", tt.goldenVal, resultVal) + t.Errorf("Expected Accuracy: %v, Got: %v\n", tt.goldenAcc, resultAcc) + } + }) + } +} + +func TestRoundHalfTowardsPositiveInf(t *testing.T) { + + // Rounding half towards positive infinity involves rounding to the nearest + // representable number, and breaking ties by rounding towards the + // number closer to +inf + + testCases := []struct { + // Inputs + signBit uint32 + exponentBits uint32 + mantissaBits uint32 + lostPrecision bool + // Outputs + goldenVal Bits + goldenAcc big.Accuracy + }{ + // Exact + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_010_00000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0010_010), + goldenAcc: big.Exact, + }, + // Normal, truncate, because closer to truncated value (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0010_000), + goldenAcc: big.Above, + }, + // Normal, truncate, because closer to truncated value (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0010_000), + goldenAcc: big.Below, + }, + // Normal, round up, because closer to rounded up value (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0010_001), + goldenAcc: big.Below, + }, + // Normal, round up, because closer to rounded up value (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0010_001), + goldenAcc: big.Above, + }, + // Normal, round up, because half-way (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0010_001), + goldenAcc: big.Above, + }, + // Normal, truncate, because half-way (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0010_000), + goldenAcc: big.Above, + }, + // Subnormal, truncate, because closer to truncated value (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_01000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0000_000), + goldenAcc: big.Above, + }, + // Subnormal, truncate, because closer to truncated value (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_01000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0000_000), + goldenAcc: big.Below, + }, + // Subnormal, round up, because closer to rounded up value (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0000_001), + goldenAcc: big.Below, + }, + // Subnormal, round up, because closr to rounded up value (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0000_001), + goldenAcc: big.Above, + }, + // Subnormal, round up, because half-way (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0000_001), + goldenAcc: big.Above, + }, + // Subnormal, truncate, because half-way (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0000_000), + goldenAcc: big.Above, + }, + // Lost Precision + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_10000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b1_0000_010), + goldenAcc: big.Below, + }, + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_10000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b0_0000_010), + goldenAcc: big.Above, + }, + } + + for _, tt := range testCases { + t.Run("RoundHalfTowardsPositiveInf", func(t *testing.T) { + resultVal, resultAcc := roundHalfTowardsPositiveInf(tt.signBit, tt.exponentBits, tt.mantissaBits, tt.lostPrecision) + if (resultVal != tt.goldenVal) || (resultAcc != tt.goldenAcc) { + t.Logf("Failed Input Set:\n") + t.Logf("signBit: %#08x, exponentBits: %#08x, mantissaBits: %#08x", tt.signBit, tt.exponentBits, tt.mantissaBits) + t.Logf("lostPrecision: %v", tt.lostPrecision) + t.Errorf("Expected Result: %0#2x, Got: %0#2x\n", tt.goldenVal, resultVal) + t.Errorf("Expected Accuracy: %v, Got: %v\n", tt.goldenAcc, resultAcc) + } + }) + } +} + +func TestRoundHalfTowardsNegativeInf(t *testing.T) { + + // Rounding half towards negative infinity involves rounding to the nearest + // representable number, and breaking ties by rounding towards the + // number closer to -inf + + testCases := []struct { + // Inputs + signBit uint32 + exponentBits uint32 + mantissaBits uint32 + lostPrecision bool + // Outputs + goldenVal Bits + goldenAcc big.Accuracy + }{ + // Exact + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_010_00000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0010_010), + goldenAcc: big.Exact, + }, + // Normal, truncate, because closer to truncated value (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0010_000), + goldenAcc: big.Above, + }, + // Normal, truncate, because closer to truncated value (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0010_000), + goldenAcc: big.Below, + }, + // Normal, round up, because closer to rounded up value (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0010_001), + goldenAcc: big.Below, + }, + // Normal, round up, because closer to rounded up value (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0010_001), + goldenAcc: big.Above, + }, + // Normal, round up, because half-way (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0010_001), + goldenAcc: big.Below, + }, + // Normal, truncate, because half-way (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0010_000), + goldenAcc: big.Below, + }, + // Subnormal, truncate, because closer to truncated value (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_01000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0000_000), + goldenAcc: big.Above, + }, + // Subnormal, truncate, because closer to truncated value (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_01000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0000_000), + goldenAcc: big.Below, + }, + // Subnormal, round up, because closer to rounded up value (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0000_001), + goldenAcc: big.Below, + }, + // Subnormal, round up, because closer to rounded up value (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0000_001), + goldenAcc: big.Above, + }, + // Subnormal, round up, because half-way (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0000_001), + goldenAcc: big.Below, + }, + // Subnormal, truncate, because half-way (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0000_000), + goldenAcc: big.Below, + }, + // Lost Precision + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_10000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b1_0000_010), + goldenAcc: big.Below, + }, + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_10000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b0_0000_010), + goldenAcc: big.Above, + }, + } + + for _, tt := range testCases { + t.Run("RoundHalfTowardsNegativeInf", func(t *testing.T) { + resultVal, resultAcc := roundHalfTowardsNegativeInf(tt.signBit, tt.exponentBits, tt.mantissaBits, tt.lostPrecision) + if (resultVal != tt.goldenVal) || (resultAcc != tt.goldenAcc) { + t.Logf("Failed Input Set:\n") + t.Logf("signBit: %#08x, exponentBits: %#08x, mantissaBits: %#08x", tt.signBit, tt.exponentBits, tt.mantissaBits) + t.Logf("lostPrecision: %v", tt.lostPrecision) + t.Errorf("Expected Result: %0#2x, Got: %0#2x\n", tt.goldenVal, resultVal) + t.Errorf("Expected Accuracy: %v, Got: %v\n", tt.goldenAcc, resultAcc) + } + }) + } +} + +func TestRoundHalfTowardsZero(t *testing.T) { + + // Rounding half towards zero involves rounding to the nearest + // representable number, and breaking ties by rounding towards the + // number closer to 0 + + testCases := []struct { + // Inputs + signBit uint32 + exponentBits uint32 + mantissaBits uint32 + lostPrecision bool + // Outputs + goldenVal Bits + goldenAcc big.Accuracy + }{ + // Exact + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_010_00000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0010_010), + goldenAcc: big.Exact, + }, + // Normal, truncate, because closer to truncated value (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0010_000), + goldenAcc: big.Above, + }, + // Normal, truncate, because closer to truncated value (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0010_000), + goldenAcc: big.Below, + }, + // Normal, round up, because closer to rounded up value (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0010_001), + goldenAcc: big.Below, + }, + // Normal, round up, because closer to rounded up value (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0010_001), + goldenAcc: big.Above, + }, + // Normal, truncate, because half-way (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0010_000), + goldenAcc: big.Above, + }, + // Normal, truncate, because half-way (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0010_000), + goldenAcc: big.Below, + }, + // Subnormal, truncate, because closer to truncated value (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_01000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0000_000), + goldenAcc: big.Above, + }, + // Subnormal, truncate, because closer to truncated value (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_01000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0000_000), + goldenAcc: big.Below, + }, + // Subnormal, round up, because closer to rounded up value (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0000_001), + goldenAcc: big.Below, + }, + // Subnormal, round up, because closer to rounded up value (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0000_001), + goldenAcc: big.Above, + }, + // Subnormal, truncate, because half-way (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0000_000), + goldenAcc: big.Above, + }, + // Subnormal, truncate, because half-way (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0000_000), + goldenAcc: big.Below, + }, + // Lost Precision + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_10000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b1_0000_010), + goldenAcc: big.Below, + }, + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_10000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b0_0000_010), + goldenAcc: big.Above, + }, + } + + for _, tt := range testCases { + t.Run("RoundHalfTowardsZero", func(t *testing.T) { + resultVal, resultAcc := roundHalfTowardsZero(tt.signBit, tt.exponentBits, tt.mantissaBits, tt.lostPrecision) + if (resultVal != tt.goldenVal) || (resultAcc != tt.goldenAcc) { + t.Logf("Failed Input Set:\n") + t.Logf("signBit: %#08x, exponentBits: %#08x, mantissaBits: %#08x", tt.signBit, tt.exponentBits, tt.mantissaBits) + t.Logf("lostPrecision: %v", tt.lostPrecision) + t.Errorf("Expected Result: %0#2x, Got: %0#2x\n", tt.goldenVal, resultVal) + t.Errorf("Expected Accuracy: %v, Got: %v\n", tt.goldenAcc, resultAcc) + } + }) + } +} + + +func TestRoundNearestOdd(t *testing.T) { + // Rounding to nearest even involves rounding to the nearest + // representable number, and breaking ties by rounding towards the + // number with LSB = 1 + + testCases := []struct { + // Inputs + signBit uint32 + exponentBits uint32 + mantissaBits uint32 + lostPrecision bool + // Outputs + goldenVal Bits + goldenAcc big.Accuracy + }{ + // Exact + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_010_00000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0010_010), + goldenAcc: big.Exact, + }, + // Normal, truncate, because closer to truncated value (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0010_000), + goldenAcc: big.Above, + }, + // Normal, truncate, because closer to truncated value (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0010_000), + goldenAcc: big.Below, + }, + // Normal, round up, because closer to rounded up value (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0010_001), + goldenAcc: big.Below, + }, + // Normal, round up, because closer to rounded up value (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0010_001), + goldenAcc: big.Above, + }, + // Normal, round up, because half-way, f32 LSB is zero (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_110_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0010_111), + goldenAcc: big.Above, + }, + // Normal, round up, because half-way, f32 LSB is zero (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_110_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0010_111), + goldenAcc: big.Below, + }, + // Normal, truncate, because half-way, f32 LSB is 1 (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_101_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0010_101), + goldenAcc: big.Below, + }, + // Normal, truncate, because half-way, f32 LSB is 1 (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_101_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0010_101), + goldenAcc: big.Above, + }, + // Subnormal, truncate, because closer to truncated value (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_01000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0000_000), + goldenAcc: big.Above, + }, + // Subnormal, truncate, because closer to truncated value (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_01000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0000_000), + goldenAcc: big.Below, + }, + // Subnormal, round up, because closer to rounded up value (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0000_001), + goldenAcc: big.Below, + }, + // Subnormal, round up, because closer to rounded up value (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0000_001), + goldenAcc: big.Above, + }, + // Subnormal, round up, because half-way, f32 LSB is zero (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_100_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0000_101), + goldenAcc: big.Above, + }, + // Subnormal, round up, because half-way, f32 LSB is zero (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_100_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0000_101), + goldenAcc: big.Below, + }, + // Subnormal, truncate, because half-way, f32 LSB is one (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_101_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0000_101), + goldenAcc: big.Below, + }, + // Subnormal, truncate, because half-way, f32 LSB is one (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_101_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0000_101), + goldenAcc: big.Above, + }, + // Lost Precision + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_10000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b1_0000_010), + goldenAcc: big.Below, + }, + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_10000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b0_0000_010), + goldenAcc: big.Above, + }, + } + + for _, tt := range testCases { + t.Run("RoundNearestOdd", func(t *testing.T) { + resultVal, resultAcc := roundNearestOdd(tt.signBit, tt.exponentBits, tt.mantissaBits, tt.lostPrecision) + if (resultVal != tt.goldenVal) || (resultAcc != tt.goldenAcc) { + t.Logf("Failed Input Set:\n") + t.Logf("signBit: %#08x, exponentBits: %#08x, mantissaBits: %#08x", tt.signBit, tt.exponentBits, tt.mantissaBits) + t.Logf("lostPrecision: %v", tt.lostPrecision) + t.Errorf("Expected Result: %0#2x, Got: %0#2x\n", tt.goldenVal, resultVal) + t.Errorf("Expected Accuracy: %v, Got: %v\n", tt.goldenAcc, resultAcc) + } + }) + } +} + +func TestRoundNearestEven(t *testing.T) { + // Rounding to nearest even involves rounding to the nearest + // representable number, and breaking ties by rounding towards the + // number with LSB = 0 + + testCases := []struct { + // Inputs + signBit uint32 + exponentBits uint32 + mantissaBits uint32 + lostPrecision bool + // Outputs + goldenVal Bits + goldenAcc big.Accuracy + }{ + // Exact + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_010_00000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0010_010), + goldenAcc: big.Exact, + }, + // Normal, truncate, because closer to truncated value (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0010_000), + goldenAcc: big.Above, + }, + // Normal, truncate, because closer to truncated value (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_00000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0010_000), + goldenAcc: big.Below, + }, + // Normal, round up, because closer to rounded up value (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0010_001), + goldenAcc: big.Below, + }, + // Normal, round up, because closer to rounded up value (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0010_001), + goldenAcc: big.Above, + }, + // Normal, truncate, because half-way, f32 LSB is zero (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_110_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0010_110), + goldenAcc: big.Below, + }, + // Normal, truncate, because half-way, f32 LSB is zero (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_110_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0010_110), + goldenAcc: big.Above, + }, + // Normal, round up, because half-way, f32 LSB is 1 (+ve) + { + signBit: 0x0, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_101_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0010_110), + goldenAcc: big.Above, + }, + // Normal, round up, because half-way, f32 LSB is 1 (-ve) + { + signBit: 0x1, + exponentBits: 0x2, + mantissaBits: 0b0_00000000_101_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0010_110), + goldenAcc: big.Below, + }, + // Subnormal, truncate, because closer to truncated value (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_01000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0000_000), + goldenAcc: big.Above, + }, + // Subnormal, truncate, because closer to truncated value (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_01000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0000_000), + goldenAcc: big.Below, + }, + // Subnormal, round up, because closer to rounded up value (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b1_0000_001), + goldenAcc: big.Below, + }, + // Subnormal, round up, because closer to rounded up value (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_000_10000000000000000001, + lostPrecision: false, + goldenVal: Bits(0b0_0000_001), + goldenAcc: big.Above, + }, + // Subnormal, truncate, because half-way, f32 LSB is zero (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_100_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0000_100), + goldenAcc: big.Below, + }, + // Subnormal, truncate, because half-way, f32 LSB is zero (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_100_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0000_100), + goldenAcc: big.Above, + }, + // Subnormal, round up, because half-way, f32 LSB is one (+ve) + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_101_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b0_0000_110), + goldenAcc: big.Above, + }, + // Subnormal, roudn up, because half-way, f32 LSB is one (-ve) + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_101_10000000000000000000, + lostPrecision: false, + goldenVal: Bits(0b1_0000_110), + goldenAcc: big.Below, + }, + // Lost Precision + { + signBit: 0x1, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_10000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b1_0000_010), + goldenAcc: big.Below, + }, + { + signBit: 0x0, + exponentBits: 0x0, + mantissaBits: 0b0_00000000_001_10000000000000000000, + lostPrecision: true, + goldenVal: Bits(0b0_0000_010), + goldenAcc: big.Above, + }, + } + + for _, tt := range testCases { + t.Run("RoundNearestEven", func(t *testing.T) { + resultVal, resultAcc := roundNearestEven(tt.signBit, tt.exponentBits, tt.mantissaBits, tt.lostPrecision) + if (resultVal != tt.goldenVal) || (resultAcc != tt.goldenAcc) { + t.Logf("Failed Input Set:\n") + t.Logf("signBit: %#08x, exponentBits: %#08x, mantissaBits: %#08x", tt.signBit, tt.exponentBits, tt.mantissaBits) + t.Logf("lostPrecision: %v", tt.lostPrecision) + t.Errorf("Expected Result: %0#2x, Got: %0#2x\n", tt.goldenVal, resultVal) + t.Errorf("Expected Accuracy: %v, Got: %v\n", tt.goldenAcc, resultAcc) + } + }) + } +} + +func TestFromBigFloat(t* testing.T) { + testCases := []struct { + name string + // Inputs + input big.Float + scaleFactor ScaleFactor + rm floatBit.RoundingMode + um floatBit.UnderflowMode + om floatBit.OverflowMode + // Outputs + goldenVal Bits + goldenAcc big.Accuracy + goldenStatus floatBit.Status + }{ + } + + for _, tt := range testCases { + resultVal, resultAcc, resultStatus := FromBigFloat(tt.input, tt.scaleFactor, tt.rm, tt.om, tt.um) + if (resultVal != tt.goldenVal || resultAcc != tt.goldenAcc || resultStatus != tt.goldenStatus) { + t.Logf("Failed Input Set:\n") + t.Logf("Name: %s", tt.name) + t.Logf("Value: %s", tt.input.String()) + t.Logf("Scale Factor: %d (%d)", tt.scaleFactor, int(tt.scaleFactor) - 127) + t.Logf("Rounding Mode: %v", tt.rm) + t.Logf("Overflow Mode: %v", tt.om) + t.Logf("Underflow Mode: %v", tt.um) + t.Errorf("Expected result: %.10e (%0#2x), Got: %.10e (%0#2x)", tt.goldenVal.ToFloat32(tt.scaleFactor), tt.goldenVal, resultVal.ToFloat32(tt.scaleFactor), resultVal) + t.Errorf("Expect accuracy: %v, Got: %v", tt.goldenAcc, resultAcc) + t.Errorf("Expected status: %v, Got: %v", tt.goldenStatus, resultStatus) + } + } +} diff --git a/pkg/float8e4m3bits/rounddown.go b/pkg/float8e4m3bits/rounddown.go index f7a0f50..36e9bba 100644 --- a/pkg/float8e4m3bits/rounddown.go +++ b/pkg/float8e4m3bits/rounddown.go @@ -1 +1,55 @@ package F8E4M3 + +import ( + "math/big" +) + +// Utility function that returns the number rounded to a number that is +// representable in float8e4m3. If y is the input number and x < y < x + 1ULP +// where x is a float8e4m3 number. Then this rounding mode picks up x +// signBit, and exponentBits must be passed with their values shifted all the +// way to the right. +// exponentBits must be passed with the appropriate scale factor +// and float8e4m3 bias applied +// mantissaBits must be passed in their float32 locations. +// NOTE: This doesn't handle the underflow and overflow cases. +// The parameter lostPrecision indicates whether the mantissa passed had already +// lost precision during any preprocessing +func roundTowardsNegativeInf(signBit, exponentBits, + mantissaBits uint32, lostPrecision bool) (Bits, big.Accuracy) { + return roundDown(signBit, exponentBits, mantissaBits, lostPrecision) +} + +func roundDown(signBit, exponentBits, + mantissaBits uint32, lostPrecision bool) (Bits, big.Accuracy) { + + mantissaF8E4M3Precision := mantissaBits & f32Float8E4M3MantissaMask + mantissaExtraPrecision := mantissaBits & f32Float8E4M3HalfSubnormalMask + + float8E4M3Sign := uint8(signBit << 7) + float8E4M3Exponent := uint8(exponentBits << 3) + float8E4M3Mantissa := uint8(mantissaF8E4M3Precision >> 20) + + // For this rounding mode, we only need to add 1 to the least-precision + // mantissa bit, if the input was negative, to bring it closer to -inf. + // For negative numbers, this is achieved by simply truncating. + + exponentMantissaComposite := (float8E4M3Exponent | float8E4M3Mantissa) + + // If negative and there is extra precision, then add 1 + if (float8E4M3Sign != 0) && + (mantissaExtraPrecision != 0 || lostPrecision) { + exponentMantissaComposite += 1 + } + // Since we don't handle overflow, all we need to do now is attach the sign + resultVal := Bits(float8E4M3Sign | exponentMantissaComposite) + + resultAcc := big.Exact + // If there was extra precision bits set, then we need to update the + // accuracy + if mantissaExtraPrecision != 0 || lostPrecision { + resultAcc = big.Below + } + + return resultVal, resultAcc +} diff --git a/pkg/float8e4m3bits/roundhalfdown.go b/pkg/float8e4m3bits/roundhalfdown.go index f7a0f50..b38cda3 100644 --- a/pkg/float8e4m3bits/roundhalfdown.go +++ b/pkg/float8e4m3bits/roundhalfdown.go @@ -1 +1,66 @@ package F8E4M3 + +import "math/big" + + +// Utility function that returns the number rounded to the closest float8e4m3 +// value. Ties are broken by rounding to the value closest to zero (truncation) +// signBit, and exponentBits must be passed with their values shifted all the +// way to the right. +// exponentBits must be passed with the float8e4m3 bias and the appropriate +// scale-factor applied +// mantissaBits must be passed in their float32 locations. +// NOTE: This doesn't handle the underflow and overflow cases. +func roundHalfTowardsNegativeInf(signBit, exponentBits, + mantissaBits uint32, lostPrecision bool) (Bits, big.Accuracy) { + + mantissaF8E4M3Precision := mantissaBits & f32Float8E4M3MantissaMask + mantissaExtraPrecision := mantissaBits & f32Float8E4M3HalfSubnormalMask + + float8E4M3Sign := uint8(signBit << 7) + float8E4M3Exponent := uint8(exponentBits << 3) + float8E4M3Mantissa := uint8(mantissaF8E4M3Precision >> 20) + + exponentMantissaComposite := float8E4M3Exponent | float8E4M3Mantissa + + addedOne := false + // We definitely add 1, if we're greater than the mid-point + if mantissaExtraPrecision > f32Float8E4M3HalfSubnormalLSB { + exponentMantissaComposite += 1 + addedOne = true + } + + // If extra precision was lost before, then we need to add one if we're + // halfway through in the adjusted mantissa (because this means we're + // actually greater than the midpoint + if mantissaExtraPrecision == f32Float8E4M3HalfSubnormalLSB && + lostPrecision { + exponentMantissaComposite += 1 + addedOne = true + } + + // In the case that we're halfway through + // we add 1, only if the sign was negative, otherwise we truncate + if (mantissaExtraPrecision == f32Float8E4M3HalfSubnormalLSB) && + (float8E4M3Sign != 0) && !lostPrecision { + exponentMantissaComposite += 1 + addedOne = true + } + + // For all other cases, we truncate, so now we can construct the result + // by attaching the sign + resultVal := Bits(float8E4M3Sign | exponentMantissaComposite) + resultAcc := big.Exact + + // Result is larger if the input was positive and we added 1, or + // if the input was negative and we truncated. + if mantissaExtraPrecision != 0 || lostPrecision { + resultAcc = big.Below + if (float8E4M3Sign == 0) == addedOne { + resultAcc = big.Above + } + } + + return resultVal, resultAcc +} + diff --git a/pkg/float8e4m3bits/roundhalftowardszero.go b/pkg/float8e4m3bits/roundhalftowardszero.go index f7a0f50..f66fd5b 100644 --- a/pkg/float8e4m3bits/roundhalftowardszero.go +++ b/pkg/float8e4m3bits/roundhalftowardszero.go @@ -1 +1,59 @@ package F8E4M3 + +import "math/big" + +// Utility function that returns the number rounded to the closest float8e4m3 +// value. Ties are broken by rounding to the value closest to zero (truncation) +// signBit, and exponentBits must be passed with their values shifted all the +// way to the right. +// exponentBits must be passed with the float8e4m3 bias and appropriate +// scale-factor applied +// mantissaBits must be passed in their float32 locations. +// NOTE: This doesn't handle the underflow and overflow cases. +// The parameter lostPrecision indicates whether the mantissa passed had already +// lost precision during any preprocessing +func roundHalfTowardsZero(signBit, exponentBits, + mantissaBits uint32, lostPrecision bool) (Bits, big.Accuracy) { + + mantissaF8E4M3Precision := mantissaBits & f32Float8E4M3MantissaMask + mantissaExtraPrecision := mantissaBits & f32Float8E4M3HalfSubnormalMask + + float8E4M3Sign := uint8(signBit << 7) + float8E4M3Exponent := uint8(exponentBits << 3) + float8E4M3Mantissa := uint8(mantissaF8E4M3Precision >> 20) + + exponentMantissaComposite := float8E4M3Exponent | float8E4M3Mantissa + + // If the extra precision bits exceed 1 0 0 0 ... + // we need to add 1 to the LSB of the f32 mantissa. + // For all other cases we truncate + addedOne := false + if mantissaExtraPrecision > f32Float8E4M3HalfSubnormalLSB { + exponentMantissaComposite += 1 + addedOne = true + } + + // If extra precision was lost before, then we need to add 1 if we're + // halfway through in the adjusted mantissa (because this means we're + // actually greater than the midpoint) + if mantissaExtraPrecision == f32Float8E4M3HalfSubnormalLSB && + lostPrecision { + exponentMantissaComposite += 1 + addedOne = true + } + + // All we need to do now is attach the sign + resultVal := Bits(float8E4M3Sign | exponentMantissaComposite) + resultAcc := big.Exact + + // Result is larger if the input was positive and we added 1, or + // if the input was negative and we truncated + if mantissaExtraPrecision != 0 || lostPrecision { + resultAcc = big.Below + if (float8E4M3Sign == 0) == addedOne { + resultAcc = big.Above + } + } + + return resultVal, resultAcc +} diff --git a/pkg/float8e4m3bits/roundhalfup.go b/pkg/float8e4m3bits/roundhalfup.go index cf9475b..57f378d 100644 --- a/pkg/float8e4m3bits/roundhalfup.go +++ b/pkg/float8e4m3bits/roundhalfup.go @@ -1,2 +1,67 @@ package F8E4M3 +import ( + "math/big" +) + +// Utility function that returns the number rounded to the closest float8e4m3 +// value. Ties are broken by rounding towards the value closer to +Infinity. +// signBit, and exponentBits must be passed with their values shifted all the +// way to the right. +// exponentBits must be passed with the float8e4m3 bias applied and the +// scaleFactor as well +// mantissaBits must be passed in their float32 locations. +// NOTE: This doesn't handle the underflow and overflow cases. +func roundHalfTowardsPositiveInf(signBit, exponentBits, + mantissaBits uint32, lostPrecision bool) (Bits, big.Accuracy) { + + mantissaF8E4M3Precision := mantissaBits & f32Float8E4M3MantissaMask + mantissaExtraPrecision := mantissaBits & f32Float8E4M3HalfSubnormalMask + + float8E4M3Sign := uint8(signBit << 7) + float8E4M3Exponent := uint8(exponentBits << 3) + float8E4M3Mantissa := uint8(mantissaF8E4M3Precision >> 20) + + exponentMantissaComposite := float8E4M3Exponent | float8E4M3Mantissa + + addedOne := false + // We definitely add 1 if we're greater than the mid-point + if mantissaExtraPrecision > f32Float8E4M3HalfSubnormalLSB { + exponentMantissaComposite += 1 + addedOne = true + } + + // If extra precision was lost before, then we need to add one if we're + // halfway through in the adjusted mantissa (because this means we're + // actually larger than the midpoint) + if (mantissaExtraPrecision == f32Float8E4M3HalfSubnormalLSB) && + lostPrecision { + exponentMantissaComposite += 1 + addedOne = true + } + + // In the case that we're halfway, we add 1, only if the sign was + // positive, otherwise we truncate + if (mantissaExtraPrecision == f32Float8E4M3HalfSubnormalLSB) && + (float8E4M3Sign == 0) && !lostPrecision { + exponentMantissaComposite += 1 + addedOne = true + } + + // For all other cases, we truncate. So, now we can construct the result + // by attaching the sign + resultVal := Bits(float8E4M3Sign | exponentMantissaComposite) + resultAcc := big.Exact + + // Result is larger if the input was positive and we added 1, or if + // the input was negative and we truncated + if mantissaExtraPrecision != 0 || lostPrecision { + resultAcc = big.Below + if (float8E4M3Sign == 0) == addedOne { + resultAcc = big.Above + } + } + + return resultVal, resultAcc +} + diff --git a/pkg/float8e4m3bits/roundnearesteven.go b/pkg/float8e4m3bits/roundnearesteven.go index f7a0f50..79abc84 100644 --- a/pkg/float8e4m3bits/roundnearesteven.go +++ b/pkg/float8e4m3bits/roundnearesteven.go @@ -1 +1,80 @@ package F8E4M3 + +import "math/big" + +// Utility function that returns the number rounded to the closest float8e4m3 +// value. Ties are broken by rounding to the even value (the LSB mantissa +// bit is 0) +// signBit, and exponentBits must be passed with their values shifted all the +// way to the right. +// exponentBits must be passed with the float8e4m3 bias and appropriate +// scale-factor applied +// mantissaBits must be passed in their float32 locations. +// NOTE: This doesn't handle the underflow and overflow cases. + +func roundNearestEven(signBit, exponentBits, mantissaBits uint32, + lostPrecision bool) (Bits, big.Accuracy) { + + // For rounding to nearest even, we round to the number that is closest and + // break ties by rounding towards the number that is even (LSB is 0) + + // LSB | Extra Precision Bits + // m20 m19 m18 m17 ... m0 + // 1. if m19 m18 m17 ... > 1 0 0 0 ... (more than half) we round up + // 2. if m19 m18 m17 ... < 1 0 0 0 ... (less than half) we truncate + // 3. if m19 m18 m17 ... == 1 0 0 0 ... (exactly half), then + // 3.1 m30 == 0, we truncate + // 3.2 m30 == 1, we round up + + mantissaF8E4M3Precision := mantissaBits & f32Float8E4M3MantissaMask + mantissaExtraPrecision := mantissaBits & f32Float8E4M3HalfSubnormalMask + + float8E4M3Sign := uint8(signBit << 7) + float8E4M3Exponent := uint8(exponentBits << 3) + float8E4M3Mantissa := uint8(mantissaF8E4M3Precision >> 20) + + exponentMantissaComposite := float8E4M3Exponent | float8E4M3Mantissa + + addedOne := false + + // We definitely add 1, if we're greater than the mid-point + if mantissaExtraPrecision > f32Float8E4M3HalfSubnormalLSB { + exponentMantissaComposite += 1 + addedOne = true + } + + // If extra precision was lost before, then we need to add one if we're + // halfway through in the adjusted mantissa (because this means we're + // actually greater than the midpoint) + if mantissaExtraPrecision == f32Float8E4M3HalfSubnormalLSB && + lostPrecision { + exponentMantissaComposite += 1 + addedOne = true + } + + mantissaF32LSB := mantissaBits & f32Float8E4M3SubnormalLSB + // In the case we're at the mid-pint, we only add 1, if the LSB of the + // float32 retained mantissa is 1 + if (mantissaF32LSB != 0) && (mantissaExtraPrecision == + f32Float8E4M3HalfSubnormalLSB) && !lostPrecision { + exponentMantissaComposite += 1 + addedOne = true + } + + // FOr all other cases we truncate, so, now we can construct the result + // by attaching the sign + resultVal := Bits(float8E4M3Sign | exponentMantissaComposite) + resultAcc := big.Exact + + // Result is larger if the input was positive and we added 1, + // or if the input was negative and we truncated + if mantissaExtraPrecision != 0 || lostPrecision { + resultAcc = big.Below + if (float8E4M3Sign == 0) == addedOne { + resultAcc = big.Above + } + } + + return resultVal, resultAcc +} + diff --git a/pkg/float8e4m3bits/roundnearestodd.go b/pkg/float8e4m3bits/roundnearestodd.go index f7a0f50..cea0cac 100644 --- a/pkg/float8e4m3bits/roundnearestodd.go +++ b/pkg/float8e4m3bits/roundnearestodd.go @@ -1 +1,81 @@ package F8E4M3 + +import "math/big" + +// Utility function that returns the number rounded to the closest float8e4m3 +// value. Ties are broken by rounding to the odd value (the LSB mantissa +// bit is 1) +// signBit, and exponentBits must be passed with their values shifted all the +// way to the right. +// exponentBits must be passed with the float8e4m3 bias and appropriate +// scale-factor applied +// mantissaBits must be passed in their float32 locations. +// NOTE: This doesn't handle the underflow and overflow cases. +// The parameter lostPrecision indicates whether the mantissa passed had already +// lost precision during any preprocessing +func roundNearestOdd(signBit, exponentBits, mantissaBits uint32, + lostPrecision bool) (Bits, + big.Accuracy) { + + // For rounding to nearest even, we round to the number that is closest and + // break ties by rounding towards the number that is even (LSB is 0) + + // LSB | Extra Precision Bits + // m20 m19 m18 m17 ... m0 + // 1. if m19 m18 m17 ... > 1 0 0 0 ... (more than half) we round up + // 2. if m19 m18 m17 ... < 1 0 0 0 ... (less than half) we truncate + // 3. if m19 m18 m17 ... == 1 0 0 0 ... (exactly half), then + // 3.1 m20 == 1, we truncate + // 3.2 m20 == 0, we round up + + mantissaF8E4M3Precision := mantissaBits & f32Float8E4M3MantissaMask + mantissaExtraPrecision := mantissaBits & f32Float8E4M3HalfSubnormalMask + + float8E4M3Sign := uint8(signBit << 7) + float8E4M3Exponent := uint8(exponentBits << 3) + float8E4M3Mantissa := uint8(mantissaF8E4M3Precision >> 20) + + exponentMantissaComposite := float8E4M3Exponent | float8E4M3Mantissa + + addedOne := false + // We definitely add 1, if we're greater than the mid-point + if mantissaExtraPrecision > f32Float8E4M3HalfSubnormalLSB { + exponentMantissaComposite += 1 + addedOne = true + } + + // If extra precision was lost before, then we need to add one if we're + // halfway through in the adjusted mantissa (because this means we're + // actually greater than the midpoint) + if mantissaExtraPrecision == f32Float8E4M3HalfSubnormalLSB && + lostPrecision { + exponentMantissaComposite += 1 + addedOne = true + } + + mantissaF32LSB := mantissaBits & f32Float8E4M3SubnormalLSB + // In the case we're at the mid-point, we only add 1, if the LSB of the + // float8e4m3 retained mantissa is 0 + if (mantissaF32LSB == 0) && (mantissaExtraPrecision == + f32Float8E4M3HalfSubnormalLSB) && !lostPrecision { + exponentMantissaComposite += 1 + addedOne = true + } + + // For all other cases we truncate, so, now we can construct the result + // by attaching the sign + resultVal := Bits(float8E4M3Sign | exponentMantissaComposite) + resultAcc := big.Exact + + // Result is larger if the input was positive and we added 1, or, + // if the input was negative and we truncated + if mantissaExtraPrecision != 0 || lostPrecision { + resultAcc = big.Below + if (float8E4M3Sign == 0) == addedOne { + resultAcc = big.Above + } + } + + return resultVal, resultAcc +} + diff --git a/pkg/float8e4m3bits/roundtowardszero.go b/pkg/float8e4m3bits/roundtowardszero.go index f7a0f50..7eac6a9 100644 --- a/pkg/float8e4m3bits/roundtowardszero.go +++ b/pkg/float8e4m3bits/roundtowardszero.go @@ -1 +1,40 @@ package F8E4M3 + +import "math/big" + +// Utility function that returns the number truncated to a number that can +// be represented as a float8e4m3 number. +// signBit, and exponentBits must be passed with their values shifted all the +// way to the right. +// exponentBits must be passed with the float8e4m3 bias and appropriate +// scale-factor applied +// mantissaBits must be passed in their float32 locations. +// NOTE: This doesn't handle the underflow and overflow cases. + +func roundTowardsZero(signBit, exponentBits, mantissaBits uint32, + lostPrecision bool) (Bits, big.Accuracy) { + + mantissaF8E4M3Precision := mantissaBits & f32Float8E4M3MantissaMask + mantissaExtraPrecision := mantissaBits & f32Float8E4M3HalfSubnormalMask + + float8E4M3Sign := uint8(signBit << 7) + float8E4M3Exponent := uint8(exponentBits << 3) + float8E4M3Mantissa := uint8(mantissaF8E4M3Precision >> 20) + + resultVal := Bits(float8E4M3Sign | float8E4M3Exponent | float8E4M3Mantissa) + resultAcc := big.Exact + + // If there was extra precision, then the number did not fit in the + // float8e4m3 format, so we nede to report the status appropriately + if mantissaExtraPrecision != 0 || lostPrecision { + if signBit == 0 { + resultAcc = big.Below + } else { + resultAcc = big.Above + } + } + + return resultVal, resultAcc + +} + diff --git a/pkg/float8e4m3bits/roundup.go b/pkg/float8e4m3bits/roundup.go index f7a0f50..f492cdd 100644 --- a/pkg/float8e4m3bits/roundup.go +++ b/pkg/float8e4m3bits/roundup.go @@ -1 +1,54 @@ package F8E4M3 + +import ( + "math/big" +) + +// Utility function that returns the number rounded to an number that is +// representable in float8e4m3. If y is the input number and x < y < x + 1ULP +// where x is a float8e4m3 number. Then this rounding mode picks up x + 1ULP +// signBit, and exponentBits must be passed with their values shifted all the +// way to the right. +// exponentBits must be passed with the appropriate scaleFactor and float8e4m3 +// bias already applied +// mantissaBits must be passed in their float32 locations +// NOTE: This doesn't handle the underflow and overflow cases +func roundTowardsPositiveInf(signBit, exponentBits, mantissaBits uint32, + lostPrecision bool) (Bits, big.Accuracy) { + return roundUp(signBit, exponentBits, mantissaBits, lostPrecision) +} + +func roundUp(signBit, exponentBits, mantissaBits uint32, + lostPrecision bool) (Bits, big.Accuracy) { + + mantissaF8E4M3Precision := mantissaBits & f32Float8E4M3MantissaMask + mantissaExtraPrecision := mantissaBits & f32Float8E4M3HalfSubnormalMask + + float8E4M3Sign := uint32(signBit << 7) + float8E4M3Exponent := uint32(exponentBits << 3) + float8E4M3Mantissa := uint32(mantissaF8E4M3Precision >> 20) + + // For this rounding mode, we only need to add 1 to the least-precision + // mantissa, if the input was positive, to bring it closer to +inf. + // For negative numbers, this is achieved by simply truncating. + + exponentMantissaComposite := (float8E4M3Exponent | float8E4M3Mantissa) + + // If positive and there is extra precision, then add 1 + if (float8E4M3Sign == 0) && + (mantissaExtraPrecision != 0 || lostPrecision) { + exponentMantissaComposite += 1 + } + + // Since we don't handle overlfow, all we need to do now is attach the sign + resultVal := Bits(float8E4M3Sign | exponentMantissaComposite) + + resultAcc := big.Exact + // If there was any extra precision left, then we need to update the + // accuracy + if mantissaExtraPrecision != 0 || lostPrecision { + resultAcc = big.Above + } + + return resultVal, resultAcc +} From 048c2f273ce641640f20b333fd7522163b6f9498 Mon Sep 17 00:00:00 2001 From: Shantanu Gontia Date: Wed, 9 Apr 2025 08:20:24 -0700 Subject: [PATCH 4/4] Update Status --- pkg/float16bits/float16.go | 14 +- pkg/float16bits/float16_test.go | 4 +- pkg/float8e4m3bits/float8e4m3.go | 332 +++++++++++++++++++++++--- pkg/float8e4m3bits/float8e4m3_test.go | 72 +++++- pkg/float8e8m0/scalefactor.go | 40 ++++ 5 files changed, 421 insertions(+), 41 deletions(-) create mode 100644 pkg/float8e8m0/scalefactor.go diff --git a/pkg/float16bits/float16.go b/pkg/float16bits/float16.go index 61fce62..97d9ae1 100644 --- a/pkg/float16bits/float16.go +++ b/pkg/float16bits/float16.go @@ -254,7 +254,7 @@ func FromFloat32(input float32, rm floatBit.RoundingMode, } // Special Case #5: Input exceeds the maximum normal value (in magnitude) - // that can be represented in the float32 format. In this case, the input om + // that can be represented in the float16 format. In this case, the input om // [floatBit.OverflowMode] determines the response. // First off, we calculate the actual value of the exponent. To do this, @@ -285,8 +285,8 @@ func FromFloat32(input float32, rm floatBit.RoundingMode, lostPrecision := false // Before performing any rounding, we need to make sure this exponent - // can actually be represented in the float32 format. If the exponent, - // is smaller than the minimum exponent allowed in float32 (-126), this + // can actually be represented in the float16 format. If the exponent, + // is smaller than the minimum exponent allowed in float16 (-14), this // either results in underflow, or it rounds up or trunc to some subnormal // number in the float32 format. We will need to take special care for // the cases where we truncate, because we might underflow and we need @@ -325,7 +325,7 @@ func FromFloat32(input float32, rm floatBit.RoundingMode, // // In general, by following the pattern, this shift amount is equal // to the difference between the minimum representable exponent (actual) - // and the actual value of the exponent in float64. + // and the actual value of the exponent in float32. shiftAmount := uint32(ExponentMin - actualExponent) for ; shiftAmount > 0; shiftAmount-- { lastDigit := alignedMantissa & 0x1 @@ -338,9 +338,9 @@ func FromFloat32(input float32, rm floatBit.RoundingMode, // Now that we have the value for the mantissa, we can determine // the underflow case. There is underflow, in the case when the // part of the mantissa that has the precision that can be represented - // in float32 is 0 (bits m52 to m30), but the rest of the mantissa has + // in float16 is 0 (bits m22 to m13), but the rest of the mantissa has // atleast 1 bit set i.e. all of the precision in the number is higher - // than that could be represented in float32. In this case, the response + // than that could be represented in float16. In this case, the response // is handled by the input um [floatBit.UnderflowMode] if checkUnderflow(alignedMantissa, lostPrecision) { return handleUnderflow(signBit, um) @@ -463,7 +463,7 @@ func handleOverflow(signBit uint32, om floatBit.OverflowMode) (Bits, return Bits(NegativeNaN), big.Below, floatBit.Overflow case floatBit.SaturateMax: if signBit == 0 { - // The maximum normal in float32 is smaller than any number + // The maximum normal in float16 is smaller than any number // this function will be invoked for return Bits(PositiveMaxNormal), big.Below, floatBit.Overflow } diff --git a/pkg/float16bits/float16_test.go b/pkg/float16bits/float16_test.go index 85786e3..2cbc1d8 100644 --- a/pkg/float16bits/float16_test.go +++ b/pkg/float16bits/float16_test.go @@ -89,7 +89,7 @@ func TestHandleOverflow(t *testing.T) { if (resultVal != tt.goldenVal) || (resultAcc != tt.goldenAcc) || (resultStatus != tt.goldenStatus) { t.Logf("Failed Input Set:\n") t.Logf("SignBit: %v\tOverflowMode: %v\n", tt.signBit, tt.om) - t.Errorf("Expected Result: %0#8x, Got: %0#8x\n", tt.goldenVal, resultVal) + t.Errorf("Expected Result: %0#4x, Got: %0#4x\n", tt.goldenVal, resultVal) t.Errorf("Expected Accuracy: %v, Got: %v\n", tt.goldenAcc, resultAcc) t.Errorf("Expected Status: %v, Got: %v\n", tt.goldenStatus, resultStatus) } @@ -121,7 +121,7 @@ func TestHandleUnderflow(t *testing.T) { if (resultVal != tt.goldenVal) || (resultAcc != tt.goldenAcc) || (resultStatus != tt.goldenStatus) { t.Logf("Failed Input Set:\n") t.Logf("SignBit: %v\tUnderflowMode: %v\n", tt.signBit, tt.um) - t.Errorf("Expected Result: %0#8x, Got: %0#8x\n", tt.goldenVal, resultVal) + t.Errorf("Expected Result: %0#4x, Got: %0#4x\n", tt.goldenVal, resultVal) t.Errorf("Expected Accuracy: %v, Got: %v\n", tt.goldenAcc, resultAcc) t.Errorf("Expected Status: %v, Got: %v\n", tt.goldenStatus, resultStatus) } diff --git a/pkg/float8e4m3bits/float8e4m3.go b/pkg/float8e4m3bits/float8e4m3.go index 5887961..bcc56bb 100644 --- a/pkg/float8e4m3bits/float8e4m3.go +++ b/pkg/float8e4m3bits/float8e4m3.go @@ -8,6 +8,7 @@ import ( floatBit "github.com/shantanu-gontia/float-conv/pkg" F32 "github.com/shantanu-gontia/float-conv/pkg/float32bits" + F8E8M0 "github.com/shantanu-gontia/float-conv/pkg/float8e8m0" ) // Some constants that will help with bit manipulation we'll need to @@ -32,25 +33,13 @@ const ( // responses type Bits uint8 -// Alias type for uint8. This is used to represent the FP8E8M0 format which -// is used to specify the scale factor for a OCP MXFP8 floating point format. -// The actual floating point represented by an OCP MXPF8 number is the -// value encoded in the bits scaled by 2^(scale_factor) when scale_factor -// is interpreted as an fp8e8m0 number. For example, to apply no scaling, -// the number passed must be 127, because the actual scale that is -// multiplied is 2^(scale_factor - 127). If the exponent in the resulting -// number is out of the range supported by the format, then the result is -// undefined per the spec. In our case, we clamp to Infinity of the same sign -// as the original number -type ScaleFactor uint8 - // Convert the given [Bits] type to the floating point number it represents, // inside a float32 value. This is effectively, a bit_cast to float8e4m3, // followed by a upcast to float32. Since Go doesn't natively support // float8e4m3 values, this method performs some bit-twiddling, // to align the bits per the float32 bit representation and then scaling -// the final result with the [ScaleFactor] -func (input Bits) ToFloat32(scaleFactor ScaleFactor) float32 { +// the final result with the [F8E8M0.ScaleFactor] +func (input Bits) ToFloat32(scaleFactor F8E8M0.ScaleFactor) float32 { asUint8 := uint8(input) // Extract the Sign, Exponent and Mantissa signBit := (asUint8 & SignMask) >> 7 @@ -158,7 +147,7 @@ func (input Bits) ToFloat32(scaleFactor ScaleFactor) float32 { // this works by intermediate conversion to float32, so if the number // is not representable in float32 with the given scale factor, the result // is same as what would be after conversion to float32 -func (input Bits) ToBigFloat(scaleFactor ScaleFactor) big.Float { +func (input Bits) ToBigFloat(scaleFactor F8E8M0.ScaleFactor) big.Float { asFloat32 := input.ToFloat32(scaleFactor) asBigFloat := *big.NewFloat(float64(asFloat32)) return asBigFloat @@ -172,31 +161,268 @@ func (input Bits) ToBigFloat(scaleFactor ScaleFactor) big.Float { // a [big.Accuracy] which encodes whether the result value was the same, // larger, or smaller than the input, and a [floatBit.Status] which encodes, // whether the result ft in the [Bits], caused overflow, or underflow. -func FromBigFloat(input big.Float, scaleFactor ScaleFactor, +func FromBigFloat(input big.Float, scaleFactor F8E8M0.ScaleFactor, rm floatBit.RoundingMode, om floatBit.OverflowMode, um floatBit.UnderflowMode) (Bits, big.Accuracy, floatBit.Status) { - return Bits(0), big.Exact, floatBit.Fits + + // Since the [big] package's methods do not support rounding modes for + // direct conversion to float8e4m3. We convert to an intermediate float32 + // number and use our custom conversion function [FromFloat32] to convert + // to [Bits] + input.SetMode(big.ToZero) + closestFloat32, fromBigFloatAcc := input.Float32() + + var asFloat32 float32 + // [big.Float.Float32] returns the float32 closest to the input. + // This might cause it round UP for some cases. + // But, we need to get the value with extra precision truncated. + // To get the truncated result, we need to subtract 1 ULP of + // precision if the number is positive and the float32 is larger, + // or if the number is negative and the float32 is smaller, or + // alternatively if [big.Float.Float32] returns [big.Above] as the + // accuracy, because for truncation this should always be [big.Below] + + // Note that however, we need to exempt the case where the results + // become infinity because that counts not as rounding but overflow. + if math.IsInf(float64(closestFloat32), 1) && fromBigFloatAcc == big.Above { + // if the input was greater then the float32 maximum normal, then + // closestFloat32 would be +inf, and the accuracy returned would be + // [big.Above]. To pass through the overflow handling to + // [FromFloat32], we thus cap the infinities to the maximum normal + // numbers in float32. + // F32.PositiveMaxNormal will trigger overflow response + // in F8E4M3 + asFloat32 = math.Float32frombits(F32.PositiveMaxNormal) + } else if math.IsInf(float64(closestFloat32), -1) && fromBigFloatAcc == big.Below { + asFloat32 = math.Float32frombits(F32.NegativeMaxNormal) + } else if closestFloat32 == 0.0 && fromBigFloatAcc == big.Below { + // Similar to the infinity case, we also need to make sure that the + // underflow response handling is passed through to [FromFloat32]. + // So, if the closest float32 number results in 0 and the accuracy is + // big.Below or +ve numbers, then we just pass in F32.PositiveMinSubnormal + // which will trigger underflow in [FromFloat32] + asFloat32 = math.Float32frombits(F32.PositiveMinSubnormal) + } else if closestFloat32 == math.Float32frombits(F32.NegativeZero) && + fromBigFloatAcc == big.Above { + asFloat32 = math.Float32frombits(F32.NegativeMinSubnormal) + } else if (input.Sign() > 0 && fromBigFloatAcc == big.Above) || + (input.Sign() < 0 && fromBigFloatAcc == big.Below) { + // For positive numbers if the accuracy was [big.Above], then + // [big.Float.Float32] caused rounding away from zero. This is + // undesirable. To make it truncation we need to subtract 1 ULP + // from the number + closestFloat32Bits := math.Float32bits(closestFloat32) + asFloat32 = math.Float32frombits(closestFloat32Bits - 1) + } else { + asFloat32 = closestFloat32 + } + + resultBits, resultAcc, resultStatus := FromFloat32(asFloat32, scaleFactor, + rm, om, um) + return resultBits, resultAcc, resultStatus } + // Convert the given float32 number into a [Bits] type which represents the // bits of a OCP MXFP8E4M3 number. Signature and usage is identical to // [FromBigFloat] except the input is a float32 -func FromFloat32(input float32, scaleFactor ScaleFactor, +func FromFloat32(input float32, scaleFactor F8E8M0.ScaleFactor, rm floatBit.RoundingMode, om floatBit.OverflowMode, um floatBit.UnderflowMode) (Bits, big.Accuracy, floatBit.Status) { - return Bits(0), big.Exact, floatBit.Fits + + // we need to access the underlying bits of the float32 number + asUint32 := math.Float32bits(input) + + // With the number interpreted as uint32, we can now extract the underlying + // sign, exponent, and mantissa bits + signBit := (asUint32 & F32.SignMask) >> 31 + exponentBits := (asUint32 & F32.ExponentMask) >> 23 + mantissaBits := asUint32 & F32.MantissaMask + + // Special Case #1: + // scaleFactor is NaN. Then the result is also NaN + if scaleFactor == F8E8M0.NaN { + if signBit == 0 { + return Bits(PositiveNaN), big.Exact, floatBit.Overflow + } + return Bits(NegativeNaN), big.Exact, floatBit.Overflow + } + + // Special Case #2: Infinities -> NaN + // F8E8M0 does not have infinites, so we return a NaN + // and overflow + if math.IsInf(float64(input), 1) { + return Bits(PositiveNaN), big.Below, floatBit.Overflow + } + if math.IsInf(float64(input), -1) { + return Bits(NegativeNaN), big.Above, floatBit.Overflow + } + + // Special Case #3: NaN + // NaNs always convert to NaNs. For our case,we consider the conversion + // to be exact + if math.IsNaN(float64(input)) { + return Bits(NaN), big.Exact, floatBit.Fits + } + + // Special Case #4: Zeros + if asUint32 == F32.PositiveZero { + return Bits(PositiveZero), big.Exact, floatBit.Fits + } + if asUint32 == F32.NegativeZero { + return Bits(NegativeZero), big.Exact, floatBit.Fits + } + + // Special Case #4: scaledExponent is smaller than the minimum float8e4m3 + // representable exponent. (Underflow) + actualScaleFactor := int(scaleFactor) - int(F8E8M0.ExponentBias) + actualExponent := int(exponentBits) - F32.ExponentBias + scaledExponent := actualExponent - actualScaleFactor + // These exponents correspond to the subnormals in float32. + // They will underflow in float8e4m3 + if scaledExponent < F32.ExponentMin { + return handleUnderflow(signBit, um) + } + + // Special Case #5: scaledExponent is larger than the maximum float8e4m3 + // representable exponent. (Overflow) + if scaledExponent > F32.ExponentMax { + return handleOverflow(signBit, om) + } + + // Special Case #6: Input exceeds the maximum normal value (in magnitude) + // that can be represented in the float8e4m3 format. In this case, the + // input om [floatBit.OverflowMode] determines the response + if checkOverflow(scaledExponent, mantissaBits) { + return handleOverflow(signBit, om) + } + + // Variables to store the return values in + var resultVal Bits + var resultAcc big.Accuracy + + // To simplify the calculation, we need to calculate two quantities + // 1. aligned mantissa - For normal numbers this is the same as the + // original, but for subnormals we need to adjust it, because subnormals + // don't have an implicit 1.0 addition like normals do. + // 2. the adjusted exponent - For normal numbers, all we need to do is + // subtract the float32 bias and apply the float8e4m3 bias. But for + // subnormal numbers this should be exactly 0. + alignedMantissa := mantissaBits + adjustedExponent := uint32(scaledExponent + ExponentBias) + + // Value that indicates whether any precision was lost when preprocessing + // the mantissa before passing it down to the rounding routines + lostPrecision := false + + // Before performing any rounding, we need to make sure this exponent + // can actually be represented in the float32 format. If the exponent + // is smaller than the float8e4m3 format. If the exponent is smaller than + // the minimum exponent allowed in float8e4m3 (-6), this either results + // in underflow, or it rounds up or truncates to some subnormal number + // in the float8e4m3 format. We will need to take special care for the + // cases we truncate, because, we might underflow and we need to report + // that. + if actualExponent < ExponentMin { + // We start with the assumption that the number can be represented + // by a subnormal number in float8e4m3. Since subnormal numbers do not + // have an implicit 1.0 addition), we add an implicit 1.0, to the + // exponent LSB. + const ( + float32ExponentLSB uint32 = 0b0_00000001_00000000000000000000000 + ) + + // For subnormals, we need to appropriately calculate the aligned + // mantissa to accout for the deficit of the implicit 1.0 addition + alignedMantissa = mantissaBits | float32ExponentLSB + + // And also, for the subnormal case, we need to set the adjusted + // exponent bits to 0 + adjustedExponent = 0 + + // Now, to aligned the bits of the original format, with the mantissa + // of the destination format as a subnormal, we have to shift right + // until this implicit 1.0 addition falls to the mantissa bit + // corresponding to the appropriate power of 2 (this is the largest + // power of 2 in the number) + // + // Consider the following example. We have, + // f32_n = 2^(-8) * (1 + m_0/2 + m_1/4 + m_2/8 + ...) + // = 2^(-6) * (1/4 + m_0/8 + m_1/16 + m_2/32 + ...) + // + // Clearly, this caused the mantissa bit corresponding to 1/2 to + // turn to 0 and m_0 which originally would have been for 1/2, now + // corresponds to the 1/8 power. This is equivalent to a right-shift + // by 2 = (-6 - (-8)) + // + // In general, by following the pattern, this shift amount is equal + // to the difference between the minimum representable exponent in + // float8e4m3 and the actual value of the exponent in float32 after + // scaling + shiftAmount := uint32(ExponentMin - scaledExponent) + for ; shiftAmount > 0; shiftAmount-- { + lastDigit := alignedMantissa & 0x1 + if lastDigit == 1 { + // There was precision lost due to shifting which wouldn't + // be retained in the aligned mantissa. We need to track this + // to record accuracy + lostPrecision = true + } + alignedMantissa >>= 1 + } + + // Now that we have the value for the mantissa, we can determine + // the underflow case. There is underflow in the case when the part of + // the mantissa that has precision that can be represented in + // float8e4m3 is 0, but the rest of the mantissa has atleast 1 bit set + // i.e. all of the precision in the number is higher than that could + // be represented in float8e4m3. In this case, the response is + // handled by the input um [floatBit.UnderflowMode] + if checkUnderflow(alignedMantissa, lostPrecision) { + return handleUnderflow(signBit, um) + } + + } + + // Now that the mantissa bits are properly placed, and exponents are + // aligned and the overflow, underflow case is handled. We can handle the + // normal -> normal, subnormal case by performing the correct rounding + + switch rm { + case floatBit.RoundTowardsZero: + resultVal, resultAcc = roundTowardsZero(signBit, + adjustedExponent, alignedMantissa, lostPrecision) + case floatBit.RoundTowardsNegativeInf: + resultVal, resultAcc = roundTowardsNegativeInf(signBit, + adjustedExponent, alignedMantissa, lostPrecision) + case floatBit.RoundTowardsPositiveInf: + resultVal, resultAcc = roundTowardsPositiveInf(signBit, + adjustedExponent, alignedMantissa, lostPrecision) + case floatBit.RoundHalfTowardsZero: + resultVal, resultAcc = roundHalfTowardsZero(signBit, + adjustedExponent, alignedMantissa, lostPrecision) + case floatBit.RoundHalfTowardsNegativeInf: + resultVal, resultAcc = roundHalfTowardsNegativeInf(signBit, + adjustedExponent, alignedMantissa, lostPrecision) + case floatBit.RoundHalfTowardsPositiveInf: + resultVal, resultAcc = roundHalfTowardsPositiveInf(signBit, + adjustedExponent, alignedMantissa, lostPrecision) + case floatBit.RoundNearestEven: + resultVal, resultAcc = roundNearestEven(signBit, + adjustedExponent, alignedMantissa, lostPrecision) + case floatBit.RoundNearestOdd: + resultVal, resultAcc = roundNearestOdd(signBit, + adjustedExponent, alignedMantissa, lostPrecision) + } + + return resultVal, resultAcc, floatBit.Fits } // Utility function to check if the number with the given exponent and mantissa // bits would overflow when trying to represent it in a float8e4m3 value // with the given scale factor. // mantissaBits should occupy the bits as they would in a float32 number -func checkOverflow(actualExponent int, mantissaBits uint32, - scaleFactor ScaleFactor) bool { - - // Remove the scaling - scaledExponent := actualExponent - (int(scaleFactor) - 127) - +func checkOverflow(scaledExponent int, mantissaBits uint32) bool { if scaledExponent > ExponentMax { return true } @@ -214,15 +440,15 @@ func checkOverflow(actualExponent int, mantissaBits uint32, return false } -// Utility function to check if the number with the given exponent and mantissa +// Utility function to check if the number with the given mantissa // bits would underflow when trying to represent it in a float8e4m3 value -// with the given scaleFactor. Subnormals require shifting the mantissa to +// assuming that exponent is min. Subnormals require shifting the mantissa to // align the exponents. This might cause loss of precision that cannot be // detected by the mantissaBits alone as they are already shifted. The // lostPrecision parameter helps us with that. If it's set to true then there // was precision lost when the mantissa was being aligned. func checkUnderflow(mantissaBits uint32, lostPrecision bool) bool { - // This assumes that the exponent after scaling is 0, so any extra + // This assumes that the exponent bits after scaling is 0, so any extra // precision in the mantissa means underflow f8e4m3PrecisionMantissa := mantissaBits & f32Float8E4M3MantissaMask f8e4m3ExtraPrecisionMantissa := mantissaBits & @@ -234,6 +460,54 @@ func checkUnderflow(mantissaBits uint32, lostPrecision bool) bool { return false } + +// Utility function that resturns the result for the case when +// the conversion results in overflow. Since float8e4m3 does not +// support infinites, the [floatBit.SaturateInf] returns NaN instead +func handleOverflow(signBit uint32, om floatBit.OverflowMode) (Bits, + big.Accuracy, floatBit.Status) { + switch om { + case floatBit.SaturateInf: + fallthrough + case floatBit.MakeNaN: + if signBit == 0 { + return Bits(PositiveNaN), big.Above, floatBit.Overflow + } + return Bits(NegativeNaN), big.Below, floatBit.Overflow + case floatBit.SaturateMax: + if signBit == 0 { + // The maximum normal in float8e4m3 is smaller than any + // number this function will be invoked for + return Bits(PositiveMaxNormal), big.Below, floatBit.Overflow + } + return Bits(NegativeMaxNormal), big.Above, floatBit.Overflow + default: + panic("Unsupported OverflowMode encountered") + } +} + +// Utility function that resturns the result for the case when +// the conversion results in underflow +func handleUnderflow(signBit uint32, um floatBit.UnderflowMode) (Bits, +big.Accuracy, floatBit.Status) { + switch um { + case floatBit.FlushToZero: + if signBit == 0 { + // Zero is less than any positive subnormal + return Bits(PositiveZero), big.Below, floatBit.Underflow + } + return Bits(NegativeZero), big.Above, floatBit.Underflow + case floatBit.SaturateMin: + if signBit == 0 { + // Min subnormal of float32 is larger than any float64 subnormal + return Bits(PositiveMinSubnormal), big.Above, floatBit.Underflow + } + return Bits(NegativeMinSubnormal), big.Below, floatBit.Underflow + default: + panic("Unsupported UnderflowMode encountered") + } +} + // ToFloatFormat converts the given [Bits] type representing the bits that // make up a OCP MXFP8E4M3 umber into [floatBit.FloatBitFormat] // Implements the FloatBitFormatter interface @@ -289,7 +563,7 @@ func (b* Bits) ToFloatFormat() floatBit.FloatBitFormat { // Conversion error returns the difference between the input [big.Float] // number and the float8e4m3 number represented by the bits in the [Bits] // receiver when it's scaled with the given scale factor -func (b* Bits) ConversionError(input* big.Float, scaleFactor ScaleFactor) ( +func (b* Bits) ConversionError(input* big.Float, scaleFactor F8E8M0.ScaleFactor) ( big.Float, error) { // If the receiver is a NaN then we return an error asFloat32 := b.ToFloat32(scaleFactor) diff --git a/pkg/float8e4m3bits/float8e4m3_test.go b/pkg/float8e4m3bits/float8e4m3_test.go index c300690..082e2b2 100644 --- a/pkg/float8e4m3bits/float8e4m3_test.go +++ b/pkg/float8e4m3bits/float8e4m3_test.go @@ -7,13 +7,14 @@ import ( floatBit "github.com/shantanu-gontia/float-conv/pkg" F32 "github.com/shantanu-gontia/float-conv/pkg/float32bits" + F8E8M0 "github.com/shantanu-gontia/float-conv/pkg/float8e8m0" ) func TestToFloat32(t *testing.T) { testCases := []struct{ // Input input Bits - scaleFactor ScaleFactor + scaleFactor F8E8M0.ScaleFactor // Output golden float32 }{ @@ -125,7 +126,7 @@ func TestCheckOverflow(t* testing.T) { // Inputs actualExponent int mantissaBits uint32 - scaleFactor ScaleFactor + scaleFactor F8E8M0.ScaleFactor // Outputs golden bool }{ @@ -211,6 +212,70 @@ func TestCheckUnderflow(t* testing.T) { } } +func TestHandleOverflow(t *testing.T) { + testCases := []struct { + // In + signBit uint32 + om floatBit.OverflowMode + // Out + goldenVal Bits + goldenAcc big.Accuracy + goldenStatus floatBit.Status + }{ + {0, floatBit.SaturateInf, Bits(PositiveNaN), big.Above, floatBit.Overflow}, + {1, floatBit.SaturateInf, Bits(NegativeNaN), big.Below, floatBit.Overflow}, + {200, floatBit.SaturateInf, Bits(NegativeNaN), big.Below, floatBit.Overflow}, + {0, floatBit.SaturateMax, Bits(PositiveMaxNormal), big.Below, floatBit.Overflow}, + {1, floatBit.SaturateMax, Bits(NegativeMaxNormal), big.Above, floatBit.Overflow}, + {200, floatBit.SaturateMax, Bits(NegativeMaxNormal), big.Above, floatBit.Overflow}, + } + + for _, tt := range testCases { + t.Run("HandleOverflow", func(t *testing.T) { + resultVal, resultAcc, resultStatus := handleOverflow(tt.signBit, tt.om) + if (resultVal != tt.goldenVal) || (resultAcc != tt.goldenAcc) || (resultStatus != tt.goldenStatus) { + t.Logf("Failed Input Set:\n") + t.Logf("SignBit: %v\tOverflowMode: %v\n", tt.signBit, tt.om) + t.Errorf("Expected Result: %0#2x, Got: %0#2x\n", tt.goldenVal, resultVal) + t.Errorf("Expected Accuracy: %v, Got: %v\n", tt.goldenAcc, resultAcc) + t.Errorf("Expected Status: %v, Got: %v\n", tt.goldenStatus, resultStatus) + } + }) + } +} + +func TestHandleUnderflow(t *testing.T) { + testCases := []struct { + // In + signBit uint32 + um floatBit.UnderflowMode + // Out + goldenVal Bits + goldenAcc big.Accuracy + goldenStatus floatBit.Status + }{ + {0, floatBit.FlushToZero, Bits(PositiveZero), big.Below, floatBit.Underflow}, + {1, floatBit.FlushToZero, Bits(NegativeZero), big.Above, floatBit.Underflow}, + {200, floatBit.FlushToZero, Bits(NegativeZero), big.Above, floatBit.Underflow}, + {0, floatBit.SaturateMin, Bits(PositiveMinSubnormal), big.Above, floatBit.Underflow}, + {1, floatBit.SaturateMin, Bits(NegativeMinSubnormal), big.Below, floatBit.Underflow}, + {200, floatBit.SaturateMin, Bits(NegativeMinSubnormal), big.Below, floatBit.Underflow}, + } + + for _, tt := range testCases { + t.Run("HandleUnderflow", func(t *testing.T) { + resultVal, resultAcc, resultStatus := handleUnderflow(tt.signBit, tt.um) + if (resultVal != tt.goldenVal) || (resultAcc != tt.goldenAcc) || (resultStatus != tt.goldenStatus) { + t.Logf("Failed Input Set:\n") + t.Logf("SignBit: %v\tUnderflowMode: %v\n", tt.signBit, tt.um) + t.Errorf("Expected Result: %0#2x, Got: %0#2x\n", tt.goldenVal, resultVal) + t.Errorf("Expected Accuracy: %v, Got: %v\n", tt.goldenAcc, resultAcc) + t.Errorf("Expected Status: %v, Got: %v\n", tt.goldenStatus, resultStatus) + } + }) + } +} + func TestRoundTowardsPositiveInf(t *testing.T) { testCases := []struct { // Inputs @@ -1376,7 +1441,7 @@ func TestFromBigFloat(t* testing.T) { name string // Inputs input big.Float - scaleFactor ScaleFactor + scaleFactor F8E8M0.ScaleFactor rm floatBit.RoundingMode um floatBit.UnderflowMode om floatBit.OverflowMode @@ -1385,6 +1450,7 @@ func TestFromBigFloat(t* testing.T) { goldenAcc big.Accuracy goldenStatus floatBit.Status }{ + } for _, tt := range testCases { diff --git a/pkg/float8e8m0/scalefactor.go b/pkg/float8e8m0/scalefactor.go new file mode 100644 index 0000000..385b56b --- /dev/null +++ b/pkg/float8e8m0/scalefactor.go @@ -0,0 +1,40 @@ +package F8E8M0 + +import ( + "math" + + F32 "github.com/shantanu-gontia/float-conv/pkg/float32bits" +) + +// Alias type for uint8. This is used to represent the FP8E8M0 format which +// is used to specify the scale factor for a OCP MXFP8 floating point format. +// The actual floating point represented by an OCP MXPF8 number is the +// value encoded in the bits scaled by 2^(scale_factor) when scale_factor +// is interpreted as an fp8e8m0 number. For example, to apply no scaling, +// the number passed must be 127, because the actual scale that is +// multiplied is 2^(scale_factor - 127). If the exponent in the resulting +// number is out of the range supported by the format, then the result is +// undefined per the spec. In our case, we clamp to Infinity of the same sign +// as the original number +type ScaleFactor uint8 + +// Apply scale factor to the given float32 number +func ApplyScaleFactor(input float32, scaleFactor ScaleFactor) float32 { + inputAsBits := math.Float32bits(input) + inputAsBitsNoExponent := (inputAsBits & F32.SignMask) | + (inputAsBits & F32.MantissaMask) + inputAsBitsMaskedExponent := inputAsBits & F32.ExponentMask + exponentBits := (inputAsBitsMaskedExponent) >> 20 + actualExponent := int(exponentBits) - 127 + scaleFactorAsInt := int(scaleFactor) - 127 + resultExponent := actualExponent - scaleFactorAsInt + resultExponentBits := (uint32(resultExponent) + 127) << 20 + resultBits := inputAsBitsNoExponent | resultExponentBits + return math.Float32frombits(resultBits) +} + +const ( + NaN ScaleFactor = 255 + ExponentBias int = 127 +) +