diff --git a/.gitignore b/.gitignore index afc7f3d..babf1e4 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,12 @@ Cargo.lock **/*.rs.bk .vscode + +.env/ + +*.egg-info +__pycache__ +*.so + +build/ + diff --git a/Cargo.toml b/Cargo.toml index 3475519..fcedfbf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,15 +12,16 @@ readme = "README.md" repository = "https://github.com/nadavrot/arpfloat" [dependencies] - +pyo3 = { version = "0.24.1", optional = true } [dev-dependencies] -criterion = "0.4" +criterion = "0.5" [[bench]] name = "main_benchmark" harness = false [features] -default = ["std"] +default = ["std", "python"] std = [] +python=["pyo3", "std"] diff --git a/README.md b/README.md index 715aae0..77b8023 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,8 @@ types can scale to hundreds of digits, and perform very accurate calculations. In ARPFloat the rounding mode is a part of the type-system, and this defines away a number of problem that show up when using fenv.h. -`no_std` environments are supported by disabling the `std` feature. +`no_std` environments are supported by disabling the `std` feature. +`python` bindings are supported by enabling the `python` feature. ### Example ```rust @@ -125,9 +126,60 @@ The program above will print this output: .... ``` - The [examples](examples) directory contains a few programs that demonstrate the use of this library. +### Python Bindings + +The has python bindings that can be installed with 'pip install -e .' + +```python + >>> from arpfloat import Float, Semantics, FP16, BF16, FP32, fp64, pi + + >>> x = fp64(2.5).cast(FP16) + >>> y = fp64(1.5).cast(FP16) + >>> x + y + 4. + + >>> sem = Semantics(10, 10, "NearestTiesToEven") + >>> sem + Semantics { exponent: 10, precision: 10, mode: NearestTiesToEven } + >>> Float(sem, False, 0b1000000001, 0b1100101) + 4.789062 + + >>> pi(FP32) + 3.1415927 + >>> pi(FP16) + 3.140625 + >>> pi(BF16) + 3.140625 +``` + +Arpfloat allows you to experiment with new floating point formats. For example, +Nvidia's new [FP8](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) +format can be defined as: + +```python + import numpy as np + from arpfloat import FP32, fp64, Semantics, zero + + # Create two random numpy arrays in the range [0,1) + A0 = np.random.rand(1000000) + A1 = np.random.rand(1000000) + + # Calculate the numpy dot product of the two arrays + print("Using fp32 arithmetic : ", np.dot(A0, A1)) + + # Create the fp8 format (4 exponent bits, 3 mantissa bits + 1 implicit bit) + FP8 = Semantics(4, 3 + 1, "NearestTiesToEven") + + # Convert the arrays to fp8 + A0 = [fp64(x).cast(FP8) for x in A0] + A1 = [fp64(x).cast(FP8) for x in A1] + + dot = sum([x.cast(FP32)*y.cast(FP32) for x, y in zip(A0, A1)]) + print("Using fp8/fp32 arithmetic: ", dot) +``` + ### Resources There are excellent resources out there, some of which are referenced in the code: diff --git a/arpfloat/__init__.py b/arpfloat/__init__.py new file mode 100644 index 0000000..3ece9cd --- /dev/null +++ b/arpfloat/__init__.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +""" +ARPFloat: Arbitrary Precision Floating-Point Library + +This library provides arbitrary precision floating-point arithmetic with +configurable precision and rounding modes. It implements IEEE 754 +semantics and supports standard arithmetic operations. + +Examples: + >>> from arpfloat import Float, FP16 + >>> x = from_f64(FP32, 2.5).cast(FP16) + >>> y = from_f64(FP32, 1.5).cast(FP16) + >>> x + y + 4 + + >>> sem = Semantics(10, 10, "Zero") + >>> sem + Semantics { exponent: 10, precision: 10, mode: Zero } + >>> Float(sem, False, 1, 13) + .0507 + + >>> arpfloat.pi(arpfloat.FP32) + 3.1415927 + >>> pi(FP16) + 3.14 + >>> pi(BF16) + 3.15 + +Constants: + BF16, FP16, FP32, FP64, FP128, FP256: Standard floating-point formats + pi, e, ln2, zero: Mathematical constants + Float, Semantics: Classes for representing floating-point numbers and their semantics + from_i64, from_f64: Constructors for creating Float objects from integers and floats +""" + +from ._arpfloat import PyFloat as Float +from ._arpfloat import PySemantics as Semantics +from ._arpfloat import pi, e, ln2, zero, fma +from ._arpfloat import from_fp64 as fp64 +from ._arpfloat import from_i64 as i64 + +# Add __radd__ method to Float class for sum() compatibility + + +def _float_radd(self, other): + if isinstance(other, (int, float)) and other == 0: + return self + return self.__add__(other) + +Float.__radd__ = _float_radd + +# Define standard floating-point types +# Parameters match IEEE 754 standard formats +BF16 = Semantics(8, 8, "NearestTiesToEven") # BFloat16 +FP16 = Semantics(5, 11, "NearestTiesToEven") # Half precision +FP32 = Semantics(8, 24, "NearestTiesToEven") # Single precision +FP64 = Semantics(11, 53, "NearestTiesToEven") # Double precision +FP128 = Semantics(15, 113, "NearestTiesToEven") # Quadruple precision +FP256 = Semantics(19, 237, "NearestTiesToEven") # Octuple precision + +version = "0.1.10" diff --git a/examples/fma.py b/examples/fma.py new file mode 100644 index 0000000..d38a21f --- /dev/null +++ b/examples/fma.py @@ -0,0 +1,20 @@ +import numpy as np +from arpfloat import FP32, fp64, Semantics, zero, fma + +# Create two random numpy arrays in the range [0,1) +A0 = np.random.rand(1024) +A1 = np.random.rand(1024) + +# Create the fp8 format (4 exponent bits, 3 mantissa bits + 1 implicit bit) +FP8 = Semantics(4, 3 + 1, "NearestTiesToEven") + +# Convert the arrays to FP8 +B0 = [fp64(x).cast(FP8) for x in A0] +B1 = [fp64(x).cast(FP8) for x in A1] + +acc = zero(FP32) +for x, y in zip(B0, B1): + acc = fma(x.cast(FP32), y.cast(FP32), acc) + +print("Using fp8/fp32 arithmetic: ", acc) +print("Using fp32 arithmetic : ", np.dot(A0, A1)) \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..aba8419 --- /dev/null +++ b/setup.py @@ -0,0 +1,29 @@ +from setuptools import setup +from setuptools_rust import Binding, RustExtension + +setup( + name="arpfloat", + version="0.1.10", # Match the version in Cargo.toml + description="Arbitrary-precision floating point library", + author="Nadav Rotem", + author_email="nadav256@gmail.com", + url="https://github.com/nadavrot/arpfloat", + rust_extensions=[ + RustExtension( + "arpfloat._arpfloat", + binding=Binding.PyO3, + debug=False, + features=["python"], + ) + ], + package_data={"arpfloat": ["py.typed"]}, + packages=["arpfloat"], + zip_safe=False, + python_requires=">=3.6", + classifiers=[ + "Programming Language :: Python :: 3", + "Programming Language :: Rust", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + ], +) diff --git a/src/arithmetic.rs b/src/arithmetic.rs index f8cab20..9762c67 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -4,7 +4,7 @@ extern crate alloc; use crate::bigint::BigInt; use super::bigint::LossFraction; -use super::float::{shift_right_with_loss, Category, Float, RoundingMode}; +use super::float::{Category, Float, RoundingMode}; use core::cmp::Ordering; use core::ops::{ Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign, @@ -71,7 +71,7 @@ impl Float { ab_mantissa = a_mantissa - b_mantissa - c; } ( - Self::new(sem, sign, a.get_exp(), ab_mantissa), + Self::from_parts(sem, sign, a.get_exp(), ab_mantissa), loss.invert(), ) } else { @@ -85,7 +85,10 @@ impl Float { } debug_assert!(a.get_exp() == b.get_exp()); let ab_mantissa = a.get_mantissa() + b.get_mantissa(); - (Self::new(sem, a.get_sign(), a.get_exp(), ab_mantissa), loss) + ( + Self::from_parts(sem, a.get_sign(), a.get_exp(), ab_mantissa), + loss, + ) } } @@ -122,7 +125,7 @@ impl Float { Self::inf(sem, b.get_sign() ^ subtract) } - (Category::Zero, Category::Normal) => Self::new( + (Category::Zero, Category::Normal) => Self::from_parts( sem, b.get_sign() ^ subtract, b.get_exp(), @@ -141,6 +144,16 @@ impl Float { } (Category::Normal, Category::Normal) => { + // The IEEE 754 spec (section 6.3) states that cancellation + // results in a positive zero, except for the case of the + // negative rounding mode. + let cancellation = subtract == (a.get_sign() == b.get_sign()); + let same_absolute_number = a.same_absolute_value(b); + if cancellation && same_absolute_number { + let is_negative = RoundingMode::Negative == rm; + return Self::zero(sem, is_negative); + } + let mut res = Self::add_or_sub_normals(a, b, subtract); res.0.normalize(rm, res.1); res.0 @@ -192,6 +205,17 @@ fn test_addition() { ); } } + + // Check that adding a negative and positive results in a positive zero for + // the default rounding mode. + let a = Float::from_f64(4.0); + let b = Float::from_f64(-4.0); + let c = Float::add(a.clone(), b); + let d = Float::sub(a.clone(), a); + assert!(c.is_zero()); + assert!(!c.is_negative()); + assert!(d.is_zero()); + assert!(!d.is_negative()); } // Pg 120. Chapter 4. Basic Properties and Algorithms. @@ -371,29 +395,17 @@ impl Float { // log(2^(e_a+1)*2^(e_b+1)) = e_a + e_b + 2. let mut exp = a.get_exp() + b.get_exp(); - let mut loss = LossFraction::ExactlyZero; - let a_significand = a.get_mantissa(); let b_significand = b.get_mantissa(); - - let mut ab_significand = a_significand * b_significand; - let first_non_zero = ab_significand.msb_index(); + let ab_significand = a_significand * b_significand; // The exponent is correct, but the bits are not in the right place. // Set the right exponent for where the bits are placed, and fix the // exponent below. exp -= sem.get_mantissa_len() as i64; - let precision = a.get_semantics().get_precision(); - if first_non_zero > precision { - let bits = first_non_zero - precision; - - (ab_significand, loss) = - shift_right_with_loss(&ab_significand, bits); - exp += bits as i64; - } - - (Self::new(sem, sign, exp, ab_significand), loss) + let loss = LossFraction::ExactlyZero; + (Self::from_parts(sem, sign, exp, ab_significand), loss) } } @@ -572,7 +584,7 @@ impl Float { Ordering::Greater => LossFraction::MoreThanHalf, }; - let x = Self::new(sem, sign, exp, a_mantissa); + let x = Self::from_parts(sem, sign, exp, a_mantissa); (x, loss) } } @@ -769,3 +781,164 @@ fn test_famous_pentium4_bug() { let result = res.to_string(); assert!(result.starts_with("1.333820449136241002")); } + +impl Float { + // Perform a fused multiply-add of normal numbers, without rounding. + fn fused_mul_add_normals( + a: &Self, + b: &Self, + c: &Self, + ) -> (Self, LossFraction) { + debug_assert_eq!(a.get_semantics(), b.get_semantics()); + let sem = a.get_semantics(); + + // Multiply a and b, without rounding. + let sign = a.get_sign() ^ b.get_sign(); + let mut ab = Self::mul_normals(a, b, sign).0; + + // Shift the product, to allow enough precision for the addition. + // Notice that this can be implemented more efficiently with 3 extra + // bits and sticky bits. + // See 8.5. Floating-Point Fused Multiply-Add, Page 255. + let mut c = c.clone(); + let extra_bits = sem.get_precision() + 1; + ab.shift_significand_left(extra_bits as u64); + c.shift_significand_left(extra_bits as u64); + + // Perform the addition, without rounding. + Self::add_or_sub_normals(&ab, &c, false) + } + + /// Compute a*b + c, with the rounding mode `rm`. + pub fn fused_mul_add_with_rm( + a: &Self, + b: &Self, + c: &Self, + rm: RoundingMode, + ) -> Self { + if a.is_normal() && b.is_normal() && c.is_normal() { + let (mut res, loss) = Self::fused_mul_add_normals(a, b, c); + res.normalize(rm, loss); // Finally, round the result. + res + } else { + // Perform two operations. First, handle non-normal values. + + // NaN anything = NaN + if a.is_nan() || b.is_nan() || c.is_nan() { + return Self::nan(a.get_semantics(), a.get_sign()); + } + // (infinity * 0) + c = NaN + if (a.is_inf() && b.is_zero()) || (a.is_zero() && b.is_inf()) { + return Self::nan(a.get_semantics(), a.get_sign()); + } + // (normal * normal) + infinity = infinity + if a.is_normal() && b.is_normal() && c.is_inf() { + return c.clone(); + } + // (normal * 0) + c = c + if a.is_zero() || b.is_zero() { + return c.clone(); + } + + // Multiply (with rounding), and add (with rounding). + let ab = Self::mul_with_rm(a, b, rm); + Self::add_with_rm(&ab, c, rm) + } + } + + /// Compute a*b + c. + pub fn fma(a: &Self, b: &Self, c: &Self) -> Self { + Self::fused_mul_add_with_rm(a, b, c, c.get_rounding_mode()) + } +} + +#[test] +fn test_fma() { + let v0 = -10.; + let v1 = -1.1; + let v2 = 0.000000000000000000000000000000000000001; + let af = Float::from_f64(v0); + let bf = Float::from_f64(v1); + let cf = Float::from_f64(v2); + + let r = Float::fused_mul_add_with_rm( + &af, + &bf, + &cf, + RoundingMode::NearestTiesToEven, + ); + + assert_eq!(f64::mul_add(v0, v1, v2), r.as_f64()); +} + +#[cfg(feature = "std")] +#[test] +fn test_fma_simple() { + use super::utils; + // Test the multiplication of various irregular values. + let values = utils::get_special_test_values(); + for a in values { + for b in values { + for c in values { + let af = Float::from_f64(a); + let bf = Float::from_f64(b); + let cf = Float::from_f64(c); + + let rf = Float::fused_mul_add_with_rm( + &af, + &bf, + &cf, + RoundingMode::NearestTiesToEven, + ); + + let r0 = rf.as_f64(); + let r1: f64 = a.mul_add(b, c); + assert_eq!(r0.is_finite(), r1.is_finite()); + assert_eq!(r0.is_nan(), r1.is_nan()); + assert_eq!(r0.is_infinite(), r1.is_infinite()); + // Check that the results are bit identical, or are both NaN. + assert!(r1.is_nan() || r1.is_infinite() || r0 == r1); + } + } + } +} + +#[test] +fn test_fma_random_vals() { + use super::utils; + + let mut lfsr = utils::Lfsr::new(); + + fn mul_f32(a: f32, b: f32, c: f32) -> f32 { + let a = Float::from_f32(a); + let b = Float::from_f32(b); + let c = Float::from_f32(c); + let k = Float::fused_mul_add_with_rm( + &a, + &b, + &c, + RoundingMode::NearestTiesToEven, + ); + k.as_f32() + } + + for _ in 0..50000 { + let v0 = lfsr.get64() as u32; + let v1 = lfsr.get64() as u32; + let v2 = lfsr.get64() as u32; + + let f0 = f32::from_bits(v0); + let f1 = f32::from_bits(v1); + let f2 = f32::from_bits(v2); + + let r0 = mul_f32(f0, f1, f2); + let r1 = f32::mul_add(f0, f1, f2); + assert_eq!(r0.is_finite(), r1.is_finite()); + assert_eq!(r0.is_nan(), r1.is_nan()); + assert_eq!(r0.is_infinite(), r1.is_infinite()); + let r0_bits = r0.to_bits(); + let r1_bits = r1.to_bits(); + // Check that the results are bit identical, or are both NaN. + assert!(r1.is_nan() || r0_bits == r1_bits); + } +} diff --git a/src/bigint.rs b/src/bigint.rs index 15e629b..64b0406 100644 --- a/src/bigint.rs +++ b/src/bigint.rs @@ -536,6 +536,11 @@ impl BigInt { use std::println; println!("[{}]", self.as_binary()); } + + #[cfg(not(feature = "std"))] + pub fn dump(&self) { + // No-op in no_std environments + } } impl Default for BigInt { diff --git a/src/cast.rs b/src/cast.rs index e7cee21..6774b39 100644 --- a/src/cast.rs +++ b/src/cast.rs @@ -20,7 +20,8 @@ impl Float { /// Load the big int `val` into the float. Notice that the number may /// overflow, or rounded to the nearest even integer. pub fn from_bigint(sem: Semantics, val: BigInt) -> Self { - let mut a = Self::new(sem, false, sem.get_mantissa_len() as i64, val); + let mut a = + Self::from_parts(sem, false, sem.get_mantissa_len() as i64, val); a.normalize(sem.get_rounding_mode(), LossFraction::ExactlyZero); a } @@ -50,7 +51,7 @@ impl Float { return i64::MAX; } } - let rm = self.get_semantics().get_rounding_mode(); + let rm = self.get_rounding_mode(); let val = self.convert_normal_to_integer(rm); if self.get_sign() { -(val.as_u64() as i64) @@ -85,7 +86,12 @@ impl Float { let mut m = self.get_mantissa(); m.shift_right(trim); m.shift_left(trim); - Self::new(self.get_semantics(), self.get_sign(), self.get_exp(), m) + Self::from_parts( + self.get_semantics(), + self.get_sign(), + self.get_exp(), + m, + ) } /// Returns a number rounded to nearest integer, away from zero. @@ -120,7 +126,7 @@ impl Float { let trim = (self.get_mantissa_len() as i64 - exp) as usize; let (mut m, loss) = shift_right_with_loss(&self.get_mantissa(), trim); m.shift_left(trim); - let t = Self::new(sem, self.get_sign(), self.get_exp(), m); + let t = Self::from_parts(sem, self.get_sign(), self.get_exp(), m); if loss.is_lt_half() { t @@ -183,7 +189,7 @@ impl Float { } let mantissa = BigInt::from_u64(mantissa); - Self::new(sem, sign, exp, mantissa) + Self::from_parts(sem, sign, exp, mantissa) } /// Cast to another float using the non-default rounding mode `rm`. @@ -216,7 +222,7 @@ impl Float { } /// Convert from one float format to another. pub fn cast(&self, to: Semantics) -> Float { - self.cast_with_rm(to, self.get_semantics().get_rounding_mode()) + self.cast_with_rm(to, self.get_rounding_mode()) } fn as_native_float(&self) -> u64 { diff --git a/src/float.rs b/src/float.rs index 3e9d58a..f621439 100644 --- a/src/float.rs +++ b/src/float.rs @@ -17,6 +17,20 @@ pub enum RoundingMode { Negative, } +impl RoundingMode { + /// Create a rounding mode from a string, if valid, or return none. + pub fn from_string(s: &str) -> Option { + match s { + "NearestTiesToEven" => Some(RoundingMode::NearestTiesToEven), + "NearestTiesToAway" => Some(RoundingMode::NearestTiesToAway), + "Zero" => Some(RoundingMode::Zero), + "Positive" => Some(RoundingMode::Positive), + "Negative" => Some(RoundingMode::Negative), + _ => None, + } + } +} + /// Controls the semantics of a floating point number with: /// 'precision', that determines the number of bits, 'exponent' that controls /// the dynamic range of the number, and rounding mode that controls how @@ -150,7 +164,12 @@ impl Float { } /// Create a new normal floating point number. - pub fn new(sem: Semantics, sign: bool, exp: i64, mantissa: BigInt) -> Self { + pub fn from_parts( + sem: Semantics, + sign: bool, + exp: i64, + mantissa: BigInt, + ) -> Self { if mantissa.is_zero() { return Float::zero(sem, sign); } @@ -267,6 +286,11 @@ impl Float { self.sem } + /// Returns the rounding mode of the number. + pub fn get_rounding_mode(&self) -> RoundingMode { + self.sem.get_rounding_mode() + } + /// Update the sign of the float to `sign`. True means negative. pub fn set_sign(&mut self, sign: bool) { self.sign = sign @@ -337,6 +361,11 @@ impl Float { } } + #[cfg(not(feature = "std"))] + pub fn dump(&self) { + // No-op in no_std environments + } + /// Returns the exponent bias for the number, as a positive number. /// https://en.wikipedia.org/wiki/IEEE_754#Basic_and_interchange_formats pub(crate) fn get_bias(&self) -> i64 { @@ -356,6 +385,8 @@ impl Float { // Table 3.5 — Binary interchange format parameters. use RoundingMode::NearestTiesToEven as nte; +/// Predefined BF16 float with 8 exponent bits, and 7 mantissa bits. +pub const BF16: Semantics = Semantics::new(8, 8, nte); /// Predefined FP16 float with 5 exponent bits, and 10 mantissa bits. pub const FP16: Semantics = Semantics::new(5, 11, nte); /// Predefined FP32 float with 8 exponent bits, and 23 mantissa bits. @@ -416,7 +447,7 @@ impl Float { fn overflow(&mut self, rm: RoundingMode) { let bounds = self.get_exp_bounds(); let inf = Self::inf(self.sem, self.sign); - let max = Self::new( + let max = Self::from_parts( self.sem, self.sign, bounds.1, @@ -489,6 +520,21 @@ impl Float { } } + /// Returns true if the absolute value of the two numbers are the same. + pub(crate) fn same_absolute_value(&self, other: &Self) -> bool { + if self.category != other.category { + return false; + } + match self.category { + Category::Infinity => true, + Category::NaN => true, + Category::Zero => true, + Category::Normal => { + self.exp == other.exp && self.mantissa == other.mantissa + } + } + } + /// Normalize the number by adjusting the exponent to the legal range, shift /// the mantissa to the msb, and round the number if bits are lost. This is /// based on Neil Booth' implementation in APFloat. diff --git a/src/lib.rs b/src/lib.rs index 6e54a15..98ed239 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -135,4 +135,8 @@ pub use self::bigint::BigInt; pub use self::float::Float; pub use self::float::RoundingMode; pub use self::float::Semantics; -pub use self::float::{FP128, FP16, FP256, FP32, FP64}; +pub use self::float::{BF16, FP128, FP16, FP256, FP32, FP64}; + +// Conditionally include a module based on feature flag +#[cfg(feature = "python")] +pub mod py; diff --git a/src/operations/functions.rs b/src/operations/functions.rs index 1264387..8816722 100644 --- a/src/operations/functions.rs +++ b/src/operations/functions.rs @@ -212,7 +212,7 @@ impl Float { return self.clone(); } - let mut r = Self::new( + let mut r = Self::from_parts( self.get_semantics(), self.get_sign(), self.get_exp() + scale, diff --git a/src/py.rs b/src/py.rs new file mode 100644 index 0000000..3eaf630 --- /dev/null +++ b/src/py.rs @@ -0,0 +1,372 @@ +use crate::{BigInt, Float, RoundingMode, Semantics}; +use core::ops::{Add, Div, Mul, Sub}; +use pyo3::prelude::*; +use std::format; +use std::string::String; +use std::string::ToString; + +/// Semantics class defining precision and rounding behavior. +/// +/// This class encapsulates the parameters that define the precision and +/// rounding behavior of floating-point operations. +#[pyclass] +struct PySemantics { + inner: Semantics, +} + +#[pymethods] +impl PySemantics { + /// Create a new semantics object. + /// + /// Args: + /// exp_size: The size of the exponent in bits + /// mantissa_size: The size of the mantissa, including the implicit bit + /// rounding_mode: The rounding mode to use: + /// "NearestTiesToEven", "NearestTiesToAway", + /// "Zero", "Positive", "Negative" + #[new] + fn new(exp_size: i64, mantissa_size: u64, rounding_mode_str: &str) -> Self { + let rm = RoundingMode::from_string(rounding_mode_str); + assert!(rm.is_some(), "Invalid rounding mode"); + let sem = Semantics::new( + exp_size as usize, + mantissa_size as usize, + rm.unwrap(), + ); + PySemantics { inner: sem } + } + /// Returns the length of the exponent in bits. + fn get_exponent_len(&self) -> usize { + self.inner.get_exponent_len() + } + /// Returns the length of the mantissa in bits. + fn get_mantissa_len(&self) -> usize { + self.inner.get_mantissa_len() + } + /// Returns the rounding mode as a string. + fn get_rounding_mode(&self) -> String { + self.inner.get_rounding_mode().as_string().to_string() + } + fn __str__(&self) -> String { + format!("{:?}", self.inner) + } + fn __repr__(&self) -> String { + self.__str__() + } +} + +/// A class representing arbitrary precision floating-point numbers. +/// +/// This class implements IEEE 754-like floating-point arithmetic with +/// configurable precision and rounding modes. +#[pyclass] +struct PyFloat { + inner: Float, +} + +#[pymethods] +impl PyFloat { + /// Create a new floating-point number. + /// + /// Args: + /// sem: The semantics (precision and rounding mode) for this number + /// is_negative: Whether the number is negative (sign bit) + /// exp: The biased exponent value (integer) + /// mantissa: The mantissa value (integer) + #[new] + fn new( + sem: &Bound<'_, PyAny>, + is_negative: bool, + exp: i64, + mantissa: u64, + ) -> Self { + let sem: PyRef = sem.extract().unwrap(); + let mut man = BigInt::from_u64(mantissa); + man.flip_bit(sem.inner.get_mantissa_len()); // Add the implicit bit. + let bias = sem.inner.get_bias(); + PyFloat { + inner: Float::from_parts(sem.inner, is_negative, exp - bias, man), + } + } + + fn __str__(&self) -> String { + self.inner.to_string() + } + fn __repr__(&self) -> String { + self.__str__() + } + /// Returns the mantissa of the float. + fn get_mantissa(&self) -> u64 { + self.inner.get_mantissa().as_u64() + } + /// Returns the exponent of the float. + fn get_exponent(&self) -> i64 { + self.inner.get_exp() + } + /// Returns the category of the float. + fn get_category(&self) -> String { + format!("{:?}", self.inner.get_category()) + } + /// Returns the semantics of the float. + fn get_semantics(&self) -> PySemantics { + PySemantics { + inner: self.inner.get_semantics(), + } + } + /// Get rounding mode of the number. + fn get_rounding_mode(&self) -> String { + self.inner.get_rounding_mode().as_string().to_string() + } + /// Returns true if the Float is negative + fn is_negative(&self) -> bool { + self.inner.is_negative() + } + /// Returns true if the Float is +-inf. + fn is_inf(&self) -> bool { + self.inner.is_inf() + } + /// Returns true if the Float is a +- NaN. + fn is_nan(&self) -> bool { + self.inner.is_nan() + } + /// Returns true if the Float is a +- zero. + fn is_zero(&self) -> bool { + self.inner.is_zero() + } + + /// Returns true if this number is normal (not Zero, Nan, Inf). + fn is_normal(&self) -> bool { + self.inner.is_normal() + } + + fn __add__(&self, other: &PyFloat) -> PyFloat { + self.add(other) + } + + fn __sub__(&self, other: &PyFloat) -> PyFloat { + self.sub(other) + } + + fn __mul__(&self, other: &PyFloat) -> PyFloat { + self.mul(other) + } + fn __truediv__(&self, other: &PyFloat) -> PyFloat { + self.div(other) + } + fn add(&self, other: &PyFloat) -> PyFloat { + let val = self.inner.clone().add(other.inner.clone()); + PyFloat { inner: val } + } + fn mul(&self, other: &PyFloat) -> PyFloat { + let val = self.inner.clone().mul(other.inner.clone()); + PyFloat { inner: val } + } + fn sub(&self, other: &PyFloat) -> PyFloat { + let val = self.inner.clone().sub(other.inner.clone()); + PyFloat { inner: val } + } + fn div(&self, other: &PyFloat) -> PyFloat { + let val = self.inner.clone().div(other.inner.clone()); + PyFloat { inner: val } + } + /// Returns the number raised to the power of `exp` which is an integer. + fn powi(&self, exp: u64) -> PyFloat { + PyFloat { + inner: self.inner.powi(exp), + } + } + /// Returns the number raised to the power of `exp` which is a float. + fn pow(&self, exp: &PyFloat) -> PyFloat { + PyFloat { + inner: self.inner.pow(&exp.inner), + } + } + /// Returns the exponential of the number. + fn exp(&self) -> PyFloat { + PyFloat { + inner: self.inner.exp(), + } + } + /// Returns the natural logarithm of the number. + fn log(&self) -> PyFloat { + PyFloat { + inner: self.inner.log(), + } + } + /// Returns the sigmoid of the number. + fn sigmoid(&self) -> PyFloat { + PyFloat { + inner: self.inner.sigmoid(), + } + } + /// Returns the absolute value of the number. + fn abs(&self) -> PyFloat { + PyFloat { + inner: self.inner.abs(), + } + } + /// Returns the maximum of two numbers (as defined by IEEE 754). + fn max(&self, other: &PyFloat) -> PyFloat { + PyFloat { + inner: self.inner.max(&other.inner), + } + } + /// Returns the minimum of two numbers (as defined by IEEE 754). + fn min(&self, other: &PyFloat) -> PyFloat { + PyFloat { + inner: self.inner.min(&other.inner), + } + } + /// Returns the remainder of the division of two numbers. + fn rem(&self, other: &PyFloat) -> PyFloat { + PyFloat { + inner: self.inner.rem(&other.inner), + } + } + /// Cast the number to another semantics. + fn cast(&self, sem: &Bound<'_, PyAny>) -> PyFloat { + let sem: PyRef = sem.extract().unwrap(); + PyFloat { + inner: self.inner.cast(sem.inner), + } + } + /// Cast the number to another semantics with a specific rounding mode. + fn cast_with_rm(&self, sem: &Bound<'_, PyAny>, rm: &str) -> PyFloat { + let sem: PyRef = sem.extract().unwrap(); + let rm = RoundingMode::from_string(rm); + assert!(rm.is_some(), "Invalid rounding mode"); + PyFloat { + inner: self.inner.cast_with_rm(sem.inner, rm.unwrap()), + } + } + /// Returns the sine of the number. + fn sin(&self) -> PyFloat { + PyFloat { + inner: self.inner.sin(), + } + } + /// Returns the cosine of the number. + fn cos(&self) -> PyFloat { + PyFloat { + inner: self.inner.cos(), + } + } + /// Returns the tangent of the number. + fn tan(&self) -> PyFloat { + PyFloat { + inner: self.inner.tan(), + } + } + /// convert to f64. + fn to_float64(&self) -> f64 { + self.inner.as_f64() + } + /// Convert the number to a Continued Fraction of two integers. + /// Take 'n' iterations. + fn as_fraction(&self, n: usize) -> (u64, u64) { + let (a, b) = self.inner.as_fraction(n); + (a.as_u64(), b.as_u64()) + } + /// Prints the number using the internal representation. + fn dump(&self) { + self.inner.dump(); + } +} // impl PyFloat + +/// Returns the mathematical constant pi with the given semantics. +/// +/// Args: +/// sem: The semantics to use for representing pi +#[pyfunction] +fn pi(sem: &Bound<'_, PyAny>) -> PyResult { + let sem: PyRef = sem.extract()?; + Ok(PyFloat { + inner: Float::pi(sem.inner), + }) +} + +/// Returns the fused multiply-add operation of three numbers. +/// +/// Args: (a * b) + c +#[pyfunction] +fn fma(a: &PyFloat, b: &PyFloat, c: &PyFloat) -> PyResult { + Ok(PyFloat { + inner: Float::fma(&a.inner, &b.inner, &c.inner), + }) +} + +/// Returns the mathematical constant e (Euler's number) with the given semantics. +/// +/// Args: +/// sem: The semantics to use for representing e +#[pyfunction] +fn e(sem: &Bound<'_, PyAny>) -> PyResult { + let sem: PyRef = sem.extract()?; + Ok(PyFloat { + inner: Float::e(sem.inner), + }) +} + +/// Returns the natural logarithm of 2 (ln(2)) with the given semantics. +/// +/// Args: +/// sem: The semantics to use for representing ln(2) +#[pyfunction] +fn ln2(sem: &Bound<'_, PyAny>) -> PyResult { + let sem: PyRef = sem.extract()?; + Ok(PyFloat { + inner: Float::ln2(sem.inner), + }) +} + +/// Returns the number zero with the given semantics. +/// +/// Args: +/// sem: The semantics to use for representing e +#[pyfunction] +fn zero(sem: &Bound<'_, PyAny>) -> PyResult { + let sem: PyRef = sem.extract()?; + Ok(PyFloat { + inner: Float::zero(sem.inner, false), + }) +} + +/// Returns a new float with the integer value 'val' with the given semantics. +/// +/// Args: +/// sem: The semantics to use +/// val: The integer value +#[pyfunction] +fn from_i64(sem: &Bound<'_, PyAny>, val: i64) -> PyResult { + let sem: PyRef = sem.extract()?; + Ok(PyFloat { + inner: Float::from_i64(sem.inner, val), + }) +} + +/// Returns a new float with the fp64 value 'val'. +/// +/// Args: +/// val: The f64 value +#[pyfunction] +fn from_fp64(val: f64) -> PyResult { + Ok(PyFloat { + inner: Float::from_f64(val), + }) +} + +#[pymodule] +fn _arpfloat(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + + // Add the functions to the module + m.add_function(wrap_pyfunction!(pi, m)?)?; + m.add_function(wrap_pyfunction!(e, m)?)?; + m.add_function(wrap_pyfunction!(ln2, m)?)?; + m.add_function(wrap_pyfunction!(zero, m)?)?; + m.add_function(wrap_pyfunction!(fma, m)?)?; + m.add_function(wrap_pyfunction!(from_i64, m)?)?; + m.add_function(wrap_pyfunction!(from_fp64, m)?)?; + Ok(()) +} diff --git a/src/string.rs b/src/string.rs index 5c1843d..e63ff80 100644 --- a/src/string.rs +++ b/src/string.rs @@ -119,6 +119,14 @@ impl Float { /// representation of numbers. For all of that that check out the paper: /// "How to Print Floating-Point Numbers Accurately" by Steele and White. fn convert_to_string(&self) -> String { + // In order to print decimal digits we need a minimum number of mantissa + // bits for the conversion. Small floats (such as BF16) don't have + // enough bits, so we cast to a larger number. + if self.get_semantics().get_mantissa_len() < 16 { + use crate::FP32; + return self.cast(FP32).to_string(); + } + let result = if self.get_sign() { "-" } else { "" }; let mut result: String = result.to_string(); @@ -299,7 +307,7 @@ mod from { fn parse_with_exp( value: &str, ) -> Result<((BigInt, usize), Option), ParseError> { - let idx = value.find(|c| c == 'e' || c == 'E'); + let idx = value.find(['e', 'E']); // Split the number to the digits and the exponent. let (num_raw, exp) = if let Some(idx) = idx { let (l, r) = value.split_at(idx); @@ -404,29 +412,37 @@ mod from { fn test_convert_to_string() { use crate::FP16; use crate::FP64; + use core::f64; use std::format; fn to_str_w_fp16(val: f64) -> String { format!("{}", Float::from_f64(val).cast(FP16)) } + fn to_str_w_bf16(val: f64) -> String { + use crate::BF16; + format!("{}", Float::from_f64(val).cast(BF16)) + } + fn to_str_w_fp64(val: f64) -> String { format!("{}", Float::from_f64(val).cast(FP64)) } assert_eq!("-0.0", to_str_w_fp16(-0.)); - assert_eq!(".3", to_str_w_fp16(0.3)); + assert_eq!(".30004882", to_str_w_fp16(0.3)); assert_eq!("4.5", to_str_w_fp16(4.5)); assert_eq!("256.", to_str_w_fp16(256.)); assert_eq!("Inf", to_str_w_fp16(65534.)); assert_eq!("-Inf", to_str_w_fp16(-65534.)); - assert_eq!(".0999", to_str_w_fp16(0.1)); + assert_eq!(".09997558", to_str_w_fp16(0.1)); assert_eq!(".1", to_str_w_fp64(0.1)); assert_eq!(".29999999999999998", to_str_w_fp64(0.3)); assert_eq!("2251799813685248.", to_str_w_fp64((1u64 << 51) as f64)); assert_eq!("1995.1994999999999", to_str_w_fp64(1995.1995)); + assert_eq!("3.140625", to_str_w_bf16(f64::consts::PI)); } +#[cfg(feature = "std")] #[test] fn test_from_string() { assert_eq!("-3.", Float::try_from("-3.0").unwrap().to_string());