Skip to content

Commit d30418e

Browse files
authored
[decimal] Implement factorial() and exp() function (#30)
1 parent 07ed6bb commit d30418e

18 files changed

+1969
-168
lines changed

benches/bench_exp.mojo

Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
"""
2+
Comprehensive benchmarks for Decimal exponential function (exp).
3+
Compares performance against Python's decimal module with 20 diverse test cases.
4+
"""
5+
6+
from decimojo.prelude import dm, Decimal, RoundingMode
7+
from python import Python, PythonObject
8+
from time import perf_counter_ns
9+
import time
10+
import os
11+
from collections import List
12+
13+
14+
fn open_log_file() raises -> PythonObject:
15+
"""
16+
Creates and opens a log file with a timestamp in the filename.
17+
18+
Returns:
19+
A file object opened for writing.
20+
"""
21+
var python = Python.import_module("builtins")
22+
var datetime = Python.import_module("datetime")
23+
24+
# Create logs directory if it doesn't exist
25+
var log_dir = "./logs"
26+
if not os.path.exists(log_dir):
27+
os.makedirs(log_dir)
28+
29+
# Generate a timestamp for the filename
30+
var timestamp = String(datetime.datetime.now().isoformat())
31+
var log_filename = log_dir + "/benchmark_exp_" + timestamp + ".log"
32+
33+
print("Saving benchmark results to:", log_filename)
34+
return python.open(log_filename, "w")
35+
36+
37+
fn log_print(msg: String, log_file: PythonObject) raises:
38+
"""
39+
Prints a message to both the console and the log file.
40+
41+
Args:
42+
msg: The message to print.
43+
log_file: The file object to write to.
44+
"""
45+
print(msg)
46+
log_file.write(msg + "\n")
47+
log_file.flush() # Ensure the message is written immediately
48+
49+
50+
fn run_benchmark(
51+
name: String,
52+
input_value: String,
53+
iterations: Int,
54+
log_file: PythonObject,
55+
mut speedup_factors: List[Float64],
56+
) raises:
57+
"""
58+
Run a benchmark comparing Mojo Decimal exp with Python Decimal exp.
59+
60+
Args:
61+
name: Name of the benchmark case.
62+
input_value: String representation of value for exp(x).
63+
iterations: Number of iterations to run.
64+
log_file: File object for logging results.
65+
speedup_factors: Mojo List to store speedup factors for averaging.
66+
"""
67+
log_print("\nBenchmark: " + name, log_file)
68+
log_print("Input value: " + input_value, log_file)
69+
70+
# Set up Mojo and Python values
71+
var mojo_decimal = Decimal(input_value)
72+
var pydecimal = Python.import_module("decimal")
73+
var py_decimal = pydecimal.Decimal(input_value)
74+
var py_math = Python.import_module("math")
75+
76+
# Execute the operations once to verify correctness
77+
var mojo_result = dm.exponential.exp(mojo_decimal)
78+
var py_result = py_decimal.exp()
79+
80+
# Display results for verification
81+
log_print("Mojo result: " + String(mojo_result), log_file)
82+
log_print("Python result: " + String(py_result), log_file)
83+
84+
# Benchmark Mojo implementation
85+
var t0 = perf_counter_ns()
86+
for _ in range(iterations):
87+
_ = dm.exponential.exp(mojo_decimal)
88+
var mojo_time = (perf_counter_ns() - t0) / iterations
89+
if mojo_time == 0:
90+
mojo_time = 1 # Prevent division by zero
91+
92+
# Benchmark Python implementation
93+
t0 = perf_counter_ns()
94+
for _ in range(iterations):
95+
_ = py_decimal.exp()
96+
var python_time = (perf_counter_ns() - t0) / iterations
97+
98+
# Calculate speedup factor
99+
var speedup = python_time / mojo_time
100+
speedup_factors.append(Float64(speedup))
101+
102+
# Print results with speedup comparison
103+
log_print(
104+
"Mojo exp(): " + String(mojo_time) + " ns per iteration",
105+
log_file,
106+
)
107+
log_print(
108+
"Python exp(): " + String(python_time) + " ns per iteration",
109+
log_file,
110+
)
111+
log_print("Speedup factor: " + String(speedup), log_file)
112+
113+
114+
fn main() raises:
115+
# Open log file
116+
var log_file = open_log_file()
117+
var datetime = Python.import_module("datetime")
118+
119+
# Create a Mojo List to store speedup factors for averaging later
120+
var speedup_factors = List[Float64]()
121+
122+
# Display benchmark header with system information
123+
log_print("=== DeciMojo Exponential Function (exp) Benchmark ===", log_file)
124+
log_print("Time: " + String(datetime.datetime.now().isoformat()), log_file)
125+
126+
# Try to get system info
127+
try:
128+
var platform = Python.import_module("platform")
129+
log_print(
130+
"System: "
131+
+ String(platform.system())
132+
+ " "
133+
+ String(platform.release()),
134+
log_file,
135+
)
136+
log_print("Processor: " + String(platform.processor()), log_file)
137+
log_print(
138+
"Python version: " + String(platform.python_version()), log_file
139+
)
140+
except:
141+
log_print("Could not retrieve system information", log_file)
142+
143+
var iterations = 100
144+
var pydecimal = Python().import_module("decimal")
145+
146+
# Set Python decimal precision to match Mojo's
147+
pydecimal.getcontext().prec = 28
148+
log_print(
149+
"Python decimal precision: " + String(pydecimal.getcontext().prec),
150+
log_file,
151+
)
152+
log_print("Mojo decimal precision: " + String(Decimal.MAX_SCALE), log_file)
153+
154+
# Define benchmark cases
155+
log_print(
156+
"\nRunning exponential function benchmarks with "
157+
+ String(iterations)
158+
+ " iterations each",
159+
log_file,
160+
)
161+
162+
# Case 1: exp(0) = 1
163+
run_benchmark(
164+
"exp(0) = 1",
165+
"0",
166+
iterations,
167+
log_file,
168+
speedup_factors,
169+
)
170+
171+
# Case 2: exp(1) ≈ e
172+
run_benchmark(
173+
"exp(1) ≈ e",
174+
"1",
175+
iterations,
176+
log_file,
177+
speedup_factors,
178+
)
179+
180+
# Case 3: exp(2) ≈ 7.389...
181+
run_benchmark(
182+
"exp(2)",
183+
"2",
184+
iterations,
185+
log_file,
186+
speedup_factors,
187+
)
188+
189+
# Case 4: exp(-1) = 1/e
190+
run_benchmark(
191+
"exp(-1) = 1/e",
192+
"-1",
193+
iterations,
194+
log_file,
195+
speedup_factors,
196+
)
197+
198+
# Case 5: exp(0.5) ≈ sqrt(e)
199+
run_benchmark(
200+
"exp(0.5) ≈ sqrt(e)",
201+
"0.5",
202+
iterations,
203+
log_file,
204+
speedup_factors,
205+
)
206+
207+
# Case 6: exp(-0.5) ≈ 1/sqrt(e)
208+
run_benchmark(
209+
"exp(-0.5) ≈ 1/sqrt(e)",
210+
"-0.5",
211+
iterations,
212+
log_file,
213+
speedup_factors,
214+
)
215+
216+
# Case 7: exp with small positive value
217+
run_benchmark(
218+
"Small positive value",
219+
"0.0001",
220+
iterations,
221+
log_file,
222+
speedup_factors,
223+
)
224+
225+
# Case 8: exp with very small positive value
226+
run_benchmark(
227+
"Very small positive value",
228+
"0.000000001",
229+
iterations,
230+
log_file,
231+
speedup_factors,
232+
)
233+
234+
# Case 9: exp with small negative value
235+
run_benchmark(
236+
"Small negative value",
237+
"-0.0001",
238+
iterations,
239+
log_file,
240+
speedup_factors,
241+
)
242+
243+
# Case 10: exp with very small negative value
244+
run_benchmark(
245+
"Very small negative value",
246+
"-0.000000001",
247+
iterations,
248+
log_file,
249+
speedup_factors,
250+
)
251+
252+
# Case 11: exp with moderate value (e^3)
253+
run_benchmark(
254+
"Moderate value (e^3)",
255+
"3",
256+
iterations,
257+
log_file,
258+
speedup_factors,
259+
)
260+
261+
# Case 12: exp with moderate negative value (e^-3)
262+
run_benchmark(
263+
"Moderate negative value (e^-3)",
264+
"-3",
265+
iterations,
266+
log_file,
267+
speedup_factors,
268+
)
269+
270+
# Case 13: exp with large value (e^10)
271+
run_benchmark(
272+
"Large value (e^10)",
273+
"10",
274+
iterations,
275+
log_file,
276+
speedup_factors,
277+
)
278+
279+
# Case 14: exp with large negative value (e^-10)
280+
run_benchmark(
281+
"Large negative value (e^-10)",
282+
"-10",
283+
iterations,
284+
log_file,
285+
speedup_factors,
286+
)
287+
288+
# Case 15: exp with Pi
289+
run_benchmark(
290+
"exp(π)",
291+
"3.14159265358979323846",
292+
iterations,
293+
log_file,
294+
speedup_factors,
295+
)
296+
297+
# Case 16: exp with high precision input
298+
run_benchmark(
299+
"High precision input",
300+
"1.234567890123456789",
301+
iterations,
302+
log_file,
303+
speedup_factors,
304+
)
305+
306+
# Case 17: exp with fractional value
307+
run_benchmark(
308+
"Fractional value (e^1.5)",
309+
"1.5",
310+
iterations,
311+
log_file,
312+
speedup_factors,
313+
)
314+
315+
# Case 18: exp with negative fractional value
316+
run_benchmark(
317+
"Negative fractional value (e^-1.5)",
318+
"-1.5",
319+
iterations,
320+
log_file,
321+
speedup_factors,
322+
)
323+
324+
# Case 19: exp with approximate e value
325+
run_benchmark(
326+
"Approximate e value",
327+
"2.718281828459045",
328+
iterations,
329+
log_file,
330+
speedup_factors,
331+
)
332+
333+
# Case 20: exp with larger value (e^15)
334+
run_benchmark(
335+
"Larger value (e^15)",
336+
"15",
337+
iterations,
338+
log_file,
339+
speedup_factors,
340+
)
341+
342+
# Calculate average speedup factor
343+
var sum_speedup: Float64 = 0.0
344+
for i in range(len(speedup_factors)):
345+
sum_speedup += speedup_factors[i]
346+
var average_speedup = sum_speedup / Float64(len(speedup_factors))
347+
348+
# Display summary
349+
log_print("\n=== Exponential Function Benchmark Summary ===", log_file)
350+
log_print("Benchmarked: 20 different exp() cases", log_file)
351+
log_print(
352+
"Each case ran: " + String(iterations) + " iterations", log_file
353+
)
354+
log_print("Average speedup: " + String(average_speedup) + "×", log_file)
355+
356+
# List all speedup factors
357+
log_print("\nIndividual speedup factors:", log_file)
358+
for i in range(len(speedup_factors)):
359+
log_print(
360+
String("Case {}: {}×").format(i + 1, round(speedup_factors[i], 2)),
361+
log_file,
362+
)
363+
364+
# Close the log file
365+
log_file.close()
366+
print("Benchmark completed. Log file closed.")

benches/bench_multiply.mojo

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,24 @@ fn main() raises:
297297
log_file,
298298
)
299299

