From ff424de10dd9a9b63dbdd74b2a87a737c39945c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?ZHU=20Yuhao=20=E6=9C=B1=E5=AE=87=E6=B5=A9?= Date: Tue, 4 Mar 2025 12:23:03 +0100 Subject: [PATCH 1/3] Implement Roundable trait --- README.md | 17 +-- decimojo/__init__.mojo | 2 +- decimojo/decimal.mojo | 90 ++++++------ decimojo/mathematics.mojo | 205 ++++++++++++++++++++++++++++ examples/examples.mojo | 22 +-- tests/test_decimal_arithmetics.mojo | 13 +- tests/test_decimal_rounding.mojo | 53 +++---- 7 files changed, 313 insertions(+), 89 deletions(-) diff --git a/README.md b/README.md index 806bc9b3..73cd27c3 100644 --- a/README.md +++ b/README.md @@ -72,18 +72,19 @@ var reciprocal = Decimal("2") ** (-1) # 0.5 ```mojo from decimojo import Decimal from decimojo.rounding_mode import RoundingMode +from decimojo.mathematics import round var num = Decimal("123.456789") # Round to various decimal places with different modes -var default_round = num.round(2) # 123.46 (HALF_EVEN) -var down = num.round(2, RoundingMode.DOWN()) # 123.45 (truncate) -var up = num.round(2, RoundingMode.UP()) # 123.46 (away from zero) -var half_up = num.round(2, RoundingMode.HALF_UP()) # 123.46 (≥0.5 rounds up) +var default_round = round(num, 2) # 123.46 (HALF_EVEN) +var down = round(num, 2, RoundingMode.DOWN()) # 123.45 (truncate) +var up = round(num, 2, RoundingMode.UP()) # 123.46 (away from zero) +var half_up = round(num, 2, RoundingMode.HALF_UP()) # 123.46 (≥0.5 rounds up) # Rounding special cases -var half_value = Decimal("123.5").round(0) # 124 (banker's rounding) -var half_odd = Decimal("124.5").round(0) # 124 (banker's rounding to even) +var half_value = round(Decimal("123.5"), 0) # 124 (banker's rounding) +var half_odd = round(Decimal("124.5"), 0) # 124 (banker's rounding to even) ``` ### 4. Working with Scale and Precision @@ -93,10 +94,10 @@ var d = Decimal("123.45") # Scale is 2 print(d.scale()) # Prints: 2 # Changing scale through rounding -var more_precise = d.round(4) # 123.4500 +var more_precise = round(d, 4) # 123.4500 print(more_precise.scale()) # Prints: 4 -var less_precise = d.round(1) # 123.5 (rounds up) +var less_precise = round(d, 1) # 123.5 (rounds up) print(less_precise.scale()) # Prints: 1 # Scale after operations diff --git a/decimojo/__init__.mojo b/decimojo/__init__.mojo index e605a0b2..3b45c928 100644 --- a/decimojo/__init__.mojo +++ b/decimojo/__init__.mojo @@ -6,4 +6,4 @@ from .decimal import Decimal -from .mathematics import power +from .mathematics import power, round diff --git a/decimojo/decimal.mojo b/decimojo/decimal.mojo index 004d7666..92ef84c0 100644 --- a/decimojo/decimal.mojo +++ b/decimojo/decimal.mojo @@ -26,7 +26,7 @@ import math.math as mt from .rounding_mode import RoundingMode -struct Decimal(Writable): +struct Decimal(Roundable, Writable): """ Correctly-rounded fixed-precision number. @@ -1198,6 +1198,57 @@ struct Decimal(Writable): return decimal.power(self, exponent) + # ===------------------------------------------------------------------=== # + # Other dunders that implements tratis + # round + # ===------------------------------------------------------------------=== # + + fn __round__( + self, ndigits: Int = 0, mode: RoundingMode = RoundingMode.HALF_EVEN() + ) raises -> Self: + """ + Rounds this Decimal to the specified number of decimal places. + + Args: + ndigits: Number of decimal places to round to. + If 0 (default), rounds to the nearest integer. + If positive, rounds to the given number of decimal places. + If negative, rounds to the left of the decimal point. + mode: The rounding mode to use. Defaults to RoundingMode.HALF_EVEN. + + Returns: + A new Decimal rounded to the specified precision + + Raises: + Error: If the operation would result in overflow. + + Examples: + ``` + round(Decimal("3.14159"), 2) # Returns 3.14 + round("3.14159") # Returns 3 + round("1234.5", -2) # Returns 1200 + ``` + . + """ + + return decimojo.round(self, ndigits, mode) + + fn __round__(self, ndigits: Int = 0) -> Self: + """ + **OVERLOAD** + Rounds this Decimal to the specified number of decimal places. + """ + + return decimojo.round(self, ndigits, RoundingMode.HALF_EVEN()) + + fn __round__(self) -> Self: + """ + **OVERLOAD** + Rounds this Decimal to the specified number of decimal places. + """ + + return decimojo.round(self, 0, RoundingMode.HALF_EVEN()) + # ===------------------------------------------------------------------=== # # Other methods # ===------------------------------------------------------------------=== # @@ -1312,43 +1363,6 @@ struct Decimal(Writable): """Returns the scale (number of decimal places) of this Decimal.""" return Int((self.flags & Self.SCALE_MASK) >> Self.SCALE_SHIFT) - fn round( - self, - decimal_places: Int, - rounding_mode: RoundingMode = RoundingMode.HALF_EVEN(), - ) -> Decimal: - """ - Rounds the Decimal to the specified number of decimal places. - - Args: - decimal_places: Number of decimal places to round to. - rounding_mode: Rounding mode to use (defaults to HALF_EVEN/banker's rounding). - - Returns: - A new Decimal rounded to the specified number of decimal places - - Examples: - ``` - var d = Decimal("123.456789") - var rounded = d.round(2) # Returns 123.46 (using banker's rounding) - var down = d.round(3, RoundingMode.DOWN()) # Returns 123.456 (truncated) - var up = d.round(1, RoundingMode.UP()) # Returns 123.5 (rounded up) - ``` - . - """ - var current_scale = self.scale() - - # If already at the desired scale, return a copy - if current_scale == decimal_places: - return self - - # If we need more decimal places, scale up - if decimal_places > current_scale: - return self._scale_up(decimal_places - current_scale) - - # Otherwise, scale down with the specified rounding mode - return self._scale_down(current_scale - decimal_places, rounding_mode) - # ===------------------------------------------------------------------=== # # Internal methods # ===------------------------------------------------------------------=== # diff --git a/decimojo/mathematics.mojo b/decimojo/mathematics.mojo index f28cdcdc..296ae14b 100644 --- a/decimojo/mathematics.mojo +++ b/decimojo/mathematics.mojo @@ -8,12 +8,40 @@ # which supports correctly-rounded, fixed-point arithmetic. # # ===----------------------------------------------------------------------=== # +# +# List of functions in this module: +# +# power(base: Decimal, exponent: Decimal): Raises base to the power of exponent (integer exponents only) +# power(base: Decimal, exponent: Int): Convenience method for integer exponents +# sqrt(x: Decimal): Computes the square root of x using Newton-Raphson method +# root(x: Decimal, n: Int): Computes the nth root of x using Newton's method +# +# TODO Additional functions planned for future implementation: +# +# exp(x: Decimal): Computes e raised to the power of x +# ln(x: Decimal): Computes the natural logarithm of x +# log10(x: Decimal): Computes the base-10 logarithm of x +# sin(x: Decimal): Computes the sine of x (in radians) +# cos(x: Decimal): Computes the cosine of x (in radians) +# tan(x: Decimal): Computes the tangent of x (in radians) +# abs(x: Decimal): Returns the absolute value of x +# round(x: Decimal, places: Int, mode: RoundingMode): Rounds x to specified decimal places +# floor(x: Decimal): Returns the largest integer <= x +# ceil(x: Decimal): Returns the smallest integer >= x +# gcd(a: Decimal, b: Decimal): Returns greatest common divisor of a and b +# lcm(a: Decimal, b: Decimal): Returns least common multiple of a and b +# ===----------------------------------------------------------------------=== # """ Implements functions for mathematical operations on Decimal objects. """ from decimojo.decimal import Decimal +from decimojo.rounding_mode import RoundingMode + +# ===----------------------------------------------------------------------=== # +# Arithmetic operations functions +# ===----------------------------------------------------------------------=== # fn power(base: Decimal, exponent: Decimal) raises -> Decimal: @@ -99,3 +127,180 @@ fn power(base: Decimal, exponent: Int) raises -> Decimal: A new Decimal containing the result. """ return power(base, Decimal(exponent)) + + +# fn sqrt(x: Decimal) raises -> Decimal: +# """ +# Computes the square root of a Decimal value using Newton-Raphson method. + +# Args: +# x: The Decimal value to compute the square root of. + +# Returns: +# A new Decimal containing the square root of x. + +# Raises: +# Error: If x is negative. +# """ +# # Special cases +# if x.is_negative(): +# raise Error("Cannot compute square root of negative number") + +# if x.is_zero(): +# return Decimal.ZERO() + +# if x == Decimal.ONE(): +# return Decimal.ONE() + +# # Working precision - we'll compute with extra digits and round at the end +# var working_precision = x.scale() * 2 +# working_precision = max(working_precision, UInt32(10)) # At least 10 digits + +# # Initial guess - a good guess helps converge faster +# # For numbers near 1, use the number itself +# # For very small or large numbers, scale appropriately +# var guess: Decimal +# var exponent = len(x.coefficient()) - x.scale() + +# if exponent >= 0 and exponent <= 3: +# # For numbers between 0.1 and 1000, start with x/2 + 0.5 +# guess = (x / Decimal("2")) + Decimal("0.5") +# else: +# # For larger/smaller numbers, make a smarter guess +# # This scales based on the magnitude of the number +# var shift = exponent / 2 +# if exponent % 2 != 0: +# # For odd exponents, adjust +# shift = (exponent + 1) / 2 + +# # Use an approximation based on the exponent +# if exponent > 0: +# guess = Decimal("10") ** shift +# else: +# guess = Decimal("0.1") ** (-shift) + +# # Newton-Raphson iterations +# # x_n+1 = (x_n + S/x_n) / 2 +# var prev_guess = Decimal.ZERO() +# var iteration_count = 0 +# var max_iterations = 100 # Prevent infinite loops + +# while guess != prev_guess and iteration_count < max_iterations: +# prev_guess = guess +# guess = (guess + (x / guess)) / Decimal("2") +# iteration_count += 1 + +# # Round to appropriate precision - typically half the working precision +# var result_precision = x.scale() +# if result_precision % 2 == 1: +# # For odd scales, add 1 to ensure proper rounding +# result_precision += 1 + +# # The result scale should be approximately half the input scale +# result_precision = result_precision / 2 + +# # Format to the appropriate number of decimal places +# var result_str = String(guess) +# var rounded_result = Decimal(result_str) + +# return rounded_result + + +# # Additional helper function to compute nth root +# fn root(x: Decimal, n: Int) raises -> Decimal: +# """ +# Computes the nth root of x using Newton's method. + +# Args: +# x: The base value +# n: The root to compute (must be positive) + +# Returns: +# The nth root of x as a Decimal + +# Raises: +# Error: If x is negative and n is even, or if n is not positive +# """ +# # Special cases +# if n <= 0: +# raise Error("Root index must be positive") + +# if n == 1: +# return x + +# if n == 2: +# return sqrt(x) + +# if x.is_zero(): +# return Decimal.ZERO() + +# if x == Decimal.ONE(): +# return Decimal.ONE() + +# if x.is_negative() and n % 2 == 0: +# raise Error("Cannot compute even root of negative number") + +# # The sign of the result +# var result_is_negative = x.is_negative() and n % 2 == 1 +# var abs_x = x.abs() + +# # Newton's method for nth root: +# # x_k+1 = ((n-1)*x_k + num/(x_k^(n-1)))/n + +# # Initial guess +# var guess = Decimal.ONE() +# if abs_x > Decimal.ONE(): +# guess = abs_x / Decimal(n) # Simple initial approximation + +# var prev_guess = Decimal.ZERO() +# var n_minus_1 = Decimal(n - 1) +# var n_decimal = Decimal(n) +# var max_iterations = 100 +# var iteration_count = 0 + +# while guess != prev_guess and iteration_count < max_iterations: +# prev_guess = guess + +# # Compute x_k^(n-1) +# var guess_pow_n_minus_1 = power(guess, n - 1) + +# # Next approximation +# guess = (n_minus_1 * guess + abs_x / guess_pow_n_minus_1) / n_decimal +# iteration_count += 1 + +# return -guess if result_is_negative else guess + + +# ===------------------------------------------------------------------------===# +# Rounding +# ===------------------------------------------------------------------------===# + + +fn round( + number: Decimal, + decimal_places: Int, + rounding_mode: RoundingMode = RoundingMode.HALF_EVEN(), +) -> Decimal: + """ + Rounds the Decimal to the specified number of decimal places. + + Args: + number: The Decimal to round. + decimal_places: Number of decimal places to round to. + rounding_mode: Rounding mode to use (defaults to HALF_EVEN/banker's rounding). + + Returns: + A new Decimal rounded to the specified number of decimal places. + """ + var current_scale = number.scale() + + # If already at the desired scale, return a copy + if current_scale == decimal_places: + return number + + # If we need more decimal places, scale up + if decimal_places > current_scale: + return number._scale_up(decimal_places - current_scale) + + # Otherwise, scale down with the specified rounding mode + return number._scale_down(current_scale - decimal_places, rounding_mode) diff --git a/examples/examples.mojo b/examples/examples.mojo index e10bde41..9244c8d9 100644 --- a/examples/examples.mojo +++ b/examples/examples.mojo @@ -1,7 +1,7 @@ """ Examples demonstrating the usage of Decimojo's Decimal type. """ -from decimojo import Decimal +from decimojo import Decimal, round from decimojo.rounding_mode import RoundingMode @@ -181,18 +181,18 @@ fn rounding_and_precision() raises: print("Original value:", d) # Default rounding (HALF_EVEN/banker's rounding) - var r1 = d.round(2) - var r2 = d.round(4) - var r3 = d.round(0) + var r1 = round(d, 2) + var r2 = round(d, 4) + var r3 = round(d, 0) print("Rounded to 2 places (default):", r1) print("Rounded to 4 places (default):", r2) print("Rounded to 0 places (default):", r3) # Using different rounding modes - var down = d.round(2, RoundingMode.DOWN()) - var up = d.round(2, RoundingMode.UP()) - var half_up = d.round(2, RoundingMode.HALF_UP()) - var half_even = d.round(2, RoundingMode.HALF_EVEN()) + var down = round(d, 2, RoundingMode.DOWN()) + var up = round(d, 2, RoundingMode.UP()) + var half_up = round(d, 2, RoundingMode.HALF_UP()) + var half_even = round(d, 2, RoundingMode.HALF_EVEN()) print("\nRounding 123.456789 to 2 places with different modes:") print("DOWN():", down) print("UP():", up) @@ -201,9 +201,9 @@ fn rounding_and_precision() raises: # Rounding special cases var half_val = Decimal("123.5") - var half_rounded = half_val.round(0, RoundingMode.HALF_EVEN()) + var half_rounded = round(half_val, 0, RoundingMode.HALF_EVEN()) var val = Decimal("124.5") - var even_rounded = val.round(0, RoundingMode.HALF_EVEN()) + var even_rounded = round(val, 0, RoundingMode.HALF_EVEN()) print("\nRounding special cases (banker's rounding):") print("123.5 rounded to 0 places:", half_rounded) print("124.5 rounded to 0 places:", even_rounded) @@ -213,7 +213,7 @@ fn rounding_and_precision() raises: print("Value:", d_scale, "- Scale:", d_scale.scale()) # Round to different scales - var rounded = d_scale.round(3) + var rounded = round(d_scale, 3) print("Rounded to 3 places:", rounded, "- New scale:", rounded.scale()) # Scale after arithmetic operations diff --git a/tests/test_decimal_arithmetics.mojo b/tests/test_decimal_arithmetics.mojo index e13bc992..0148b982 100644 --- a/tests/test_decimal_arithmetics.mojo +++ b/tests/test_decimal_arithmetics.mojo @@ -2,6 +2,7 @@ Test Decimal arithmetic operations including addition, subtraction, and negation. """ from decimojo import Decimal +from decimojo.mathematics import round import testing @@ -322,35 +323,35 @@ fn test_subtraction() raises: var value1 = Decimal("0") testing.assert_equal( String(value1 - value1), - String(Decimal("0").round(value1.scale())), + String(round(Decimal("0"), value1.scale())), "Self subtraction should yield zero (0)", ) var value2 = Decimal("123.45") testing.assert_equal( String(value2 - value2), - String(Decimal("0").round(value2.scale())), + String(round(Decimal("0"), value2.scale())), "Self subtraction should yield zero (123.45)", ) var value3 = Decimal("-987.654") testing.assert_equal( String(value3 - value3), - String(Decimal("0").round(value3.scale())), + String(round(Decimal("0"), value3.scale())), "Self subtraction should yield zero (-987.654)", ) var value4 = Decimal("0.0001") testing.assert_equal( String(value4 - value4), - String(Decimal("0").round(value4.scale())), + String(round(Decimal("0"), value4.scale())), "Self subtraction should yield zero (0.0001)", ) var value5 = Decimal("-99999.99999") testing.assert_equal( String(value5 - value5), - String(Decimal("0").round(value5.scale())), + String(round(Decimal("0"), value5.scale())), "Self subtraction should yield zero (-99999.99999)", ) @@ -836,12 +837,10 @@ fn test_power_precision() raises: # These tests assume we have overloaded the ** operator # and we have a way to control precision similar to pow() - try: # Test with precision control var a1 = Decimal("1.5") var result1 = a1**2 - # Test equality including precision testing.assert_equal( String(result1), "2.25", "1.5^2 should be exactly 2.25" diff --git a/tests/test_decimal_rounding.mojo b/tests/test_decimal_rounding.mojo index 3b06667b..2b5fba04 100644 --- a/tests/test_decimal_rounding.mojo +++ b/tests/test_decimal_rounding.mojo @@ -3,6 +3,7 @@ Test Decimal rounding methods with different rounding modes and precision levels """ from decimojo import Decimal from decimojo.rounding_mode import RoundingMode +from decimojo.mathematics import round import testing @@ -11,26 +12,26 @@ fn test_basic_rounding() raises: # Test case 1: Round to 2 decimal places (banker's rounding) var d1 = Decimal("123.456") - var result1 = d1.round(2) + var result1 = round(d1, 2) testing.assert_equal( String(result1), "123.46", "Basic rounding to 2 decimal places" ) # Test case 2: Round to 0 decimal places var d2 = Decimal("123.456") - var result2 = d2.round(0) + var result2 = round(d2, 0) testing.assert_equal(String(result2), "123", "Rounding to 0 decimal places") # Test case 3: Round to more decimal places than original (should pad with zeros) var d3 = Decimal("123.45") - var result3 = d3.round(4) + var result3 = round(d3, 4) testing.assert_equal( String(result3), "123.4500", "Rounding to more decimal places" ) # Test case 4: Round number that's already at target precision var d4 = Decimal("123.45") - var result4 = d4.round(2) + var result4 = round(d4, 2) testing.assert_equal( String(result4), "123.45", "Rounding to same precision" ) @@ -44,19 +45,19 @@ fn test_different_rounding_modes() raises: var test_value = Decimal("123.456") # Test case 1: Round down (truncate) - var result1 = test_value.round(2, RoundingMode.DOWN()) + var result1 = round(test_value, 2, RoundingMode.DOWN()) testing.assert_equal(String(result1), "123.45", "Rounding down") # Test case 2: Round up (away from zero) - var result2 = test_value.round(2, RoundingMode.UP()) + var result2 = round(test_value, 2, RoundingMode.UP()) testing.assert_equal(String(result2), "123.46", "Rounding up") # Test case 3: Round half up - var result3 = test_value.round(2, RoundingMode.HALF_UP()) + var result3 = round(test_value, 2, RoundingMode.HALF_UP()) testing.assert_equal(String(result3), "123.46", "Rounding half up") # Test case 4: Round half even (banker's rounding) - var result4 = test_value.round(2, RoundingMode.HALF_EVEN()) + var result4 = round(test_value, 2, RoundingMode.HALF_EVEN()) testing.assert_equal(String(result4), "123.46", "Rounding half even") print("Rounding mode tests passed!") @@ -69,23 +70,25 @@ fn test_edge_cases() raises: var half_value = Decimal("123.5") testing.assert_equal( - String(half_value.round(0, RoundingMode.DOWN())), + String(round(half_value, 0, RoundingMode.DOWN())), "123", "Rounding 0.5 down", ) testing.assert_equal( - String(half_value.round(0, RoundingMode.UP())), "124", "Rounding 0.5 up" + String(round(half_value, 0, RoundingMode.UP())), + "124", + "Rounding 0.5 up", ) testing.assert_equal( - String(half_value.round(0, RoundingMode.HALF_UP())), + String(round(half_value, 0, RoundingMode.HALF_UP())), "124", "Rounding 0.5 half up", ) testing.assert_equal( - String(half_value.round(0, RoundingMode.HALF_EVEN())), + String(round(half_value, 0, RoundingMode.HALF_EVEN())), "124", "Rounding 0.5 half even (even is 124)", ) @@ -93,7 +96,7 @@ fn test_edge_cases() raises: # Another test with half to even value var half_even_value = Decimal("124.5") testing.assert_equal( - String(half_even_value.round(0, RoundingMode.HALF_EVEN())), + String(round(half_even_value, 0, RoundingMode.HALF_EVEN())), "124", "Rounding 124.5 half even (even is 124)", ) @@ -103,7 +106,7 @@ fn test_edge_cases() raises: "0." + "0" * 27 + "1" ) # 0.0000...01 (1 at 28th place) testing.assert_equal( - String(small_value.round(27)), + String(round(small_value, 27)), "0." + "0" * 27, "Rounding tiny number to 27 places", ) @@ -112,19 +115,19 @@ fn test_edge_cases() raises: var negative_value = Decimal("-123.456") testing.assert_equal( - String(negative_value.round(2, RoundingMode.DOWN())), + String(round(negative_value, 2, RoundingMode.DOWN())), "-123.45", "Rounding negative number down", ) testing.assert_equal( - String(negative_value.round(2, RoundingMode.UP())), + String(round(negative_value, 2, RoundingMode.UP())), "-123.46", "Rounding negative number up", ) testing.assert_equal( - String(negative_value.round(2, RoundingMode.HALF_EVEN())), + String(round(negative_value, 2, RoundingMode.HALF_EVEN())), "-123.46", "Rounding negative number half even", ) @@ -133,13 +136,15 @@ fn test_edge_cases() raises: var carry_value = Decimal("9.999") testing.assert_equal( - String(carry_value.round(2)), "10.00", "Rounding with carry propagation" + String(round(carry_value, 2)), + "10.00", + "Rounding with carry propagation", ) # Test case 5: Rounding to maximum precision var max_precision = Decimal("0." + "1" * 28) # 0.1111...1 (28 digits) testing.assert_equal( - String(max_precision.round(14)), + String(round(max_precision, 14)), "0.11111111111111", "Rounding from maximum precision", ) @@ -159,16 +164,16 @@ fn test_rounding_consistency() raises: # Both should round the same way testing.assert_equal( - String(d1.round(1))[:3], - String(d2.round(1))[:3], + String(round(d1, 1))[:3], + String(round(d2, 1))[:3], "Rounding consistency across different constructors", ) # Test that repeated rounding is consistent var start = Decimal("123.456789") - var round_once = start.round(4) # 123.4568 - var round_twice = round_once.round(2) # 123.46 - var direct = start.round(2) # 123.46 + var round_once = round(start, 4) # 123.4568 + var round_twice = round(round_once, 2) # 123.46 + var direct = round(start, 2) # 123.46 testing.assert_equal( String(round_twice), From b0c67b024541b4c4b706745e797e2e2ef7439351 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?ZHU=20Yuhao=20=E6=9C=B1=E5=AE=87=E6=B5=A9?= Date: Tue, 4 Mar 2025 12:59:53 +0100 Subject: [PATCH 2/3] Add logic dunders --- decimojo/__init__.mojo | 9 +- decimojo/decimal.mojo | 79 +++++- decimojo/logic.mojo | 220 ++++++++++++++ mojoproject.toml | 14 +- tests/test_decimal_logic.mojo | 521 ++++++++++++++++++++++++++++++++++ 5 files changed, 835 insertions(+), 8 deletions(-) create mode 100644 decimojo/logic.mojo create mode 100644 tests/test_decimal_logic.mojo diff --git a/decimojo/__init__.mojo b/decimojo/__init__.mojo index 3b45c928..e543b9cc 100644 --- a/decimojo/__init__.mojo +++ b/decimojo/__init__.mojo @@ -4,6 +4,11 @@ # https://github.com/forFudan/decimal/blob/main/LICENSE # ===----------------------------------------------------------------------=== # -from .decimal import Decimal +""" +DeciMojo - Correctly-rounded, fixed-point Decimal library for Mojo. +""" -from .mathematics import power, round +from .decimal import Decimal +from .rounding_mode import RoundingMode +from .mathematics import round, power +from .logic import greater, greater_equal, less, less_equal, equal, not_equal diff --git a/decimojo/decimal.mojo b/decimojo/decimal.mojo index 92ef84c0..d9aae82b 100644 --- a/decimojo/decimal.mojo +++ b/decimojo/decimal.mojo @@ -614,7 +614,7 @@ struct Decimal(Roundable, Writable): return result # ===------------------------------------------------------------------=== # - # Basic binary operation dunders + # Basic binary arithmetic operation dunders # add, sub, mul, truediv, pow # ===------------------------------------------------------------------=== # fn __add__(self, other: Decimal) raises -> Self: @@ -1198,6 +1198,83 @@ struct Decimal(Roundable, Writable): return decimal.power(self, exponent) + # ===------------------------------------------------------------------=== # + # Basic binary logic operation dunders + # __gt__, __ge__, __lt__, __le__, __eq__, __ne__ + # ===------------------------------------------------------------------=== # + + fn __gt__(self, other: Decimal) -> Bool: + """ + Greater than comparison operator. + + Args: + other: The Decimal to compare with. + + Returns: + True if self is greater than other, False otherwise. + """ + return decimojo.greater(self, other) + + fn __ge__(self, other: Decimal) -> Bool: + """ + Greater than or equal comparison operator. + + Args: + other: The Decimal to compare with. + + Returns: + True if self is greater than or equal to other, False otherwise. + """ + return decimojo.greater_equal(self, other) + + fn __lt__(self, other: Decimal) -> Bool: + """ + Less than comparison operator. + + Args: + other: The Decimal to compare with. + + Returns: + True if self is less than other, False otherwise. + """ + return decimojo.less(self, other) + + fn __le__(self, other: Decimal) -> Bool: + """ + Less than or equal comparison operator. + + Args: + other: The Decimal to compare with. + + Returns: + True if self is less than or equal to other, False otherwise. + """ + return decimojo.less_equal(self, other) + + fn __eq__(self, other: Decimal) -> Bool: + """ + Equality comparison operator. + + Args: + other: The Decimal to compare with. + + Returns: + True if self is equal to other, False otherwise. + """ + return decimojo.equal(self, other) + + fn __ne__(self, other: Decimal) -> Bool: + """ + Inequality comparison operator. + + Args: + other: The Decimal to compare with. + + Returns: + True if self is not equal to other, False otherwise. + """ + return decimojo.not_equal(self, other) + # ===------------------------------------------------------------------=== # # Other dunders that implements tratis # round diff --git a/decimojo/logic.mojo b/decimojo/logic.mojo new file mode 100644 index 00000000..6fb3220a --- /dev/null +++ b/decimojo/logic.mojo @@ -0,0 +1,220 @@ +# ===----------------------------------------------------------------------=== # +# Distributed under the Apache 2.0 License with LLVM Exceptions. +# See LICENSE and the LLVM License for more information. +# https://github.com/forFudan/decimojo/blob/main/LICENSE +# ===----------------------------------------------------------------------=== # +# +# Implements logic operations for the Decimal type +# +# ===----------------------------------------------------------------------=== # +# +# List of functions in this module: +# +# greater(a: Decimal, b: Decimal) -> Bool: Returns True if a > b +# greater_equal(a: Decimal, b: Decimal) -> Bool: Returns True if a >= b +# less(a: Decimal, b: Decimal) -> Bool: Returns True if a < b +# less_equal(a: Decimal, b: Decimal) -> Bool: Returns True if a <= b +# equal(a: Decimal, b: Decimal) -> Bool: Returns True if a == b +# not_equal(a: Decimal, b: Decimal) -> Bool: Returns True if a != b +# +# List of internal functions in this module: +# +# _compare_abs(a: Decimal, b: Decimal) -> Int: Compares absolute values of two Decimals +# +# ===----------------------------------------------------------------------=== # + +""" +Implements functions for comparison operations on Decimal objects. +""" + +from decimojo.decimal import Decimal + + +fn greater(a: Decimal, b: Decimal) -> Bool: + """ + Returns True if a > b. + + Args: + a: First Decimal value. + b: Second Decimal value. + + Returns: + True if a is greater than b, False otherwise. + """ + # Handle special case where either or both are zero + if a.is_zero() and b.is_zero(): + return False # Zero equals zero + if a.is_zero(): + return b.is_negative() # a=0 > b only if b is negative + if b.is_zero(): + return ( + not a.is_negative() and not a.is_zero() + ) # a > b=0 only if a is positive and non-zero + + # If they have different signs, positive is always greater + if a.is_negative() != b.is_negative(): + return not a.is_negative() # a > b if a is positive and b is negative + + # Now we know they have the same sign + # Compare absolute values, considering the sign + var compare_result = _compare_abs(a, b) + + if a.is_negative(): + # For negative numbers, the one with smaller absolute value is greater + return compare_result < 0 + else: + # For positive numbers, the one with larger absolute value is greater + return compare_result > 0 + + +fn greater_equal(a: Decimal, b: Decimal) -> Bool: + """ + Returns True if a >= b. + + Args: + a: First Decimal value. + b: Second Decimal value. + + Returns: + True if a is greater than or equal to b, False otherwise. + """ + # Handle special case where either or both are zero + if a.is_zero() and b.is_zero(): + return True # Zero equals zero + if a.is_zero(): + return ( + b.is_zero() or b.is_negative() + ) # a=0 >= b only if b is zero or negative + if b.is_zero(): + return ( + a.is_negative() == False + ) # a >= b=0 only if a is positive or zero + + # If they have different signs, positive is always greater + if a.is_negative() != b.is_negative(): + return not a.is_negative() # a >= b if a is positive and b is negative + + # Now we know they have the same sign + # Compare absolute values, considering the sign + var compare_result = _compare_abs(a, b) + + if a.is_negative(): + # For negative numbers, the one with smaller or equal absolute value is greater or equal + return compare_result <= 0 + else: + # For positive numbers, the one with larger or equal absolute value is greater or equal + return compare_result >= 0 + + +fn less(a: Decimal, b: Decimal) -> Bool: + """ + Returns True if a < b. + + Args: + a: First Decimal value. + b: Second Decimal value. + + Returns: + True if a is less than b, False otherwise. + """ + # We can use the greater function with arguments reversed + return greater(b, a) + + +fn less_equal(a: Decimal, b: Decimal) -> Bool: + """ + Returns True if a <= b. + + Args: + a: First Decimal value. + b: Second Decimal value. + + Returns: + True if a is less than or equal to b, False otherwise. + """ + # We can use the greater_equal function with arguments reversed + return greater_equal(b, a) + + +fn equal(a: Decimal, b: Decimal) -> Bool: + """ + Returns True if a == b. + + Args: + a: First Decimal value. + b: Second Decimal value. + + Returns: + True if a equals b, False otherwise. + """ + # If both are zero, they are equal regardless of scale or sign + if a.is_zero() and b.is_zero(): + return True + + # If signs differ, they're not equal + if a.is_negative() != b.is_negative(): + return False + + # Compare absolute values + return _compare_abs(a, b) == 0 + + +fn not_equal(a: Decimal, b: Decimal) -> Bool: + """ + Returns True if a != b. + + Args: + a: First Decimal value. + b: Second Decimal value. + + Returns: + True if a is not equal to b, False otherwise. + """ + # Simply negate the equal function + return not equal(a, b) + + +fn _compare_abs(a: Decimal, b: Decimal) -> Int: + """ + Internal helper to compare absolute values of two Decimal numbers. + + Returns: + - Positive value if |a| > |b| + - Zero if |a| = |b| + - Negative value if |a| < |b| + """ + # Normalize scales by scaling up the one with smaller scale + var scale_a = a.scale() + var scale_b = b.scale() + + # Create temporary copies that we will scale + var a_copy = a + var b_copy = b + + # Scale up the decimal with smaller scale to match the other + if scale_a < scale_b: + a_copy = a._scale_up(scale_b - scale_a) + elif scale_b < scale_a: + b_copy = b._scale_up(scale_a - scale_b) + + # Now both have the same scale, compare integer components + # Compare high parts first (most significant) + if a_copy.high > b_copy.high: + return 1 + if a_copy.high < b_copy.high: + return -1 + + # High parts equal, compare mid parts + if a_copy.mid > b_copy.mid: + return 1 + if a_copy.mid < b_copy.mid: + return -1 + + # Mid parts equal, compare low parts (least significant) + if a_copy.low > b_copy.low: + return 1 + if a_copy.low < b_copy.low: + return -1 + + # All components are equal + return 0 diff --git a/mojoproject.toml b/mojoproject.toml index 775c7796..8ac81c91 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -14,15 +14,19 @@ format = "magic run mojo format ./" # compile the package package = "magic run format && magic run mojo package decimojo && cp decimojo.mojopkg ./tests/ && cp decimojo.mojopkg ./examples/ && rm decimojo.mojopkg" -p = "magic run package" +p = "clear && magic run package" -# tests -test = "magic run package && magic run mojo tests/*.mojo && magic run mojo test tests -I ." -t = "magic run test" +# debugs (run the testing files only) +debug = "magic run package && magic run mojo tests/*.mojo" +d = "clear && magic run debug" + +# tests (use the mojo testing tool) +test = "magic run package && magic run mojo test tests -I ." +t = "clear && magic run test" # before commit final = "magic run test" -f = "magic run final" +f = "clear && magic run final" [dependencies] max = ">=25.0" diff --git a/tests/test_decimal_logic.mojo b/tests/test_decimal_logic.mojo new file mode 100644 index 00000000..47f94540 --- /dev/null +++ b/tests/test_decimal_logic.mojo @@ -0,0 +1,521 @@ +""" +Test Decimal logic operations for comparison, including basic comparisons, +edge cases, special handling for zero values, and operator overloads. +""" +from decimojo import Decimal +from decimojo.logic import ( + greater, + greater_equal, + less, + less_equal, + equal, + not_equal, +) +import testing + + +fn test_equality() raises: + print("Testing decimal equality...") + + # Test case 1: Equal decimals + var a1 = Decimal("123.45") + var b1 = Decimal("123.45") + testing.assert_true(equal(a1, b1), "Equal decimals should be equal") + + # Test case 2: Equal with different scales + var a2 = Decimal("123.450") + var b2 = Decimal("123.45") + testing.assert_true( + equal(a2, b2), "Equal decimals with different scales should be equal" + ) + + # Test case 3: Different values + var a3 = Decimal("123.45") + var b3 = Decimal("123.46") + testing.assert_false( + equal(a3, b3), "Different decimals should not be equal" + ) + + # Test case 4: Zeros with different scales + var a4 = Decimal("0") + var b4 = Decimal("0.00") + testing.assert_true( + equal(a4, b4), "Zeros with different scales should be equal" + ) + + # Test case 5: Zero and negative zero + var a5 = Decimal("0") + var b5 = Decimal("-0") + testing.assert_true(equal(a5, b5), "Zero and negative zero should be equal") + + # Test case 6: Same absolute value but different signs + var a6 = Decimal("123.45") + var b6 = Decimal("-123.45") + testing.assert_false( + equal(a6, b6), + "Same absolute value but different signs should not be equal", + ) + + print("Equality tests passed!") + + +fn test_inequality() raises: + print("Testing decimal inequality...") + + # Test case 1: Equal decimals + var a1 = Decimal("123.45") + var b1 = Decimal("123.45") + testing.assert_false( + not_equal(a1, b1), "Equal decimals should not be unequal" + ) + + # Test case 2: Equal with different scales + var a2 = Decimal("123.450") + var b2 = Decimal("123.45") + testing.assert_false( + not_equal(a2, b2), + "Equal decimals with different scales should not be unequal", + ) + + # Test case 3: Different values + var a3 = Decimal("123.45") + var b3 = Decimal("123.46") + testing.assert_true( + not_equal(a3, b3), "Different decimals should be unequal" + ) + + # Test case 4: Same absolute value but different signs + var a4 = Decimal("123.45") + var b4 = Decimal("-123.45") + testing.assert_true( + not_equal(a4, b4), + "Same absolute value but different signs should be unequal", + ) + + print("Inequality tests passed!") + + +fn test_greater() raises: + print("Testing greater than comparison...") + + # Test case 1: Larger decimal + var a1 = Decimal("123.46") + var b1 = Decimal("123.45") + testing.assert_true(greater(a1, b1), "123.46 should be greater than 123.45") + testing.assert_false( + greater(b1, a1), "123.45 should not be greater than 123.46" + ) + + # Test case 2: Equal decimals + var a2 = Decimal("123.45") + var b2 = Decimal("123.45") + testing.assert_false( + greater(a2, b2), "Equal decimals should not be greater" + ) + + # Test case 3: Positive vs. negative + var a3 = Decimal("123.45") + var b3 = Decimal("-123.45") + testing.assert_true( + greater(a3, b3), "Positive should be greater than negative" + ) + testing.assert_false( + greater(b3, a3), "Negative should not be greater than positive" + ) + + # Test case 4: Negative with smaller absolute value + var a4 = Decimal("-123.45") + var b4 = Decimal("-123.46") + testing.assert_true( + greater(a4, b4), "-123.45 should be greater than -123.46" + ) + + # Test case 5: Zero vs. positive + var a5 = Decimal("0") + var b5 = Decimal("123.45") + testing.assert_false( + greater(a5, b5), "Zero should not be greater than positive" + ) + + # Test case 6: Zero vs. negative + var a6 = Decimal("0") + var b6 = Decimal("-123.45") + testing.assert_true(greater(a6, b6), "Zero should be greater than negative") + + # Test case 7: Different scales + var a7 = Decimal("123.5") + var b7 = Decimal("123.45") + testing.assert_true(greater(a7, b7), "123.5 should be greater than 123.45") + + print("Greater than tests passed!") + + +fn test_greater_equal() raises: + print("Testing greater than or equal comparison...") + + # Test case 1: Larger decimal + var a1 = Decimal("123.46") + var b1 = Decimal("123.45") + testing.assert_true( + greater_equal(a1, b1), + "123.46 should be greater than or equal to 123.45", + ) + + # Test case 2: Equal decimals + var a2 = Decimal("123.45") + var b2 = Decimal("123.45") + testing.assert_true( + greater_equal(a2, b2), "Equal decimals should be greater than or equal" + ) + + # Test case 3: Positive vs. negative + var a3 = Decimal("123.45") + var b3 = Decimal("-123.45") + testing.assert_true( + greater_equal(a3, b3), + "Positive should be greater than or equal to negative", + ) + + # Test case 4: Equal values with different scales + var a4 = Decimal("123.450") + var b4 = Decimal("123.45") + testing.assert_true( + greater_equal(a4, b4), + "Equal values with different scales should be greater than or equal", + ) + + # Test case 5: Smaller decimal + var a5 = Decimal("123.45") + var b5 = Decimal("123.46") + testing.assert_false( + greater_equal(a5, b5), + "123.45 should not be greater than or equal to 123.46", + ) + + print("Greater than or equal tests passed!") + + +fn test_less() raises: + print("Testing less than comparison...") + + # Test case 1: Smaller decimal + var a1 = Decimal("123.45") + var b1 = Decimal("123.46") + testing.assert_true(less(a1, b1), "123.45 should be less than 123.46") + + # Test case 2: Equal decimals + var a2 = Decimal("123.45") + var b2 = Decimal("123.45") + testing.assert_false(less(a2, b2), "Equal decimals should not be less") + + # Test case 3: Negative vs. positive + var a3 = Decimal("-123.45") + var b3 = Decimal("123.45") + testing.assert_true(less(a3, b3), "Negative should be less than positive") + + # Test case 4: Negative with larger absolute value + var a4 = Decimal("-123.46") + var b4 = Decimal("-123.45") + testing.assert_true(less(a4, b4), "-123.46 should be less than -123.45") + + # Test case 5: Zero vs. positive + var a5 = Decimal("0") + var b5 = Decimal("123.45") + testing.assert_true(less(a5, b5), "Zero should be less than positive") + + print("Less than tests passed!") + + +fn test_less_equal() raises: + print("Testing less than or equal comparison...") + + # Test case 1: Smaller decimal + var a1 = Decimal("123.45") + var b1 = Decimal("123.46") + testing.assert_true( + less_equal(a1, b1), "123.45 should be less than or equal to 123.46" + ) + + # Test case 2: Equal decimals + var a2 = Decimal("123.45") + var b2 = Decimal("123.45") + testing.assert_true( + less_equal(a2, b2), "Equal decimals should be less than or equal" + ) + + # Test case 3: Negative vs. positive + var a3 = Decimal("-123.45") + var b3 = Decimal("123.45") + testing.assert_true( + less_equal(a3, b3), "Negative should be less than or equal to positive" + ) + + # Test case 4: Equal values with different scales + var a4 = Decimal("123.450") + var b4 = Decimal("123.45") + testing.assert_true( + less_equal(a4, b4), + "Equal values with different scales should be less than or equal", + ) + + # Test case 5: Larger decimal + var a5 = Decimal("123.46") + var b5 = Decimal("123.45") + testing.assert_false( + less_equal(a5, b5), "123.46 should not be less than or equal to 123.45" + ) + + print("Less than or equal tests passed!") + + +fn test_zero_comparison() raises: + print("Testing zero comparison cases...") + + var zero = Decimal("0") + var pos = Decimal("0.0000000000000000001") # Very small positive + var neg = Decimal("-0.0000000000000000001") # Very small negative + var zero_scale = Decimal("0.00000") # Zero with different scale + + # Zero compared to small positive + testing.assert_false(greater(zero, pos), "Zero should not be > positive") + testing.assert_false( + greater_equal(zero, pos), "Zero should not be >= positive" + ) + testing.assert_true(less(zero, pos), "Zero should be < positive") + testing.assert_true(less_equal(zero, pos), "Zero should be <= positive") + testing.assert_false(equal(zero, pos), "Zero should not be == positive") + testing.assert_true(not_equal(zero, pos), "Zero should be != positive") + + # Positive compared to zero + testing.assert_true(greater(pos, zero), "Positive should be > zero") + testing.assert_true(greater_equal(pos, zero), "Positive should be >= zero") + testing.assert_false(less(pos, zero), "Positive should not be < zero") + testing.assert_false( + less_equal(pos, zero), "Positive should not be <= zero" + ) + testing.assert_false(equal(pos, zero), "Positive should not be == zero") + testing.assert_true(not_equal(pos, zero), "Positive should be != zero") + + # Zero compared to small negative + testing.assert_true(greater(zero, neg), "Zero should be > negative") + testing.assert_true(greater_equal(zero, neg), "Zero should be >= negative") + testing.assert_false(less(zero, neg), "Zero should not be < negative") + testing.assert_false( + less_equal(zero, neg), "Zero should not be <= negative" + ) + testing.assert_false(equal(zero, neg), "Zero should not be == negative") + testing.assert_true(not_equal(zero, neg), "Zero should be != negative") + + # Different zeros + testing.assert_false( + greater(zero, zero_scale), "Zero should not be > zero with scale" + ) + testing.assert_true( + greater_equal(zero, zero_scale), "Zero should be >= zero with scale" + ) + testing.assert_false( + less(zero, zero_scale), "Zero should not be < zero with scale" + ) + testing.assert_true( + less_equal(zero, zero_scale), "Zero should be <= zero with scale" + ) + testing.assert_true( + equal(zero, zero_scale), "Zero should be == zero with scale" + ) + testing.assert_false( + not_equal(zero, zero_scale), "Zero should not be != zero with scale" + ) + + # Negative zero + var neg_zero = Decimal("-0") + testing.assert_true( + equal(zero, neg_zero), "Zero should be == negative zero" + ) + testing.assert_false( + greater(zero, neg_zero), "Zero should not be > negative zero" + ) + testing.assert_true( + greater_equal(zero, neg_zero), "Zero should be >= negative zero" + ) + testing.assert_false( + less(zero, neg_zero), "Zero should not be < negative zero" + ) + testing.assert_true( + less_equal(zero, neg_zero), "Zero should be <= negative zero" + ) + + print("Zero comparison cases passed!") + + +fn test_edge_cases() raises: + print("Testing comparison edge cases...") + + # Test case 1: Very close values + var a1 = Decimal("1.000000000000000000000000001") + var b1 = Decimal("1.000000000000000000000000000") + testing.assert_true( + greater(a1, b1), "1.000...001 should be greater than 1.000...000" + ) + + # Test case 2: Very large values + var a2 = Decimal("79228162514264337593543950335") # MAX value + var b2 = Decimal("79228162514264337593543950334") # MAX - 1 + testing.assert_true(greater(a2, b2), "MAX should be greater than MAX-1") + + # Test case 3: Very small negatives vs very small positives + var a3 = Decimal("-0." + "0" * 27 + "1") # -0.0000...01 (1 at 28th place) + var b3 = Decimal("0." + "0" * 27 + "1") # 0.0000...01 (1 at 28th place) + testing.assert_true( + less(a3, b3), + "Very small negative should be less than very small positive", + ) + + # Test case 4: Transitivity checks + var neg_large = Decimal("-1000") + var neg_small = Decimal("-0.001") + var pos_small = Decimal("0.001") + var pos_large = Decimal("1000") + + # Transitivity: if a > b and b > c then a > c + testing.assert_true(greater(pos_large, pos_small), "1000 > 0.001") + testing.assert_true(greater(pos_small, neg_small), "0.001 > -0.001") + testing.assert_true(greater(neg_small, neg_large), "-0.001 > -1000") + testing.assert_true( + greater(pos_large, neg_large), "1000 > -1000 (transitivity)" + ) + + print("Edge case tests passed!") + + +fn test_exact_comparison() raises: + print("Testing exact comparison with precision handling...") + + # Test case 1: Scale handling with zeros + var zero1 = Decimal("0") + var zero2 = Decimal("0.0") + var zero3 = Decimal("0.00000") + + testing.assert_true(equal(zero1, zero2), "0 == 0.0") + testing.assert_true(equal(zero1, zero3), "0 == 0.00000") + testing.assert_true(equal(zero2, zero3), "0.0 == 0.00000") + + # Test case 2: Equal values with different number of trailing zeros + var d1 = Decimal("123.400") + var d2 = Decimal("123.4") + var d3 = Decimal("123.40000") + + testing.assert_true(equal(d1, d2), "123.400 == 123.4") + testing.assert_true(equal(d2, d3), "123.4 == 123.40000") + testing.assert_true(equal(d1, d3), "123.400 == 123.40000") + + # Test case 3: Numbers that appear close but are different + var e1 = Decimal("1.2") + var e2 = Decimal("1.20000001") + + testing.assert_false(equal(e1, e2), "1.2 != 1.20000001") + testing.assert_true(less(e1, e2), "1.2 < 1.20000001") + + print("Exact comparison tests passed!") + + +fn test_comparison_operators() raises: + print("Testing comparison operators...") + + # Create test values + var a = Decimal("123.45") + var b = Decimal("67.89") + var c = Decimal("123.45") # Equal to a + var d = Decimal("123.450") # Equal to a with different scale + var e = Decimal("-50.0") # Negative number + var f = Decimal("0") # Zero + var g = Decimal("-0.0") # Negative zero (equal to zero) + + # Greater than + testing.assert_true(a > b, "a > b: 123.45 should be > 67.89") + testing.assert_false(b > a, "b > a: 67.89 should not be > 123.45") + testing.assert_false(a > c, "a > c: Equal values should not be >") + testing.assert_true(a > e, "a > e: Positive should be > negative") + testing.assert_true(a > f, "a > f: Positive should be > zero") + testing.assert_true(f > e, "f > e: Zero should be > negative") + + # Less than + testing.assert_false(a < b, "a < b: 123.45 should not be < 67.89") + testing.assert_true(b < a, "b < a: 67.89 should be < 123.45") + testing.assert_false(a < c, "a < c: Equal values should not be <") + testing.assert_false( + a < d, "a < d: Equal values (diff scale) should not be <" + ) + testing.assert_false(a < e, "a < e: Positive should not be < negative") + testing.assert_true(e < a, "e < a: Negative should be < positive") + testing.assert_true(e < f, "e < f: Negative should be < zero") + testing.assert_true(f < a, "f < a: Zero should be < positive") + + # Greater than or equal + testing.assert_true(a >= b, "a >= b: 123.45 should be >= 67.89") + testing.assert_false(b >= a, "b >= a: 67.89 should not be >= 123.45") + testing.assert_true(a >= c, "a >= c: Equal values should be >=") + testing.assert_true( + a >= d, "a >= d: Equal values (diff scale) should be >=" + ) + testing.assert_true(a >= e, "a >= e: Positive should be >= negative") + testing.assert_false(e >= a, "e >= a: Negative should not be >= positive") + testing.assert_true(f >= g, "f >= g: Zero should be >= negative zero") + + # Less than or equal + testing.assert_false(a <= b, "a <= b: 123.45 should not be <= 67.89") + testing.assert_true(b <= a, "b <= a: 67.89 should be <= 123.45") + testing.assert_true(a <= c, "a <= c: Equal values should be <=") + testing.assert_true( + a <= d, "a <= d: Equal values (diff scale) should be <=" + ) + testing.assert_false(a <= e, "a <= e: Positive should not be <= negative") + testing.assert_true(e <= a, "e <= a: Negative should be <= positive") + testing.assert_true(f <= a, "f <= a: Zero should be <= positive") + testing.assert_true(g <= f, "g <= f: Negative zero should be <= zero") + + # Equality + testing.assert_false(a == b, "a == b: Different values should not be equal") + testing.assert_true(a == c, "a == c: Same value should be equal") + testing.assert_true( + a == d, "a == d: Same value with different scales should be equal" + ) + testing.assert_true( + f == g, "f == g: Zero and negative zero should be equal" + ) + + # Inequality + testing.assert_true(a != b, "a != b: Different values should be unequal") + testing.assert_false(a != c, "a != c: Same value should not be unequal") + testing.assert_false( + a != d, "a != d: Same value with different scales should not be unequal" + ) + testing.assert_true(a != e, "a != e: Different values should be unequal") + testing.assert_false( + f != g, "f != g: Zero and negative zero should not be unequal" + ) + + print("Comparison operator tests passed!") + + +fn main() raises: + print("Running decimal logic tests") + + # Basic equality tests + test_equality() + test_inequality() + + # Comparison tests + test_greater() + test_greater_equal() + test_less() + test_less_equal() + + # Zero handling and edge cases + test_zero_comparison() + test_edge_cases() + test_exact_comparison() + + # Test operator overloads + test_comparison_operators() + + print("All decimal logic tests passed!") From d96514505a07194a89c667fa6e0c428e5b1a8c3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?ZHU=20Yuhao=20=E6=9C=B1=E5=AE=87=E6=B5=A9?= Date: Tue, 4 Mar 2025 13:07:57 +0100 Subject: [PATCH 3/3] remove sqrt --- decimojo/mathematics.mojo | 142 -------------------------------------- 1 file changed, 142 deletions(-) diff --git a/decimojo/mathematics.mojo b/decimojo/mathematics.mojo index 296ae14b..fad8af80 100644 --- a/decimojo/mathematics.mojo +++ b/decimojo/mathematics.mojo @@ -129,148 +129,6 @@ fn power(base: Decimal, exponent: Int) raises -> Decimal: return power(base, Decimal(exponent)) -# fn sqrt(x: Decimal) raises -> Decimal: -# """ -# Computes the square root of a Decimal value using Newton-Raphson method. - -# Args: -# x: The Decimal value to compute the square root of. - -# Returns: -# A new Decimal containing the square root of x. - -# Raises: -# Error: If x is negative. -# """ -# # Special cases -# if x.is_negative(): -# raise Error("Cannot compute square root of negative number") - -# if x.is_zero(): -# return Decimal.ZERO() - -# if x == Decimal.ONE(): -# return Decimal.ONE() - -# # Working precision - we'll compute with extra digits and round at the end -# var working_precision = x.scale() * 2 -# working_precision = max(working_precision, UInt32(10)) # At least 10 digits - -# # Initial guess - a good guess helps converge faster -# # For numbers near 1, use the number itself -# # For very small or large numbers, scale appropriately -# var guess: Decimal -# var exponent = len(x.coefficient()) - x.scale() - -# if exponent >= 0 and exponent <= 3: -# # For numbers between 0.1 and 1000, start with x/2 + 0.5 -# guess = (x / Decimal("2")) + Decimal("0.5") -# else: -# # For larger/smaller numbers, make a smarter guess -# # This scales based on the magnitude of the number -# var shift = exponent / 2 -# if exponent % 2 != 0: -# # For odd exponents, adjust -# shift = (exponent + 1) / 2 - -# # Use an approximation based on the exponent -# if exponent > 0: -# guess = Decimal("10") ** shift -# else: -# guess = Decimal("0.1") ** (-shift) - -# # Newton-Raphson iterations -# # x_n+1 = (x_n + S/x_n) / 2 -# var prev_guess = Decimal.ZERO() -# var iteration_count = 0 -# var max_iterations = 100 # Prevent infinite loops - -# while guess != prev_guess and iteration_count < max_iterations: -# prev_guess = guess -# guess = (guess + (x / guess)) / Decimal("2") -# iteration_count += 1 - -# # Round to appropriate precision - typically half the working precision -# var result_precision = x.scale() -# if result_precision % 2 == 1: -# # For odd scales, add 1 to ensure proper rounding -# result_precision += 1 - -# # The result scale should be approximately half the input scale -# result_precision = result_precision / 2 - -# # Format to the appropriate number of decimal places -# var result_str = String(guess) -# var rounded_result = Decimal(result_str) - -# return rounded_result - - -# # Additional helper function to compute nth root -# fn root(x: Decimal, n: Int) raises -> Decimal: -# """ -# Computes the nth root of x using Newton's method. - -# Args: -# x: The base value -# n: The root to compute (must be positive) - -# Returns: -# The nth root of x as a Decimal - -# Raises: -# Error: If x is negative and n is even, or if n is not positive -# """ -# # Special cases -# if n <= 0: -# raise Error("Root index must be positive") - -# if n == 1: -# return x - -# if n == 2: -# return sqrt(x) - -# if x.is_zero(): -# return Decimal.ZERO() - -# if x == Decimal.ONE(): -# return Decimal.ONE() - -# if x.is_negative() and n % 2 == 0: -# raise Error("Cannot compute even root of negative number") - -# # The sign of the result -# var result_is_negative = x.is_negative() and n % 2 == 1 -# var abs_x = x.abs() - -# # Newton's method for nth root: -# # x_k+1 = ((n-1)*x_k + num/(x_k^(n-1)))/n - -# # Initial guess -# var guess = Decimal.ONE() -# if abs_x > Decimal.ONE(): -# guess = abs_x / Decimal(n) # Simple initial approximation - -# var prev_guess = Decimal.ZERO() -# var n_minus_1 = Decimal(n - 1) -# var n_decimal = Decimal(n) -# var max_iterations = 100 -# var iteration_count = 0 - -# while guess != prev_guess and iteration_count < max_iterations: -# prev_guess = guess - -# # Compute x_k^(n-1) -# var guess_pow_n_minus_1 = power(guess, n - 1) - -# # Next approximation -# guess = (n_minus_1 * guess + abs_x / guess_pow_n_minus_1) / n_decimal -# iteration_count += 1 - -# return -guess if result_is_negative else guess - - # ===------------------------------------------------------------------------===# # Rounding # ===------------------------------------------------------------------------===#