Skip to content

Commit 88e9f7b

Browse files
authored
[integer] Optimize BigUInt addition and subtraction with SIMD and early stop tricks (#101)
This pull request introduces significant updates to the `bench_biguint` benchmarking suite, optimizes arithmetic operations in the `BigUInt` and `BigDecimal` modules, and refactors method names for clarity and consistency. Additionally, new benchmarking cases and constants are added to improve performance testing and support for larger numbers. 1. Use SIMD to accelerate BigUInt addition and in-place addition. The speed gain is 2x to 4x for large numbers. 2. Refine the BigUInt subtraction and in-place addition with some tricks on carry so that floor_divide and modulo are replaced by addition and subtraction. 3. Use a trick to first do a paralelled addition word-by-word, and then do normalized carries with one loop. ### Arithmetic Optimizations: * [`src/decimojo/bigdecimal/arithmetics.mojo`](diffhunk://#diff-f79534f4e7fdd891932ce9d015c50bd3c8a72c4a1689f0cb55524490ffc0458dL73-R74): Refactored methods to replace `scale_up_by_power_of_10` and `scale_down_by_power_of_10` with `multiply_by_power_of_ten` and `floor_divide_by_power_of_ten`, improving naming consistency and clarity. [[1]](diffhunk://#diff-f79534f4e7fdd891932ce9d015c50bd3c8a72c4a1689f0cb55524490ffc0458dL73-R74) [[2]](diffhunk://#diff-f79534f4e7fdd891932ce9d015c50bd3c8a72c4a1689f0cb55524490ffc0458dL304-R304) [[3]](diffhunk://#diff-f79534f4e7fdd891932ce9d015c50bd3c8a72c4a1689f0cb55524490ffc0458dL440-R445) * [`src/decimojo/biguint/biguint.mojo`](diffhunk://#diff-f9432b9b2671643af91201f9e3f011551a3d3b0e6d7b256d0d4569f5ae59848aL1003-L1010): Removed redundant `add_inplace_by_1` method and replaced it with a more general `add_inplace_by_uint32` for optimized addition operations. [[1]](diffhunk://#diff-f9432b9b2671643af91201f9e3f011551a3d3b0e6d7b256d0d4569f5ae59848aL1003-L1010) [[2]](diffhunk://#diff-f9432b9b2671643af91201f9e3f011551a3d3b0e6d7b256d0d4569f5ae59848aL1437-R1439) ### Refactoring and Enhancements: * [`src/decimojo/biguint/biguint.mojo`](diffhunk://#diff-f9432b9b2671643af91201f9e3f011551a3d3b0e6d7b256d0d4569f5ae59848aR74-R75): Renamed methods (e.g., `scale_up_by_power_of_10` → `multiply_by_power_of_ten`) for consistency across the codebase and introduced `VECTOR_WIDTH` constant for SIMD-based arithmetic optimizations. [[1]](diffhunk://#diff-f9432b9b2671643af91201f9e3f011551a3d3b0e6d7b256d0d4569f5ae59848aR74-R75) [[2]](diffhunk://#diff-f9432b9b2671643af91201f9e3f011551a3d3b0e6d7b256d0d4569f5ae59848aL1070-R1095) * [`src/decimojo/bigdecimal/comparison.mojo`](diffhunk://#diff-04237ffa697ff22a4879812f65a72c23bc5d3e183b58f11e437c94836bd43da3L66-R70): Updated comparison logic to use the newly renamed `multiply_by_power_of_ten` method for scaling coefficients. ### Benchmarking Updates: * [`benches/biguint/bench_biguint_add.mojo`](diffhunk://#diff-967ad165864a3f276ee27b8eca0721f132d904f71ffb3da60003a75aec8837efR460-R509): Added five new addition benchmark cases for larger word sizes (e.g., 4096 words + 2048 words) to test scalability. * [`benches/biguint/bench_biguint_multiply.mojo`](diffhunk://#diff-3fba3fe441d30e17e77d7e18b33b2508452b08f07e7af177d413c08b5b5c88c2L463-R584): Expanded multiplication benchmarks to include 12 new cases with varying word sizes, introducing reduced iterations for very large numbers to optimize runtime. * [`benches/biguint/bench_biguint_multiply_complexity.mojo`](diffhunk://#diff-d0d1723b5108046f6dc332ce4cf856979f00576061017e71660d38bcd536b31fL132-R132): Adjusted test sizes to start from 8 words instead of 32 and updated iteration logic for benchmarking complexity. [[1]](diffhunk://#diff-d0d1723b5108046f6dc332ce4cf856979f00576061017e71660d38bcd536b31fL132-R132) [[2]](diffhunk://#diff-d0d1723b5108046f6dc332ce4cf856979f00576061017e71660d38bcd536b31fL141-R144) These changes collectively enhance the code's readability, scalability, and performance, especially for operations involving large numbers and benchmarking scenarios.
1 parent dced4f9 commit 88e9f7b

9 files changed

Lines changed: 674 additions & 293 deletions

File tree

benches/biguint/bench_biguint_add.mojo

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ fn open_log_file() raises -> PythonObject:
2121
"""
2222
var python = Python.import_module("builtins")
2323
var datetime = Python.import_module("datetime")
24+
var pysys = Python.import_module("sys")
25+
pysys.set_int_max_str_digits(1000000)
2426

2527
# Create logs directory if it doesn't exist
2628
var log_dir = "./logs"
@@ -455,6 +457,56 @@ fn main() raises:
455457
speedup_factors,
456458
)
457459

460+
# Case 31: Addition with 64 words + 32 words
461+
run_benchmark_add(
462+
"Addition with 64 words + 32 words",
463+
"123456789" * 64,
464+
"987654321" * 32,
465+
iterations,
466+
log_file,
467+
speedup_factors,
468+
)
469+
470+
# Case 32: Addition with 256 words + 128 words
471+
run_benchmark_add(
472+
"Addition with 256 words + 128 words",
473+
"123456789" * 256,
474+
"987654321" * 128,
475+
iterations,
476+
log_file,
477+
speedup_factors,
478+
)
479+
480+
# Case 33: Addition with 1024 words + 512 words
481+
run_benchmark_add(
482+
"Addition with 1024 words + 512 words",
483+
"123456789" * 1024,
484+
"987654321" * 512,
485+
iterations,
486+
log_file,
487+
speedup_factors,
488+
)
489+
490+
# Case 34: Addition with 4096 words + 2048 words
491+
run_benchmark_add(
492+
"Addition with 4096 words + 2048 words",
493+
"123456789" * 4096,
494+
"987654321" * 2048,
495+
iterations,
496+
log_file,
497+
speedup_factors,
498+
)
499+
500+
# Case 35: Addition with 16384 words + 8192 words
501+
run_benchmark_add(
502+
"Addition with 16384 words + 8192 words",
503+
"123456789" * 16384,
504+
"987654321" * 8192,
505+
iterations,
506+
log_file,
507+
speedup_factors,
508+
)
509+
458510
# Calculate average speedup factor
459511
var sum_speedup: Float64 = 0.0
460512
for i in range(len(speedup_factors)):

benches/biguint/bench_biguint_multiply.mojo

Lines changed: 110 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ fn main() raises:
150150
log_print("Could not retrieve system information", log_file)
151151

152152
# Use fewer iterations for multiplication as it's more compute-intensive
153+
# For large numbers, we reduce iterations to avoid long runtimes
153154
var iterations = 100
155+
var iterations_large = 20
154156

155157
# Define benchmark cases
156158
log_print(
@@ -460,26 +462,126 @@ fn main() raises:
460462
speedup_factors,
461463
)
462464

463-
# Case 31: Very, very large numbers multiplication
465+
# Case 31: 2 words * 2 words multiplication
464466
run_benchmark_multiply(
465-
"Extreme large numbers multiplication (9000 digits * 9000 digits)",
466-
"123456789" * 1000, # 9000 digits
467-
"987654321" * 1000, # 9000 digits
467+
"2 words * 2 words multiplication",
468+
"123456789" * 2,
469+
"987654321" * 2,
468470
iterations,
469471
log_file,
470472
speedup_factors,
471473
)
472474

473-
# Case 32: Extremely large numbers multiplication
475+
# Case 32: 4 words * 4 words multiplication
474476
run_benchmark_multiply(
475-
"Extreme large numbers multiplication (36000 digits * 36000 digits)",
476-
"123456789" * 4000, # 36000 digits
477-
"987654321" * 4000, # 36000 digits
477+
"4 words * 4 words multiplication",
478+
"123456789" * 4,
479+
"987654321" * 4,
478480
iterations,
479481
log_file,
480482
speedup_factors,
481483
)
482484

485+
# Case 33: 8 words * 8 words multiplication
486+
run_benchmark_multiply(
487+
"8 words * 8 words multiplication",
488+
"123456789" * 8,
489+
"987654321" * 8,
490+
iterations,
491+
log_file,
492+
speedup_factors,
493+
)
494+
495+
# Case 34: 16 words * 16 words multiplication
496+
run_benchmark_multiply(
497+
"16 words * 16 words multiplication",
498+
"123456789" * 16,
499+
"987654321" * 16,
500+
iterations,
501+
log_file,
502+
speedup_factors,
503+
)
504+
505+
# Case 35: 32 words * 32 words multiplication
506+
run_benchmark_multiply(
507+
"32 words * 32 words multiplication",
508+
"123456789" * 32,
509+
"987654321" * 32,
510+
iterations,
511+
log_file,
512+
speedup_factors,
513+
)
514+
515+
# Case 36: 64 words * 64 words multiplication
516+
run_benchmark_multiply(
517+
"64 words * 64 words multiplication",
518+
"123456789" * 64,
519+
"987654321" * 64,
520+
iterations,
521+
log_file,
522+
speedup_factors,
523+
)
524+
525+
# Case 37: 128 words * 128 words multiplication
526+
run_benchmark_multiply(
527+
"128 words * 128 words multiplication",
528+
"123456789" * 128,
529+
"987654321" * 128,
530+
iterations_large,
531+
log_file,
532+
speedup_factors,
533+
)
534+
535+
# Case 38: 256 words * 256 words multiplication
536+
run_benchmark_multiply(
537+
"256 words * 256 words multiplication",
538+
"123456789" * 256,
539+
"987654321" * 256,
540+
iterations_large,
541+
log_file,
542+
speedup_factors,
543+
)
544+
545+
# Case 39: 512 words * 512 words multiplication
546+
run_benchmark_multiply(
547+
"512 words * 512 words multiplication",
548+
"123456789" * 512,
549+
"987654321" * 512,
550+
iterations_large,
551+
log_file,
552+
speedup_factors,
553+
)
554+
555+
# Case 40: 1024 words * 1024 words multiplication
556+
run_benchmark_multiply(
557+
"1024 words * 1024 words multiplication",
558+
"123456789" * 1024,
559+
"987654321" * 1024,
560+
iterations_large,
561+
log_file,
562+
speedup_factors,
563+
)
564+
565+
# Case 41: 2048 words * 2048 words multiplication
566+
run_benchmark_multiply(
567+
"2048 words * 2048 words multiplication",
568+
"123456789" * 2048,
569+
"987654321" * 2048,
570+
iterations_large,
571+
log_file,
572+
speedup_factors,
573+
)
574+
575+
# Case 42: 4096 words * 4096 words multiplication
576+
run_benchmark_multiply(
577+
"4096 words * 4096 words multiplication",
578+
"123456789" * 4096,
579+
"987654321" * 4096,
580+
iterations_large,
581+
log_file,
582+
speedup_factors,
583+
)
584+
483585
# Calculate average speedup factor
484586
var sum_speedup: Float64 = 0.0
485587
for i in range(len(speedup_factors)):

benches/biguint/bench_biguint_multiply_complexity.mojo

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ fn main() raises:
129129

130130
log_print("", log_file)
131131
log_print(
132-
"Testing word sizes from 32 to 262144 words (powers of 2)", log_file
132+
"Testing word sizes from 8 to 262144 words (powers of 2)", log_file
133133
)
134134
log_print("Each test uses 5 iterations for averaging", log_file)
135135
log_print(
@@ -138,8 +138,10 @@ fn main() raises:
138138
)
139139
log_print("", log_file)
140140

141-
# Test sizes: powers of 2 from 32 to 262144
141+
# Test sizes: powers of 2 from 8 to 262144
142142
var test_sizes = List[Int]()
143+
test_sizes.append(8)
144+
test_sizes.append(16)
143145
test_sizes.append(32)
144146
test_sizes.append(64)
145147
test_sizes.append(128)

src/decimojo/bigdecimal/arithmetics.mojo

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ fn add(x1: BigDecimal, x2: BigDecimal) raises -> BigDecimal:
7070
return x1.extend_precision(scale_factor1)
7171

7272
# Scale coefficients to match
73-
var coef1 = x1.coefficient.scale_up_by_power_of_10(scale_factor1)
74-
var coef2 = x2.coefficient.scale_up_by_power_of_10(scale_factor2)
73+
var coef1 = x1.coefficient.multiply_by_power_of_ten(scale_factor1)
74+
var coef2 = x2.coefficient.multiply_by_power_of_ten(scale_factor2)
7575

7676
# Handle addition based on signs
7777
if x1.sign == x2.sign:
@@ -135,8 +135,8 @@ fn subtract(x1: BigDecimal, x2: BigDecimal) raises -> BigDecimal:
135135
return result^
136136

137137
# Scale coefficients to match
138-
var coef1 = x1.coefficient.scale_up_by_power_of_10(scale_factor1)
139-
var coef2 = x2.coefficient.scale_up_by_power_of_10(scale_factor2)
138+
var coef1 = x1.coefficient.multiply_by_power_of_ten(scale_factor1)
139+
var coef2 = x2.coefficient.multiply_by_power_of_ten(scale_factor2)
140140

141141
# Handle subtraction based on signs
142142
if x1.sign != x2.sign:
@@ -278,7 +278,7 @@ fn true_divide(
278278
# Scale up the dividend to ensure sufficient precision
279279
var scaled_x1 = x1.coefficient
280280
if additional_digits > 0:
281-
scaled_x1.scale_up_inplace_by_power_of_10(additional_digits)
281+
scaled_x1.multiply_inplace_by_power_of_ten(additional_digits)
282282

283283
# Perform division
284284
var quotient: BigUInt
@@ -301,7 +301,7 @@ fn true_divide(
301301
if is_exact:
302302
var num_trailing_zeros = quotient.number_of_trailing_zeros()
303303
if num_trailing_zeros > 0:
304-
quotient = quotient.scale_down_by_power_of_10(num_trailing_zeros)
304+
quotient = quotient.floor_divide_by_power_of_ten(num_trailing_zeros)
305305
result_scale -= num_trailing_zeros
306306
# Recalculate digits after removing trailing zeros
307307
result_digits = quotient.number_of_digits()
@@ -382,7 +382,7 @@ fn true_divide_inexact(
382382
# Scale up the dividend to ensure sufficient precision
383383
var scaled_x1 = x1.coefficient
384384
if buffer_digits > 0:
385-
scaled_x1.scale_up_inplace_by_power_of_10(buffer_digits)
385+
scaled_x1.multiply_inplace_by_power_of_ten(buffer_digits)
386386

387387
# Perform division
388388
var quotient: BigUInt = scaled_x1 // x2.coefficient
@@ -437,12 +437,12 @@ fn truncate_divide(x1: BigDecimal, x2: BigDecimal) raises -> BigDecimal:
437437
# If scale_diff is positive, we need to scale up the dividend
438438
# If scale_diff is negative, we need to scale up the divisor
439439
if scale_diff > 0:
440-
var divisor = x2.coefficient.scale_up_by_power_of_10(scale_diff)
440+
var divisor = x2.coefficient.multiply_by_power_of_ten(scale_diff)
441441
var quotient = x1.coefficient.truncate_divide(divisor)
442442
return BigDecimal(quotient^, 0, x1.sign != x2.sign)
443443

444444
else: # scale_diff < 0
445-
var dividend = x1.coefficient.scale_up_by_power_of_10(-scale_diff)
445+
var dividend = x1.coefficient.multiply_by_power_of_ten(-scale_diff)
446446
var quotient = dividend.truncate_divide(x2.coefficient)
447447
return BigDecimal(quotient^, 0, x1.sign != x2.sign)
448448

src/decimojo/bigdecimal/comparison.mojo

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ fn compare_absolute(x1: BigDecimal, x2: BigDecimal) -> Int8:
6363

6464
if scale_diff > 0:
6565
# x1 has larger scale (more decimal places)
66-
var scaled_x2 = x2.coefficient.scale_up_by_power_of_10(scale_diff)
66+
var scaled_x2 = x2.coefficient.multiply_by_power_of_ten(scale_diff)
6767
return x1.coefficient.compare(scaled_x2^)
6868
else:
6969
# x2 has larger scale (more decimal places)
70-
var scaled_x1 = x1.coefficient.scale_up_by_power_of_10(-scale_diff)
70+
var scaled_x1 = x1.coefficient.multiply_by_power_of_ten(-scale_diff)
7171
return scaled_x1.compare(x2.coefficient)
7272

7373

src/decimojo/bigdecimal/exponential.mojo

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,13 @@ fn integer_power(
156156
var abs_exp = abs(exponent)
157157
var exp_value: BigUInt
158158
if abs_exp.scale > 0:
159-
exp_value = abs_exp.coefficient.scale_down_by_power_of_10(abs_exp.scale)
159+
exp_value = abs_exp.coefficient.floor_divide_by_power_of_ten(
160+
abs_exp.scale
161+
)
160162
elif abs_exp.scale == 0:
161163
exp_value = abs_exp.coefficient
162164
else:
163-
exp_value = abs_exp.coefficient.scale_up_by_power_of_10(-abs_exp.scale)
165+
exp_value = abs_exp.coefficient.multiply_by_power_of_ten(-abs_exp.scale)
164166

165167
var result = BigDecimal(BigUInt.ONE, 0, False)
166168
var current_power = base
@@ -357,7 +359,7 @@ fn integer_root(
357359
# Convert n to integer to check odd/even
358360
var n_uint: BigUInt
359361
if n.scale > 0:
360-
n_uint = n.coefficient.scale_down_by_power_of_10(n.scale)
362+
n_uint = n.coefficient.floor_divide_by_power_of_ten(n.scale)
361363
else: # n.scale <= 0
362364
n_uint = n.coefficient
363365

src/decimojo/bigint/bigint.mojo

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,8 +523,10 @@ struct BigInt(Absable, IntableRaising, Representable, Stringable, Writable):
523523
@always_inline
524524
fn __iadd__(mut self, other: Int) raises:
525525
# Optimize the case `i += 1`
526-
if (self >= 0) and (other == 1):
527-
self.magnitude.add_inplace_by_1()
526+
if (self >= 0) and (other >= 0) and (other <= 999_999_999):
527+
decimojo.biguint.arithmetics.add_inplace_by_uint32(
528+
self.magnitude, UInt32(other)
529+
)
528530
else:
529531
decimojo.bigint.arithmetics.add_inplace(self, other)
530532

0 commit comments

Comments
 (0)