From 53b9596f6619f784a57abba138ff14232cfcf437 Mon Sep 17 00:00:00 2001 From: shivasankar Date: Sun, 8 Mar 2026 17:31:42 +0900 Subject: [PATCH 1/5] fix format --- benchmark/bench_derivative.mojo | 4 +++- scijo/fft/fastfourier.mojo | 5 +---- tests/test_differentiate.mojo | 13 ++++++++++--- tests/test_fft.mojo | 1 + tests/test_interpolate.mojo | 9 +++++---- tests/test_quad.mojo | 29 ++++++++++++++++++++++------- tests/test_trapezoid.mojo | 1 + 7 files changed, 43 insertions(+), 19 deletions(-) diff --git a/benchmark/bench_derivative.mojo b/benchmark/bench_derivative.mojo index 102f6f7..05605ac 100644 --- a/benchmark/bench_derivative.mojo +++ b/benchmark/bench_derivative.mojo @@ -91,7 +91,9 @@ fn poly_func[ fn sin_func[ dtype: DType -](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): +](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype +] where dtype.is_floating_point(): return sin(x) diff --git a/scijo/fft/fastfourier.mojo b/scijo/fft/fastfourier.mojo index 17838e8..bdc9eb2 100644 --- a/scijo/fft/fastfourier.mojo +++ b/scijo/fft/fastfourier.mojo @@ -144,10 +144,7 @@ fn _ifft_unnormalized[ for k in range(half_size): var angle = ( - 2.0 - * Constants.pi - * Scalar[dtype.dtype](k) - / Scalar[dtype.dtype](n) + 2.0 * Constants.pi * Scalar[dtype.dtype](k) / Scalar[dtype.dtype](n) ) var twiddle = ComplexSIMD[dtype]( cos(angle).cast[dtype.dtype](), sin(angle).cast[dtype.dtype]() diff --git a/tests/test_differentiate.mojo b/tests/test_differentiate.mojo index c17457c..91b115a 100644 --- a/tests/test_differentiate.mojo +++ b/tests/test_differentiate.mojo @@ -43,7 +43,9 @@ fn cubic_function[ fn sin_function[ dtype: DType -](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): +](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype +] where dtype.is_floating_point(): """ F(x) = sin(x), f'(x) = cos(x). """ @@ -52,7 +54,9 @@ fn sin_function[ fn cos_function[ dtype: DType -](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): +](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype +] where dtype.is_floating_point(): """ F(x) = cos(x), f'(x) = -sin(x). """ @@ -61,7 +65,9 @@ fn cos_function[ fn exp_function[ dtype: DType -](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): +](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype +] where dtype.is_floating_point(): """ F(x) = e^x, f'(x) = e^x. """ @@ -384,5 +390,6 @@ fn test_step_size_parameters() raises: msg="Result should be consistent across step factors", ) + def main(): TestSuite.discover_tests[__functions_in_module()]().run() diff --git a/tests/test_fft.mojo b/tests/test_fft.mojo index 46ad578..22e57fb 100644 --- a/tests/test_fft.mojo +++ b/tests/test_fft.mojo @@ -263,5 +263,6 @@ fn test_error_conditions() raises: except: print("Non-power-of-2 error handling - PASSED") + def main(): TestSuite.discover_tests[__functions_in_module()]().run() diff --git a/tests/test_interpolate.mojo b/tests/test_interpolate.mojo index aa79245..cb25dee 100644 --- a/tests/test_interpolate.mojo +++ b/tests/test_interpolate.mojo @@ -313,9 +313,7 @@ fn test_edge_cases() raises: ) # Test with very small intervals - var x_small = nm.array[nm.f64]( - [0.0, 1e-10, 2e-10], [3] - ) + var x_small = nm.array[nm.f64]([0.0, 1e-10, 2e-10], [3]) var y_small = nm.array[nm.f64]([0.0, 1.0, 2.0], [3]) var interp_small = LinearInterpolator(x_small, y_small) @@ -327,7 +325,9 @@ fn test_edge_cases() raises: var small_test_point = 1.5e-10 var result_small = interp_small(small_test_point) - var scipy_small = Float64(py=py_interp_small(PythonObject(small_test_point))) + var scipy_small = Float64( + py=py_interp_small(PythonObject(small_test_point)) + ) assert_almost_equal( result_small, scipy_small, @@ -516,5 +516,6 @@ fn test_performance_comparison() raises: msg="Large dataset interpolation should match SciPy", ) + def main(): TestSuite.discover_tests[__functions_in_module()]().run() diff --git a/tests/test_quad.mojo b/tests/test_quad.mojo index 4fc96ca..da13f41 100644 --- a/tests/test_quad.mojo +++ b/tests/test_quad.mojo @@ -56,7 +56,9 @@ def test_quad_trigonometric(): # Test ∫sin(x) dx from 0 to π = 2 fn sine[ dtype: DType - ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype + ] where dtype.is_floating_point(): return sin(x) var result = quad[sj.f64, sine](0.0, pi, None) @@ -66,7 +68,9 @@ def test_quad_trigonometric(): # Test ∫cos(x) dx from 0 to π/2 = 1 fn cosine[ dtype: DType - ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype + ] where dtype.is_floating_point(): return cos(x) var result2 = quad[sj.f64, cosine](0.0, pi / 2.0, None) @@ -76,7 +80,9 @@ def test_quad_trigonometric(): # Test ∫sin²(x) dx from 0 to π = π/2 fn sin_squared[ dtype: DType - ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype + ] where dtype.is_floating_point(): var s = sin(x) return s * s @@ -91,7 +97,9 @@ def test_quad_exponential(): # Test ∫e^x dx from 0 to 1 = e - 1 fn exponential[ dtype: DType - ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype + ] where dtype.is_floating_point(): return exp(x) var result = quad[sj.f64, exponential](0.0, 1.0, None) @@ -102,7 +110,9 @@ def test_quad_exponential(): # Test ∫e^(-x) dx from 0 to ∞ ≈ 1 (using large upper bound) fn exp_decay[ dtype: DType - ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype + ] where dtype.is_floating_point(): return exp(-x) var result2 = quad[sj.f64, exp_decay]( @@ -166,7 +176,9 @@ def test_quad_difficult_integrands(): # Using finite bounds that approximate infinity fn gaussian[ dtype: DType - ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype + ] where dtype.is_floating_point(): return exp(-x * x) var result = quad[sj.f64, gaussian](-5.0, 5.0, None, epsrel=1e-8) @@ -176,7 +188,9 @@ def test_quad_difficult_integrands(): # Test oscillatory function sin(x)/x near origin (needs careful handling) fn sinc_like[ dtype: DType - ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype + ] where dtype.is_floating_point(): if abs(x) < 1e-10: return 1.0 # limit as x→0 of sin(x)/x = 1 return sin(x) / x @@ -245,5 +259,6 @@ def test_quad_difficult_integrands(): # # This should integrate to near zero due to oscillation # assert_almost_equal(result2.integral, 0.0, atol=1e-6) + def main(): TestSuite.discover_tests[__functions_in_module()]().run() diff --git a/tests/test_trapezoid.mojo b/tests/test_trapezoid.mojo index 6a76e95..1944532 100644 --- a/tests/test_trapezoid.mojo +++ b/tests/test_trapezoid.mojo @@ -187,5 +187,6 @@ fn test_error_conditions() raises: except: pass + def main(): TestSuite.discover_tests[__functions_in_module()]().run() From dacc86890d01bb25c390198f3d9a1033423ddebc Mon Sep 17 00:00:00 2001 From: shivasankar Date: Fri, 13 Mar 2026 14:29:46 +0900 Subject: [PATCH 2/5] clean up optimzie --- pixi.toml | 6 +-- scijo/integrate/fixed_sample.mojo | 23 +--------- scijo/optimize/root_scalar.mojo | 72 +++++++++++++++++++------------ scijo/optimize/utility.mojo | 3 ++ 4 files changed, 52 insertions(+), 52 deletions(-) diff --git a/pixi.toml b/pixi.toml index 9be9db7..3148618 100644 --- a/pixi.toml +++ b/pixi.toml @@ -42,12 +42,12 @@ backend = { name = "pixi-build-mojo", version = "0.*"} name = "scijo" [package.host-dependencies] -modular = ">=26.2.0.dev2026022717,<27" +modular = ">=26.2.0.dev2026030605,<27" [package.build-dependencies] -modular = ">=26.2.0.dev2026022717,<27" +modular = ">=26.2.0.dev2026030605,<27" numojo = { git = "https://github.com/shivasankarka/NuMojo.git", branch = "pre-0.9"} [package.run-dependencies] -modular = ">=26.2.0.dev2026022717,<27" +modular = ">=26.2.0.dev2026030605,<27" numojo = { git = "https://github.com/shivasankarka/NuMojo.git", branch = "pre-0.9"} diff --git a/scijo/integrate/fixed_sample.mojo b/scijo/integrate/fixed_sample.mojo index fdc9f38..d046961 100644 --- a/scijo/integrate/fixed_sample.mojo +++ b/scijo/integrate/fixed_sample.mojo @@ -12,7 +12,6 @@ Includes the composite trapezoidal rule, Simpson's rule, and Romberg integration """ from numojo.core.ndarray import NDArray, NDArrayShape -from numojo.core.error import NumojoError import numojo as nm @@ -47,28 +46,10 @@ fn trapezoid[ """ if y.ndim != 1: - raise Error( - NumojoError( - category="shape", - message=String( - "Expected y to be 1-D, received ndim={}. Pass a 1-D NDArray" - " for y (e.g. shape (N,))." - ).format(y.ndim), - location="trapezoid(y, dx=1.0)", - ) - ) + raise Error(t"Scijo [trapezoid]: Expected y to be 1-D array, received ndim={y.ndim}.") if y.size == 0: - raise Error( - NumojoError( - category="value", - message=( - "Cannot integrate over an empty array. Provide a non-empty" - " array for y." - ), - location="trapezoid(y, dx=1.0)", - ) - ) + raise Error(t"Scijo [trapezoid]: y.size = 0, Cannot interage over an empty array.") if y.size == 1: return Scalar[dtype](0.0) diff --git a/scijo/optimize/root_scalar.mojo b/scijo/optimize/root_scalar.mojo index 848f336..e6b2e47 100644 --- a/scijo/optimize/root_scalar.mojo +++ b/scijo/optimize/root_scalar.mojo @@ -14,6 +14,9 @@ methods (secant). # TODO: check if we are using the right tolerance conditions in all methods. +# ===----------------------------------------------------------------------=== # +# Root scalar +# ===----------------------------------------------------------------------=== # fn root_scalar[ dtype: DType, @@ -27,6 +30,7 @@ fn root_scalar[ dtype ] ] = None, + *, method: String = "bisect", ]( args: Optional[List[Scalar[dtype]]] = None, @@ -36,7 +40,6 @@ fn root_scalar[ xtol: Scalar[dtype] = 1e-8, rtol: Scalar[dtype] = 1e-8, maxiter: Int = 100, - # options: SolverOptions ) raises -> Scalar[dtype]: """Finds a root of a scalar function using the specified method. @@ -55,32 +58,39 @@ fn root_scalar[ rtol: Relative tolerance for convergence. maxiter: Maximum number of iterations. - Returns: - The approximate root as a Scalar[dtype]. - Raises: Error: If required inputs for the chosen method are missing or invalid. - """ + Returns: + The approximate root of the function. + """ @parameter - if method == "newton" and fprime: + if method == "newton": + if not fprime: + raise Error( + "Scijo [root_scalar]: Derivative fprime must be provided for Newton's method." + ) return newton[dtype, f, fprime.value()](args, x0, xtol, rtol, maxiter) elif method == "bisect": if not bracket: - raise Error("Bracket must be provided for bisection method.") + raise Error("Scijo [root_scalar]: Bracket must be provided for bisection method.") return bisect[dtype, f](args, bracket.value(), xtol, rtol, maxiter) elif method == "secant": if not (x0 and x1): raise Error( - "Initial guesses x0 and x1 must be provided for secant method." + "Scijo [root_scalar]: Initial guesses x0 and x1 must be provided for secant method." ) return secant[dtype, f]( args, x0.value(), x1.value(), xtol, rtol, maxiter ) else: - raise Error("Unsupported method: " + String(method)) + raise Error("Scijo [root_scalar]: Unsupported method: " + String(method)) +# ===----------------------------------------------------------------------=== # +# Root scalar methods +# ===----------------------------------------------------------------------=== # + fn newton[ dtype: DType, f: fn[dtype: DType]( @@ -115,17 +125,17 @@ fn newton[ rtol: Relative tolerance for convergence. maxiter: Maximum number of iterations. - Returns: - The approximate root as a Scalar[dtype]. - Raises: Error: If x0 is not provided or the derivative is zero at any step. + + Returns: + The approximate root as a Scalar[dtype]. """ var xn: Scalar[dtype] if x0: xn = x0.value() else: - raise Error("Initial guess x0 must be provided for Newton's method.") + raise Error("Scijo [newton]: Initial guess x0 must be provided for Newton's method.") for _ in range(maxiter): var fx = f(xn, args) @@ -133,7 +143,7 @@ fn newton[ if fpx == 0: raise Error( - "Derivative is zero. Newton-Raphson step would divide by zero." + "Scijo [newton]: Derivative is zero. Newton-Raphson step would divide by zero." ) var delta = fx / fpx @@ -181,11 +191,11 @@ fn bisect[ rtol: Relative tolerance for convergence. maxiter: Maximum number of iterations. - Returns: - The approximate root as a Scalar[dtype]. - Raises: Error: If f(a) and f(b) do not have opposite signs. + + Returns: + The approximate root as a Scalar[dtype]. """ var a: Scalar[dtype] = bracket[0] var b: Scalar[dtype] = bracket[1] @@ -200,7 +210,7 @@ fn bisect[ if fa * fb > 0: raise Error( - "f(a) and f(b) must have opposite signs (bracket does not enclose a" + "Scijo [newton]: f(a) and f(b) must have opposite signs (bracket does not enclose a" " root)." ) @@ -212,16 +222,15 @@ fn bisect[ return c var tol_x = max(xtol, rtol * abs(c)) - var half_width = (b - a) / 2 + var half_width = abs(b - a) / 2 if half_width <= tol_x: return c if fa * fc < 0: - # Root is in [a, c] b = c else: - # Root is in [c, b] a = c + fa = fc return (a + b) / 2 @@ -238,7 +247,7 @@ fn secant[ xtol: Scalar[dtype] = 1e-8, rtol: Scalar[dtype] = 1e-8, maxiter: Int = 100, -) -> Scalar[dtype]: +) raises -> Scalar[dtype]: """Finds a root using the secant method. A derivative-free method that approximates the derivative using finite @@ -256,6 +265,9 @@ fn secant[ rtol: Relative tolerance for convergence. maxiter: Maximum number of iterations. + Raises: + Error: If zero slope is encountered. + Returns: The approximate root as a Scalar[dtype]. """ @@ -263,12 +275,16 @@ fn secant[ var b: Scalar[dtype] = x1 for _ in range(maxiter): - var f0 = f[dtype](a, args) - var f1 = f[dtype](b, args) + var f0 = f(a, args) + var f1 = f(b, args) + + var denom = f1 - f0 + if denom == 0: + raise Error("Scijo [newton]: Secant method encountered zero slope (f1 - f0 == 0).") - var xn = b - (f1 * (b - a)) / (f1 - f0) + var xn = b - (f1 * (b - a)) / denom - var fxn = f[dtype](xn, args) + var fxn = f(xn, args) var tol_x = max(xtol, rtol * abs(xn)) var tol_f = max(xtol, rtol * abs(fxn)) @@ -278,7 +294,7 @@ fn secant[ if abs(xn - b) <= tol_x: return xn + a = b b = xn - a = x1 - return x1 + return b diff --git a/scijo/optimize/utility.mojo b/scijo/optimize/utility.mojo index 3c5ccd9..8cd7eb3 100644 --- a/scijo/optimize/utility.mojo +++ b/scijo/optimize/utility.mojo @@ -10,6 +10,9 @@ Data structures for returning results from optimization and root-finding routines. """ +# ===----------------------------------------------------------------------=== # +# RootResults +# ===----------------------------------------------------------------------=== # struct RootResults[dtype: DType = DType.float64](): """Result structure for scalar root-finding operations. From 6b94d37b8847fb20facaec35c52fe7609e4d03bc Mon Sep 17 00:00:00 2001 From: shivasankar Date: Fri, 13 Mar 2026 14:31:19 +0900 Subject: [PATCH 3/5] fix format --- scijo/integrate/fixed_sample.mojo | 10 +++++++-- scijo/optimize/root_scalar.mojo | 35 +++++++++++++++++++++++-------- scijo/optimize/utility.mojo | 1 + 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/scijo/integrate/fixed_sample.mojo b/scijo/integrate/fixed_sample.mojo index d046961..3341c8d 100644 --- a/scijo/integrate/fixed_sample.mojo +++ b/scijo/integrate/fixed_sample.mojo @@ -46,10 +46,16 @@ fn trapezoid[ """ if y.ndim != 1: - raise Error(t"Scijo [trapezoid]: Expected y to be 1-D array, received ndim={y.ndim}.") + raise Error( + "Scijo [trapezoid]: Expected y to be 1-D array, received" + t" ndim={y.ndim}." + ) if y.size == 0: - raise Error(t"Scijo [trapezoid]: y.size = 0, Cannot interage over an empty array.") + raise Error( + t"Scijo [trapezoid]: y.size = 0, Cannot interage over an empty" + t" array." + ) if y.size == 1: return Scalar[dtype](0.0) diff --git a/scijo/optimize/root_scalar.mojo b/scijo/optimize/root_scalar.mojo index e6b2e47..eda4365 100644 --- a/scijo/optimize/root_scalar.mojo +++ b/scijo/optimize/root_scalar.mojo @@ -18,6 +18,7 @@ methods (secant). # Root scalar # ===----------------------------------------------------------------------=== # + fn root_scalar[ dtype: DType, f: fn[dtype: DType]( @@ -64,33 +65,42 @@ fn root_scalar[ Returns: The approximate root of the function. """ + @parameter if method == "newton": if not fprime: raise Error( - "Scijo [root_scalar]: Derivative fprime must be provided for Newton's method." + "Scijo [root_scalar]: Derivative fprime must be provided for" + " Newton's method." ) return newton[dtype, f, fprime.value()](args, x0, xtol, rtol, maxiter) elif method == "bisect": if not bracket: - raise Error("Scijo [root_scalar]: Bracket must be provided for bisection method.") + raise Error( + "Scijo [root_scalar]: Bracket must be provided for bisection" + " method." + ) return bisect[dtype, f](args, bracket.value(), xtol, rtol, maxiter) elif method == "secant": if not (x0 and x1): raise Error( - "Scijo [root_scalar]: Initial guesses x0 and x1 must be provided for secant method." + "Scijo [root_scalar]: Initial guesses x0 and x1 must be" + " provided for secant method." ) return secant[dtype, f]( args, x0.value(), x1.value(), xtol, rtol, maxiter ) else: - raise Error("Scijo [root_scalar]: Unsupported method: " + String(method)) + raise Error( + "Scijo [root_scalar]: Unsupported method: " + String(method) + ) # ===----------------------------------------------------------------------=== # # Root scalar methods # ===----------------------------------------------------------------------=== # + fn newton[ dtype: DType, f: fn[dtype: DType]( @@ -135,7 +145,10 @@ fn newton[ if x0: xn = x0.value() else: - raise Error("Scijo [newton]: Initial guess x0 must be provided for Newton's method.") + raise Error( + "Scijo [newton]: Initial guess x0 must be provided for Newton's" + " method." + ) for _ in range(maxiter): var fx = f(xn, args) @@ -143,7 +156,8 @@ fn newton[ if fpx == 0: raise Error( - "Scijo [newton]: Derivative is zero. Newton-Raphson step would divide by zero." + "Scijo [newton]: Derivative is zero. Newton-Raphson step would" + " divide by zero." ) var delta = fx / fpx @@ -210,8 +224,8 @@ fn bisect[ if fa * fb > 0: raise Error( - "Scijo [newton]: f(a) and f(b) must have opposite signs (bracket does not enclose a" - " root)." + "Scijo [newton]: f(a) and f(b) must have opposite signs (bracket" + " does not enclose a root)." ) for _ in range(maxiter): @@ -280,7 +294,10 @@ fn secant[ var denom = f1 - f0 if denom == 0: - raise Error("Scijo [newton]: Secant method encountered zero slope (f1 - f0 == 0).") + raise Error( + "Scijo [newton]: Secant method encountered zero slope (f1 - f0" + " == 0)." + ) var xn = b - (f1 * (b - a)) / denom diff --git a/scijo/optimize/utility.mojo b/scijo/optimize/utility.mojo index 8cd7eb3..48d043a 100644 --- a/scijo/optimize/utility.mojo +++ b/scijo/optimize/utility.mojo @@ -14,6 +14,7 @@ Data structures for returning results from optimization and root-finding routine # RootResults # ===----------------------------------------------------------------------=== # + struct RootResults[dtype: DType = DType.float64](): """Result structure for scalar root-finding operations. From 0c8def64e4b2dbb48cf9e88fb1f8660c5585a147 Mon Sep 17 00:00:00 2001 From: shivasankar Date: Fri, 13 Mar 2026 17:40:14 +0900 Subject: [PATCH 4/5] fix t string error --- scijo/integrate/fixed_sample.mojo | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scijo/integrate/fixed_sample.mojo b/scijo/integrate/fixed_sample.mojo index 3341c8d..a0b760e 100644 --- a/scijo/integrate/fixed_sample.mojo +++ b/scijo/integrate/fixed_sample.mojo @@ -47,8 +47,8 @@ fn trapezoid[ if y.ndim != 1: raise Error( - "Scijo [trapezoid]: Expected y to be 1-D array, received" - t" ndim={y.ndim}." + t"Scijo [trapezoid]: Expected y to be 1-D array, received" + t" ndim={{y.ndim}}." ) if y.size == 0: From 1363a308c180df4b6c4206f0a38508e876837cd8 Mon Sep 17 00:00:00 2001 From: shivasankar Date: Fri, 13 Mar 2026 17:40:42 +0900 Subject: [PATCH 5/5] add root_scalar tests --- tests/test_root_scalar.mojo | 62 +++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tests/test_root_scalar.mojo diff --git a/tests/test_root_scalar.mojo b/tests/test_root_scalar.mojo new file mode 100644 index 0000000..d8ed1d7 --- /dev/null +++ b/tests/test_root_scalar.mojo @@ -0,0 +1,62 @@ +from testing import assert_almost_equal, assert_equal +from testing import TestSuite +from math import sqrt +import scijo as sj +from scijo.optimize.root_scalar import root_scalar, newton, bisect, secant + + +fn test_bisect_root_scalar_basic() raises: + fn f[ + dtype: DType + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype]: + return x * x - 2.0 + + var root = root_scalar[sj.f64, f](bracket=(0.0, 2.0)) + assert_almost_equal(root, sqrt(2.0), atol=1e-8) + + var root2 = bisect[sj.f64, f](None, (0.0, 2.0)) + assert_almost_equal(root2, sqrt(2.0), atol=1e-8) + + +fn test_newton_basic() raises: + fn f[ + dtype: DType + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype]: + return x * x - 2.0 + + fn fprime[ + dtype: DType + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype]: + return 2.0 * x + + var root = newton[sj.f64, f, fprime](None, x0=1.0, xtol=1e-12, rtol=1e-12) + assert_almost_equal(root, sqrt(2.0), atol=1e-10) + + +fn test_secant_basic() raises: + fn f[ + dtype: DType + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype]: + return x * x - 2.0 + + var root = secant[sj.f64, f](None, 1.0, 2.0) + assert_almost_equal(root, sqrt(2.0), atol=1e-8) + + +fn test_bisect_invalid_bracket_raises() raises: + fn g[ + dtype: DType + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype]: + return x * x + 1.0 + + try: + var _ = bisect[sj.f64, g](None, (0.0, 1.0)) + assert_equal( + True, False, msg="Expected bisect to raise on invalid bracket." + ) + except: + pass + + +def main(): + TestSuite.discover_tests[__functions_in_module()]().run()