Skip to content

Commit d9bb737

Browse files
authored
[integer] Improve BigUInt subtraction with SIMD (#102)
This pull request introduces enhancements to the BigUInt benchmarks and arithmetic operations, focusing on subtraction benchmarks, SIMD optimizations, and new arithmetic methods. The changes improve performance, add comprehensive benchmarking capabilities, and refine the implementation of subtraction and addition operations. ### SIMD Optimizations for Arithmetic Operations: * Introduced SIMD-based subtraction methods (`subtract_simd`) for improved performance, replacing traditional methods in `subtract` and `subtract_inplace`. (`src/decimojo/biguint/arithmetics.mojo`) [[1]](diffhunk://#diff-95a5c66e5957fff2a2a8a2710ffb11a4ad7d0bddaf3bc11d075d62ddaa8b915cR513-R587) [[2]](diffhunk://#diff-95a5c66e5957fff2a2a8a2710ffb11a4ad7d0bddaf3bc11d075d62ddaa8b915cL500-R660) * Added SIMD-based addition methods (`add_simd`) and optimized `add_inplace` with SIMD operations. (`src/decimojo/biguint/arithmetics.mojo`) [[1]](diffhunk://#diff-95a5c66e5957fff2a2a8a2710ffb11a4ad7d0bddaf3bc11d075d62ddaa8b915cR36-R37) [[2]](diffhunk://#diff-95a5c66e5957fff2a2a8a2710ffb11a4ad7d0bddaf3bc11d075d62ddaa8b915cR367-R370) ### New Arithmetic Methods: * Added new methods for addition (`add_slices`) and multiplication (`multiply_inplace_by_uint32`), expanding functionality for BigUInt arithmetic. (`src/decimojo/biguint/arithmetics.mojo`) [[1]](diffhunk://#diff-95a5c66e5957fff2a2a8a2710ffb11a4ad7d0bddaf3bc11d075d62ddaa8b915cR36-R37) [[2]](diffhunk://#diff-95a5c66e5957fff2a2a8a2710ffb11a4ad7d0bddaf3bc11d075d62ddaa8b915cR49) * Implemented `normalize_carries` to handle carry propagation efficiently during arithmetic operations. (`src/decimojo/biguint/arithmetics.mojo`) [[1]](diffhunk://#diff-95a5c66e5957fff2a2a8a2710ffb11a4ad7d0bddaf3bc11d075d62ddaa8b915cR68) [[2]](diffhunk://#diff-95a5c66e5957fff2a2a8a2710ffb11a4ad7d0bddaf3bc11d075d62ddaa8b915cL1706-R1826) ### Enhancements to BigUInt Benchmarks: * Added `bench_biguint_subtraction` module for benchmarking BigUInt subtraction against Python's `int`, including detailed logging and performance comparison across various test cases. (`benches/biguint/bench_biguint_subtraction.mojo`) * Updated `benches/biguint/bench.mojo` to include subtraction benchmarks in the menu and execution logic. (`benches/biguint/bench.mojo`) [[1]](diffhunk://#diff-34a58f6f5c1b673b01199db91a3b182f1f345caf0797cbcacadd978e96e4cc10R2) [[2]](diffhunk://#diff-34a58f6f5c1b673b01199db91a3b182f1f345caf0797cbcacadd978e96e4cc10R15) [[3]](diffhunk://#diff-34a58f6f5c1b673b01199db91a3b182f1f345caf0797cbcacadd978e96e4cc10R27-R28) ### Bug Fixes and Refinements: * Fixed type mismatches in `multiply_slices` by ensuring consistent casting of `BigUInt.BASE` values. (`src/decimojo/biguint/arithmetics.mojo`) * Corrected a typo in the description of `normalize_carries` ("crray" → "carry"). (`src/decimojo/biguint/arithmetics.mojo`)
1 parent 88e9f7b commit d9bb737

File tree

3 files changed

+513
-57
lines changed

3 files changed

+513
-57
lines changed

benches/biguint/bench.mojo

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from bench_biguint_add import main as bench_add
2+
from bench_biguint_subtraction import main as bench_subtraction
23
from bench_biguint_multiply import main as bench_multiply
34
from bench_biguint_truncate_divide import main as bench_truncate_divide
45
from bench_biguint_from_string import main as bench_from_string
@@ -11,6 +12,7 @@ fn main() raises:
1112
This is the BigUInt Benchmarks
1213
=========================================
1314
add: Add
15+
sub: Subtract
1416
mul: Multiply
1517
div: Truncate divide (//)
1618
fromstr: From string
@@ -22,6 +24,8 @@ q: Exit
2224
var command = input("Type name of bench you want to run: ")
2325
if command == "add":
2426
bench_add()
27+
elif command == "sub":
28+
bench_subtraction()
2529
elif command == "mul":
2630
bench_multiply()
2731
elif command == "div":
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
"""
2+
Comprehensive benchmarks for BigUInt subtraction.
3+
Compares performance against Python's built-in int with diverse test cases.
4+
"""
5+
6+
from decimojo.biguint.biguint import BigUInt
7+
import decimojo.biguint.arithmetics
8+
from python import Python, PythonObject
9+
from time import perf_counter_ns
10+
import time
11+
import os
12+
from collections import List
13+
14+
15+
fn open_log_file() raises -> PythonObject:
16+
"""
17+
Creates and opens a log file with a timestamp in the filename.
18+
19+
Returns:
20+
A file object opened for writing.
21+
"""
22+
var python = Python.import_module("builtins")
23+
var datetime = Python.import_module("datetime")
24+
var pysys = Python.import_module("sys")
25+
pysys.set_int_max_str_digits(1000000)
26+
27+
# Create logs directory if it doesn't exist
28+
var log_dir = "./logs"
29+
if not os.path.exists(log_dir):
30+
os.makedirs(log_dir)
31+
32+
# Generate a timestamp for the filename
33+
var timestamp = String(datetime.datetime.now().isoformat())
34+
var log_filename = (
35+
log_dir + "/benchmark_biguint_subtraction_" + timestamp + ".log"
36+
)
37+
38+
print("Saving benchmark results to:", log_filename)
39+
return python.open(log_filename, "w")
40+
41+
42+
fn log_print(msg: String, log_file: PythonObject) raises:
43+
"""
44+
Prints a message to both the console and the log file.
45+
46+
Args:
47+
msg: The message to print.
48+
log_file: The file object to write to.
49+
"""
50+
print(msg)
51+
log_file.write(msg + "\n")
52+
log_file.flush() # Ensure the message is written immediately
53+
54+
55+
fn run_benchmark_subtraction(
56+
name: String,
57+
value1: String,
58+
value2: String,
59+
iterations: Int,
60+
log_file: PythonObject,
61+
mut speedup_factors: List[Float64],
62+
) raises:
63+
"""
64+
Run a benchmark comparing Mojo BigUInt subtraction with Python int subtraction.
65+
66+
Args:
67+
name: Name of the benchmark case.
68+
value1: String representation of first operand.
69+
value2: String representation of second operand.
70+
iterations: Number of iterations to run.
71+
log_file: File object for logging results.
72+
speedup_factors: Mojo List to store speedup factors for averaging.
73+
"""
74+
log_print("\nBenchmark: " + name, log_file)
75+
log_print("First operand: " + value1, log_file)
76+
log_print("Second operand: " + value2, log_file)
77+
78+
# Set up Mojo and Python values
79+
var mojo_value1 = BigUInt(value1)
80+
var mojo_value2 = BigUInt(value2)
81+
var py = Python.import_module("builtins")
82+
var py_value1 = py.int(value1)
83+
var py_value2 = py.int(value2)
84+
85+
# Execute the operations once to verify correctness
86+
var mojo_result = mojo_value1 - mojo_value2
87+
var py_result = py_value1 - py_value2
88+
89+
# Display results for verification
90+
log_print("Mojo result: " + String(mojo_result), log_file)
91+
log_print("Python result: " + String(py_result), log_file)
92+
93+
# Benchmark Mojo implementation
94+
var t0 = perf_counter_ns()
95+
for _ in range(iterations):
96+
_ = mojo_value1 - mojo_value2
97+
var mojo_time = (perf_counter_ns() - t0) / iterations
98+
if mojo_time == 0:
99+
mojo_time = 1 # Prevent division by zero
100+
101+
# Benchmark Python implementation
102+
t0 = perf_counter_ns()
103+
for _ in range(iterations):
104+
_ = py_value1 - py_value2
105+
var python_time = (perf_counter_ns() - t0) / iterations
106+
107+
# Calculate speedup factor
108+
var speedup = python_time / mojo_time
109+
speedup_factors.append(Float64(speedup))
110+
111+
# Print results with speedup comparison
112+
log_print(
113+
"Mojo subtraction: " + String(mojo_time) + " ns per iteration",
114+
log_file,
115+
)
116+
log_print(
117+
"Python subtraction: " + String(python_time) + " ns per iteration",
118+
log_file,
119+
)
120+
log_print("Speedup factor: " + String(speedup), log_file)
121+
122+
123+
fn main() raises:
124+
# Open log file
125+
var log_file = open_log_file()
126+
var datetime = Python.import_module("datetime")
127+
128+
# Create a Mojo List to store speedup factors for averaging later
129+
var speedup_factors = List[Float64]()
130+
131+
# Display benchmark header with system information
132+
log_print("=== DeciMojo BigUInt Subtraction Benchmark ===", log_file)
133+
log_print("Time: " + String(datetime.datetime.now().isoformat()), log_file)
134+
135+
# Try to get system info
136+
try:
137+
var platform = Python.import_module("platform")
138+
log_print(
139+
"System: "
140+
+ String(platform.system())
141+
+ " "
142+
+ String(platform.release()),
143+
log_file,
144+
)
145+
log_print("Processor: " + String(platform.processor()), log_file)
146+
log_print(
147+
"Python version: " + String(platform.python_version()), log_file
148+
)
149+
except:
150+
log_print("Could not retrieve system information", log_file)
151+
152+
var iterations = 1000
153+
154+
# Define benchmark cases
155+
log_print(
156+
"\nRunning subtraction benchmarks with "
157+
+ String(iterations)
158+
+ " iterations each",
159+
log_file,
160+
)
161+
162+
# Case 26: Subtraction with 2 words - 1 word
163+
run_benchmark_subtraction(
164+
"Subtraction with 2 words - 1 word",
165+
"123456789" * 2,
166+
"987654321",
167+
iterations,
168+
log_file,
169+
speedup_factors,
170+
)
171+
172+
# Case 27: Subtraction with 4 words - 2 words
173+
run_benchmark_subtraction(
174+
"Subtraction with 4 words - 2 words",
175+
"123456789" * 4,
176+
"987654321" * 2,
177+
iterations,
178+
log_file,
179+
speedup_factors,
180+
)
181+
182+
# Case 28: Subtraction with 8 words - 4 words
183+
run_benchmark_subtraction(
184+
"Subtraction with 8 words - 4 words",
185+
"123456789" * 8,
186+
"987654321" * 4,
187+
iterations,
188+
log_file,
189+
speedup_factors,
190+
)
191+
192+
# Case 29: Subtraction with 16 words - 8 words
193+
run_benchmark_subtraction(
194+
"Subtraction with 16 words - 8 words",
195+
"123456789" * 16,
196+
"987654321" * 8,
197+
iterations,
198+
log_file,
199+
speedup_factors,
200+
)
201+
202+
# Case 30: Subtraction with 32 words - 16 words
203+
run_benchmark_subtraction(
204+
"Subtraction with 32 words - 16 words",
205+
"123456789" * 32,
206+
"987654321" * 16,
207+
iterations,
208+
log_file,
209+
speedup_factors,
210+
)
211+
212+
# Case 31: Subtraction with 64 words - 32 words
213+
run_benchmark_subtraction(
214+
"Subtraction with 64 words - 32 words",
215+
"123456789" * 64,
216+
"987654321" * 32,
217+
iterations,
218+
log_file,
219+
speedup_factors,
220+
)
221+
222+
# Case 32: Subtraction with 256 words - 128 words
223+
run_benchmark_subtraction(
224+
"Subtraction with 256 words - 128 words",
225+
"123456789" * 256,
226+
"987654321" * 128,
227+
iterations,
228+
log_file,
229+
speedup_factors,
230+
)
231+
232+
# Case 33: Subtraction with 1024 words - 512 words
233+
run_benchmark_subtraction(
234+
"Subtraction with 1024 words - 512 words",
235+
"123456789" * 1024,
236+
"987654321" * 512,
237+
iterations,
238+
log_file,
239+
speedup_factors,
240+
)
241+
242+
# Case 34: Subtraction with 4096 words - 2048 words
243+
run_benchmark_subtraction(
244+
"Subtraction with 4096 words - 2048 words",
245+
"123456789" * 4096,
246+
"987654321" * 2048,
247+
iterations,
248+
log_file,
249+
speedup_factors,
250+
)
251+
252+
# Case 35: Subtraction with 16384 words - 8192 words
253+
run_benchmark_subtraction(
254+
"Subtraction with 16384 words - 8192 words",
255+
"123456789" * 16384,
256+
"987654321" * 8192,
257+
iterations,
258+
log_file,
259+
speedup_factors,
260+
)
261+
262+
# Case 36: Subtraction with 32768 words - 16384 words
263+
run_benchmark_subtraction(
264+
"Subtraction with 32768 words - 16384 words",
265+
"123456789" * 32768,
266+
"987654321" * 16384,
267+
iterations,
268+
log_file,
269+
speedup_factors,
270+
)
271+
272+
# Calculate average speedup factor
273+
var sum_speedup: Float64 = 0.0
274+
for i in range(len(speedup_factors)):
275+
sum_speedup += speedup_factors[i]
276+
var average_speedup = sum_speedup / Float64(len(speedup_factors))
277+
278+
# Display summary
279+
log_print("\n=== BigUInt Subtraction Benchmark Summary ===", log_file)
280+
log_print("Benchmarked: different subtraction cases", log_file)
281+
log_print(
282+
"Each case ran: " + String(iterations) + " iterations", log_file
283+
)
284+
log_print("Average speedup: " + String(average_speedup) + "×", log_file)
285+
286+
# List all speedup factors
287+
log_print("\nIndividual speedup factors:", log_file)
288+
for i in range(len(speedup_factors)):
289+
log_print(
290+
String("Case {}: {}×").format(i + 1, round(speedup_factors[i], 2)),
291+
log_file,
292+
)
293+
294+
# Close the log file
295+
log_file.close()
296+
print("Benchmark completed. Log file closed.")

0 commit comments

Comments
 (0)