300+
# Case 11: Decimal multiplication with many digits after the decimal point
301+
var case11_a_mojo = Decimal.E()
302+
var case11_b_mojo = Decimal.E05()
303+
var case11_a_py = pydecimal.Decimal("1").exp()
304+
var case11_b_py = pydecimal.Decimal("0.5").exp()
305+
run_benchmark(
306+
"e * e^0.5",
307+
case11_a_mojo,
308+
case11_b_mojo,
309+
case11_a_py,
310+
case11_b_py,
311+
iterations,
312+
log_file,
313+
)
314+
300315
# Display summary
301316
log_print("\n=== Multiplication Benchmark Summary ===", log_file)
302-
log_print("Benchmarked: 10 different multiplication cases", log_file)
317+
log_print("Benchmarked: 11 different multiplication cases", log_file)
303318
log_print(
304319
"Each case ran: " + String(iterations) + " iterations", log_file
305320
)

docs/todo.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# TODO
2+
3+
This is a to-do list for Yuhao's personal use.
4+
5+
- The `exp()` function performs slower than Python's counterpart in specific cases. Detailed investigation reveals the bottleneck stems from multiplication operations between decimals with significant fractional components. These operations currently rely on UInt256 arithmetic, which introduces performance overhead. Optimization of the `multiply()` function is required to address these performance bottlenecks, particularly for high-precision decimal multiplication with many digits after the decimal point.

0 commit comments

Comments
 (0)