Skip to content

Commit 15e28d5

Browse files
authored
[fix] Fix bug in multiply() due to implicit type conversion (#39)
This pull request fixes a bug in `multiply()` due to implicit type conversion. More specifically, ```UInt128(a) * 10 ** 28``` would implicitly convert the value into Int. Thus, all "power to 10" are replaced by the function `power_of_10()`.
1 parent 782f5f1 commit 15e28d5

File tree

3 files changed

+80
-87
lines changed

3 files changed

+80
-87
lines changed

src/decimojo/arithmetics.mojo

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -495,8 +495,6 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal:
495495

496496
# SPECIAL CASE: Both operands are true integers
497497
if x1_scale == 0 and x2_scale == 0:
498-
# print("DEBUG: Both operands are true integers")
499-
# print("DEBUG: combined_num_bits: ", combined_num_bits)
500498
# Small integers, use UInt64 multiplication
501499
if combined_num_bits <= 64:
502500
var prod: UInt64 = UInt64(x1_coef) * UInt64(x2_coef)
@@ -526,8 +524,12 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal:
526524
# SPECIAL CASE: Both operands are integers but with scales
527525
# Examples: 123.0 * 456.00
528526
if x1.is_integer() and x2.is_integer():
529-
var x1_integral_part = x1_coef // (UInt128(10) ** UInt128(x1_scale))
530-
var x2_integral_part = x2_coef // (UInt128(10) ** UInt128(x2_scale))
527+
var x1_integral_part = x1_coef // decimojo.utility.power_of_10[
528+
DType.uint128
529+
](x1_scale)
530+
var x2_integral_part = x2_coef // decimojo.utility.power_of_10[
531+
DType.uint128
532+
](x2_scale)
531533
var prod: UInt256 = UInt256(x1_integral_part) * UInt256(
532534
x2_integral_part
533535
)
@@ -538,8 +540,11 @@ fn multiply(x1: Decimal, x2: Decimal) raises -> Decimal:
538540
var final_scale = min(
539541
Decimal.MAX_NUM_DIGITS - num_digits, combined_scale
540542
)
541-
# Scale up before it overflows
542-
prod = prod * 10**final_scale
543+
# Scale up by adding trailing zeros
544+
prod = prod * decimojo.utility.power_of_10[DType.uint256](
545+
final_scale
546+
)
547+
# If it overflows, remove the last zero
543548
if prod > Decimal.MAX_AS_UINT256:
544549
prod = prod // 10
545550
final_scale -= 1
@@ -719,11 +724,6 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal:
719724
Error: If x2 is zero.
720725
"""
721726

722-
# print("----------------------------------------")
723-
# print("DEBUG divide()")
724-
# print("DEBUG: x1", x1)
725-
# print("DEBUG: x2", x2)
726-
727727
# Treatment for special cases
728728
# 對各類特殊情況進行處理
729729

@@ -772,12 +772,6 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal:
772772

773773
# diff_scale < 0, then times 10 ** (-diff_scale)
774774
else:
775-
# print("DEBUG: x1_coef", x1_coef)
776-
# print("DEBUG: x1_scale", x1_scale)
777-
# print("DEBUG: x2_coef", x2_coef)
778-
# print("DEBUG: x2_scale", x2_scale)
779-
# print("DEBUG: diff_scale", diff_scale)
780-
781775
# If the result can be stored in UInt128
782776
if (
783777
decimojo.utility.number_of_digits(x1_coef) - diff_scale
@@ -789,7 +783,6 @@ fn true_divide(x1: Decimal, x2: Decimal) raises -> Decimal:
789783
# If the result should be stored in UInt256
790784
else:
791785
var quot = UInt256(x1_coef) * UInt256(10) ** (-diff_scale)
792-
# print("DEBUG: quot", quot)
793786
if quot > Decimal.MAX_AS_UINT256:
794787
raise Error("Error in `true_divide()`: Decimal overflow")
795788
else:

src/decimojo/decimal.mojo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,7 @@ struct Decimal(
10451045

10461046
# Otherwise, get the integer part by dividing by 10^scale
10471047
else:
1048-
res = self.coefficient() // 10 ** UInt128(self.scale())
1048+
res = self.coefficient() // UInt128(10) ** UInt128(self.scale())
10491049

10501050
return res
10511051

src/decimojo/utility.mojo

Lines changed: 68 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -633,74 +633,74 @@ fn number_of_bits[dtype: DType, //](owned value: Scalar[dtype]) -> Int:
633633
# ===----------------------------------------------------------------------=== #
634634

635635

636-
# Module-level cache for powers of 10
637-
var _power_of_10_as_uint128_cache = List[UInt128]()
638-
var _power_of_10_as_uint256_cache = List[UInt256]()
639-
640-
641-
# Initialize with the first value
642-
@always_inline
643-
fn _init_power_of_10_as_uint128_cache():
644-
if len(_power_of_10_as_uint128_cache) == 0:
645-
_power_of_10_as_uint128_cache.append(1) # 10^0 = 1
646-
647-
648-
@always_inline
649-
fn _init_power_of_10_as_uint256_cache():
650-
if len(_power_of_10_as_uint256_cache) == 0:
651-
_power_of_10_as_uint256_cache.append(1) # 10^0 = 1
652-
653-
654-
@always_inline
655-
fn power_of_10_as_uint128(n: Int) raises -> UInt128:
656-
"""
657-
Returns 10^n using cached values when available.
658-
"""
659-
660-
# Check for negative exponent
661-
if n < 0:
662-
raise Error(
663-
"power_of_10() requires non-negative exponent, got {}".format(n)
664-
)
665-
666-
# Initialize cache if needed
667-
if len(_power_of_10_as_uint128_cache) == 0:
668-
_init_power_of_10_as_uint128_cache()
669-
670-
# Extend cache if needed
671-
while len(_power_of_10_as_uint128_cache) <= n:
672-
var next_power = _power_of_10_as_uint128_cache[
673-
len(_power_of_10_as_uint128_cache) - 1
674-
] * 10
675-
_power_of_10_as_uint128_cache.append(next_power)
676-
677-
return _power_of_10_as_uint128_cache[n]
678-
679-
680-
@always_inline
681-
fn power_of_10_as_uint256(n: Int) raises -> UInt256:
682-
"""
683-
Returns 10^n using cached values when available.
684-
"""
685-
686-
# Check for negative exponent
687-
if n < 0:
688-
raise Error(
689-
"power_of_10() requires non-negative exponent, got {}".format(n)
690-
)
691-
692-
# Initialize cache if needed
693-
if len(_power_of_10_as_uint256_cache) == 0:
694-
_init_power_of_10_as_uint256_cache()
695-
696-
# Extend cache if needed
697-
while len(_power_of_10_as_uint256_cache) <= n:
698-
var next_power = _power_of_10_as_uint256_cache[
699-
len(_power_of_10_as_uint256_cache) - 1
700-
] * 10
701-
_power_of_10_as_uint256_cache.append(next_power)
702-
703-
return _power_of_10_as_uint256_cache[n]
636+
# # Module-level cache for powers of 10
637+
# var _power_of_10_as_uint128_cache = List[UInt128]()
638+
# var _power_of_10_as_uint256_cache = List[UInt256]()
639+
640+
641+
# # Initialize with the first value
642+
# @always_inline
643+
# fn _init_power_of_10_as_uint128_cache():
644+
# if len(_power_of_10_as_uint128_cache) == 0:
645+
# _power_of_10_as_uint128_cache.append(1) # 10^0 = 1
646+
647+
648+
# @always_inline
649+
# fn _init_power_of_10_as_uint256_cache():
650+
# if len(_power_of_10_as_uint256_cache) == 0:
651+
# _power_of_10_as_uint256_cache.append(1) # 10^0 = 1
652+
653+
654+
# @always_inline
655+
# fn power_of_10_as_uint128(n: Int) raises -> UInt128:
656+
# """
657+
# Returns 10^n using cached values when available.
658+
# """
659+
660+
# # Check for negative exponent
661+
# if n < 0:
662+
# raise Error(
663+
# "power_of_10() requires non-negative exponent, got {}".format(n)
664+
# )
665+
666+
# # Initialize cache if needed
667+
# if len(_power_of_10_as_uint128_cache) == 0:
668+
# _init_power_of_10_as_uint128_cache()
669+
670+
# # Extend cache if needed
671+
# while len(_power_of_10_as_uint128_cache) <= n:
672+
# var next_power = _power_of_10_as_uint128_cache[
673+
# len(_power_of_10_as_uint128_cache) - 1
674+
# ] * 10
675+
# _power_of_10_as_uint128_cache.append(next_power)
676+
677+
# return _power_of_10_as_uint128_cache[n]
678+
679+
680+
# @always_inline
681+
# fn power_of_10_as_uint256(n: Int) raises -> UInt256:
682+
# """
683+
# Returns 10^n using cached values when available.
684+
# """
685+
686+
# # Check for negative exponent
687+
# if n < 0:
688+
# raise Error(
689+
# "power_of_10() requires non-negative exponent, got {}".format(n)
690+
# )
691+
692+
# # Initialize cache if needed
693+
# if len(_power_of_10_as_uint256_cache) == 0:
694+
# _init_power_of_10_as_uint256_cache()
695+
696+
# # Extend cache if needed
697+
# while len(_power_of_10_as_uint256_cache) <= n:
698+
# var next_power = _power_of_10_as_uint256_cache[
699+
# len(_power_of_10_as_uint256_cache) - 1
700+
# ] * 10
701+
# _power_of_10_as_uint256_cache.append(next_power)
702+
703+
# return _power_of_10_as_uint256_cache[n]
704704

705705

706706
@always_inline

0 commit comments

Comments
 (0)