diff --git a/benchmark/bench_derivative.mojo b/benchmark/bench_derivative.mojo index 102f6f7..05605ac 100644 --- a/benchmark/bench_derivative.mojo +++ b/benchmark/bench_derivative.mojo @@ -91,7 +91,9 @@ fn poly_func[ fn sin_func[ dtype: DType -](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): +](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype +] where dtype.is_floating_point(): return sin(x) diff --git a/pixi.toml b/pixi.toml index 9be9db7..3148618 100644 --- a/pixi.toml +++ b/pixi.toml @@ -42,12 +42,12 @@ backend = { name = "pixi-build-mojo", version = "0.*"} name = "scijo" [package.host-dependencies] -modular = ">=26.2.0.dev2026022717,<27" +modular = ">=26.2.0.dev2026030605,<27" [package.build-dependencies] -modular = ">=26.2.0.dev2026022717,<27" +modular = ">=26.2.0.dev2026030605,<27" numojo = { git = "https://github.com/shivasankarka/NuMojo.git", branch = "pre-0.9"} [package.run-dependencies] -modular = ">=26.2.0.dev2026022717,<27" +modular = ">=26.2.0.dev2026030605,<27" numojo = { git = "https://github.com/shivasankarka/NuMojo.git", branch = "pre-0.9"} diff --git a/scijo/fft/__init__.mojo b/scijo/fft/__init__.mojo index 2aaed1d..803c3aa 100644 --- a/scijo/fft/__init__.mojo +++ b/scijo/fft/__init__.mojo @@ -11,9 +11,4 @@ The `fft` module provides Fast Fourier Transform operations for complex-valued arrays. It includes forward and inverse FFT using the Cooley-Tukey algorithm. """ -from numojo.core.complex import ComplexNDArray, ComplexSIMD -from numojo.core.ndarray import NDArray -from numojo.core.layout import NDArrayShape -from numojo.routines.constants import Constants - from .fastfourier import fft, ifft diff --git a/scijo/fft/fastfourier.mojo b/scijo/fft/fastfourier.mojo index 17838e8..00f8753 100644 --- a/scijo/fft/fastfourier.mojo +++ b/scijo/fft/fastfourier.mojo @@ -11,6 +11,8 @@ Forward and inverse Fast Fourier Transform using the Cooley-Tukey radix-2 decimation-in-time algorithm for 1-D complex arrays with power-of-2 lengths. """ +from math import sin, cos + from numojo.core.complex import ComplexNDArray, ComplexSIMD from numojo.core.dtype import ComplexDType from numojo.core.ndarray import NDArray @@ -18,7 +20,9 @@ from numojo.core.layout import NDArrayShape from numojo.routines.constants import Constants from numojo.core.indexing import Item -from math import sin, cos +# ===----------------------------------------------------------------------=== # +# FFT +# ===----------------------------------------------------------------------=== # fn fft[ @@ -38,16 +42,26 @@ fn fft[ arr: Input complex array to transform. Must be 1-dimensional with length that is a power of 2. + Raises: + Error: If the input array is not 1-dimensional. + Error: If the array length is not a power of 2. + Returns: ComplexNDArray containing the FFT of the input array with the same shape and dtype. - Raises: - Error: If the input array is not 1-dimensional. - Error: If the array length is not a power of 2. + Examples: + ```mojo + import numojo as nm + from scijo.fft import fft + from scijo.prelude import * + + var arr = nm.linspace[cf32](CScalar[cf32](0, 0), CScalar[cf32](10, 10), num=10) + var fft_arr = fft(arr) + ``` """ if arr.ndim != 1: - raise Error("FFT currently only supports 1D arrays") + raise Error("Scijo [fft]: FFT currently only supports 1D arrays") var n: Int = arr.shape[0] if n <= 1: @@ -55,8 +69,8 @@ fn fft[ if (n & (n - 1)) != 0: raise Error( - "FFT currently only supports arrays with length that is a power" - " of 2" + "Scijo [fft]: FFT currently only supports arrays with length that" + " is a power of 2" ) var half_size = n // 2 @@ -109,15 +123,15 @@ fn _ifft_unnormalized[ arr: Input complex array to transform. Must be 1-dimensional with length that is a power of 2. - Returns: - ComplexNDArray containing the unnormalized inverse FFT of the input array. - Raises: Error: If the input array is not 1-dimensional. Error: If the array length is not a power of 2. + + Returns: + ComplexNDArray containing the unnormalized inverse FFT of the input array. """ if arr.ndim != 1: - raise Error("FFT currently only supports 1D arrays") + raise Error("Scijo [fft]: FFT currently only supports 1D arrays") var n: Int = arr.shape[0] if n <= 1: @@ -125,8 +139,8 @@ fn _ifft_unnormalized[ if (n & (n - 1)) != 0: raise Error( - "FFT currently only supports arrays with length that is a power" - " of 2" + "Scijo [fft]: FFT currently only supports arrays with length that" + " is a power of 2" ) var half_size = n // 2 @@ -144,10 +158,7 @@ fn _ifft_unnormalized[ for k in range(half_size): var angle = ( - 2.0 - * Constants.pi - * Scalar[dtype.dtype](k) - / Scalar[dtype.dtype](n) + 2.0 * Constants.pi * Scalar[dtype.dtype](k) / Scalar[dtype.dtype](n) ) var twiddle = ComplexSIMD[dtype]( cos(angle).cast[dtype.dtype](), sin(angle).cast[dtype.dtype]() @@ -161,6 +172,11 @@ fn _ifft_unnormalized[ return result^ +# ===----------------------------------------------------------------------=== # +# IFFT +# ===----------------------------------------------------------------------=== # + + fn ifft[ dtype: ComplexDType = ComplexDType.float64 ](arr: ComplexNDArray[dtype]) raises -> ComplexNDArray[ @@ -178,13 +194,24 @@ fn ifft[ arr: Input complex array to transform. Must be 1-dimensional with length that is a power of 2. + Raises: + Error: If the input array is not 1-dimensional. + Error: If the array length is not a power of 2. + Returns: ComplexNDArray containing the IFFT of the input array with the same shape and dtype. - Raises: - Error: If the input array is not 1-dimensional. - Error: If the array length is not a power of 2. + Examples: + ```mojo + import numojo as nm + from scijo.fft import fft, ifft + from scijo.prelude import * + + var arr = nm.linspace[cf32](CScalar[cf32](0, 0), CScalar[cf32](10, 10), num=10) + var freq = fft(arr) + var time = ifft(freq) + ``` """ var n: Int = arr.shape[0] var result = _ifft_unnormalized[dtype](arr) diff --git a/scijo/integrate/__init__.mojo b/scijo/integrate/__init__.mojo index ac99d0e..322e865 100644 --- a/scijo/integrate/__init__.mojo +++ b/scijo/integrate/__init__.mojo @@ -10,7 +10,12 @@ The `integrate` module provides tools for numerical integration and quadrature. It includes adaptive and non-adaptive methods for computing definite integrals, as well as fixed-sample integration rules for discrete data. + +Examples: + ```mojo + from scijo.integrate import quad, trapezoid + ``` """ -from .quad import quad +from .quadrature import quad from .fixed_sample import trapezoid, simpson, romb diff --git a/scijo/integrate/fixed_sample.mojo b/scijo/integrate/fixed_sample.mojo index fdc9f38..ecc2a84 100644 --- a/scijo/integrate/fixed_sample.mojo +++ b/scijo/integrate/fixed_sample.mojo @@ -12,10 +12,13 @@ Includes the composite trapezoidal rule, Simpson's rule, and Romberg integration """ from numojo.core.ndarray import NDArray, NDArrayShape -from numojo.core.error import NumojoError import numojo as nm +# ===----------------------------------------------------------------------=== # +# Trapezoid +# ===----------------------------------------------------------------------=== # + fn trapezoid[ dtype: DType ]( @@ -37,37 +40,35 @@ fn trapezoid[ dx: The spacing between sample points. Defaults to 1.0. axis: The axis along which to integrate. Currently only 1-D is supported. + Raises: + Error: If y is not 1-D. + Error: If y is empty. + Returns: Definite integral approximated by the trapezoidal rule. Returns 0.0 for arrays with fewer than 2 elements. - Raises: - Error: If y is not 1-D. - Error: If y is empty. + Examples: + ```mojo + import numojo as nm + from scijo.integrate import trapezoid + from scijo.prelude import * + + var y = nm.linspace[f64](0.0, 10.0, 100) ** 2 # y = x^2 sampled at 100 points from 0 to 10 + var area = trapezoid(y, dx=0.1) + ``` """ if y.ndim != 1: raise Error( - NumojoError( - category="shape", - message=String( - "Expected y to be 1-D, received ndim={}. Pass a 1-D NDArray" - " for y (e.g. shape (N,))." - ).format(y.ndim), - location="trapezoid(y, dx=1.0)", - ) + t"Scijo [trapezoid]: Expected y to be 1-D array, received" + t" ndim={{y.ndim}}." ) if y.size == 0: raise Error( - NumojoError( - category="value", - message=( - "Cannot integrate over an empty array. Provide a non-empty" - " array for y." - ), - location="trapezoid(y, dx=1.0)", - ) + t"Scijo [trapezoid]: y.size = 0, Cannot interage over an empty" + t" array." ) if y.size == 1: @@ -103,13 +104,24 @@ fn trapezoid[ x: Array of sample points corresponding to the y values. axis: The axis along which to integrate. Currently only 1-D is supported. + Raises: + Error: If y or x are not 1-D, or if their sizes differ. + Error: If y is empty. + Returns: Definite integral approximated by the trapezoidal rule. Returns 0.0 for arrays with fewer than 2 elements. - Raises: - Error: If y or x are not 1-D, or if their sizes differ. - Error: If y is empty. + Examples: + ```mojo + import numojo as nm + from scijo.integrate import trapezoid + from scijo.prelude import * + + var x = nm.linspace[f64](0.0, 10.0, 100) + var y = x * x + var area = trapezoid(y, x) + ``` """ if y.ndim != 1: raise Error( @@ -176,6 +188,10 @@ fn trapezoid[ return integral +# ===----------------------------------------------------------------------=== # +# Simpson +# ===----------------------------------------------------------------------=== # + fn simpson[ dtype: DType ]( @@ -197,11 +213,21 @@ fn simpson[ dx: The spacing between sample points. Defaults to 1.0. axis: The axis along which to integrate. Currently only 1-D is supported. + Raises: + Error: If y is not 1-D. + Returns: Definite integral approximated by Simpson's rule. - Raises: - Error: If y is not 1-D. + Examples: + ```mojo + import numojo as nm + from scijo.integrate import simpson + from scijo.prelude import * + + var y = nm.linspace[f64](0.0, 10.0, 100) ** 2 # y = x^2 sampled at 100 points from 0 to 10 + var area = simpson(y, dx=0.1) + ``` """ if y.ndim != 1: raise Error( @@ -251,11 +277,22 @@ fn simpson[ x: Array of sample points corresponding to the y values. axis: The axis along which to integrate. Currently only 1-D is supported. + Raises: + Error: If y or x are not 1-D, or if their sizes differ. + Returns: Definite integral approximated by Simpson's rule. - Raises: - Error: If y or x are not 1-D, or if their sizes differ. + Examples: + ```mojo + import numojo as nm + from scijo.integrate import simpson + from scijo.prelude import * + + var y = nm.linspace[f64](0.0, 10.0, 100) ** 2 # y = x^2 sampled at 100 points from 0 to 10 + var x = nm.linspace[f64](0.0, 10.0, 100) # x values corresponding to y + var area = simpson(y, x) + ``` """ if y.ndim != 1: raise Error( @@ -317,6 +354,10 @@ fn simpson[ return integral +# ===----------------------------------------------------------------------=== # +# Romberg +# ===----------------------------------------------------------------------=== # + # TODO: fix the loop implementation. fn romb[ dtype: DType @@ -336,13 +377,23 @@ fn romb[ dx: The spacing between sample points. Defaults to 1.0. axis: The axis along which to integrate. Currently only 1-D is supported. + Raises: + Error: If y is not 1-D. + Returns: Definite integral approximated by Romberg integration. - Raises: - Error: If y is not 1-D. + Examples: + ```mojo + import numojo as nm + from scijo.integrate import romb + from scijo.prelude import * + + var y = nm.linspace[f64](0.0, 10.0, 100) ** 2 # y = x^2 sampled at 100 points from 0 to 10 + var area = romb(y, dx=0.1) + ``` """ - var maxiter: Int = 10 + comptime maxiter: Int = 10 if y.ndim != 1: raise Error( NumojoError( diff --git a/scijo/integrate/quad.mojo b/scijo/integrate/quadrature.mojo similarity index 91% rename from scijo/integrate/quad.mojo rename to scijo/integrate/quadrature.mojo index fe50ca7..04a0324 100644 --- a/scijo/integrate/quad.mojo +++ b/scijo/integrate/quadrature.mojo @@ -1,16 +1,21 @@ # ===----------------------------------------------------------------------=== # -# Scijo: Integrate - Quad +# Scijo: Integrate - Quadrature # Distributed under the Apache 2.0 License with LLVM Exceptions. # See LICENSE and the LLVM License for more information. # https://github.com/Mojo-Numerics-and-Algorithms-group/NuMojo/blob/main/LICENSE # https://llvm.org/LICENSE.txt # ===----------------------------------------------------------------------=== # -"""Integrate Module - Adaptive Quadrature (scijo.integrate.quad) +"""Integrate Module - Quadrature (scijo.integrate.quadrature) General-purpose numerical integration using adaptive quadrature methods based on the QUADPACK library. Currently implements the non-adaptive Gauss-Kronrod-Patterson (QNG) algorithm. +Examples: + ```mojo + from scijo.integrate import quad + ``` + References: - SciPy quad documentation: https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.quad.html @@ -43,6 +48,10 @@ from .utility import ( ) +# ===----------------------------------------------------------------------=== # +# Quad +# ===----------------------------------------------------------------------=== # + fn quad[ dtype: DType, func: fn[dtype: DType]( @@ -56,7 +65,7 @@ fn quad[ args: Optional[List[Scalar[dtype]]], epsabs: Scalar[dtype] = 1.49e-8, epsrel: Scalar[dtype] = 1.49e-8, -) raises -> IntegralResult[dtype]: +) raises -> IntegralResult[dtype] where dtype.is_floating_point(): """Computes the definite integral of a scalar function over [a, b]. Dispatches to the appropriate quadrature algorithm based on the `method` @@ -74,12 +83,27 @@ fn quad[ epsabs: Absolute error tolerance. epsrel: Relative error tolerance. + Raises: + Error: If the specified method is not supported. + Returns: IntegralResult[dtype] containing the integral value, absolute error estimate, function evaluation count, and status code. - Raises: - Error: If the specified method is not supported. + Examples: + ```mojo + import numojo as nm + from scijo.integrate import quad + from scijo.prelude import * + + fn integrand[dtype: DType](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype]: + return x * x # Example: f(x) = x^2 + + fn main() raises: + var result = quad[f64, integrand](0.0, 1.0, None) + print("Integral:", result.integral) # Should be close to 1/3 + print("Estimated error:", result.abserr) + ``` """ @parameter @@ -104,7 +128,7 @@ fn _qng[ args: Optional[List[Scalar[dtype]]], epsabs: Scalar[dtype] = 1.49e-8, epsrel: Scalar[dtype] = 1.49e-8, -) -> IntegralResult[dtype]: +) -> IntegralResult[dtype] where dtype.is_floating_point(): """Non-adaptive Gauss-Kronrod-Patterson integration (QUADPACK QNG). Attempts integration using progressively higher-order rules until the @@ -130,10 +154,6 @@ fn _qng[ IntegralResult[dtype] containing the integral value, error estimate, function evaluation count, and status code (ier). """ - constrained[ - dtype.is_floating_point(), "DType must be a floating point type." - ]() - comptime epsilon_mach: Scalar[dtype] = Scalar[dtype]( machine_epsilon[dtype]() ) diff --git a/scijo/integrate/utility.mojo b/scijo/integrate/utility.mojo index bf9efe0..11d412d 100644 --- a/scijo/integrate/utility.mojo +++ b/scijo/integrate/utility.mojo @@ -27,7 +27,11 @@ comptime largest_positive_dtype[dtype: DType] = max_finite[dtype]() # TODO: Remove predefined messages in IntegralResult and add custom result according to the result. -fn machine_epsilon[dtype: DType]() -> Float64: +# ===----------------------------------------------------------------------=== # +# Helpers +# ===----------------------------------------------------------------------=== # + +fn machine_epsilon[dtype: DType]() -> Float64 where dtype.is_floating_point(): """Returns the machine epsilon for the given floating-point dtype. Parameters: @@ -36,18 +40,6 @@ fn machine_epsilon[dtype: DType]() -> Float64: Returns: The machine epsilon as a Float64 value. """ - constrained[ - ( - dtype.is_floating_point() - and ( - dtype == DType.float16 - or dtype == DType.float32 - or dtype == DType.float64 - ) - ), - "DType must be floating point.", - ]() - # TODO: Check if these values are correct lol @parameter if dtype == DType.float16: @@ -58,6 +50,10 @@ fn machine_epsilon[dtype: DType]() -> Float64: return Float64(2.220446049250313e-16) # 2**-52 +# ===----------------------------------------------------------------------=== # +# Adaptive intervals +# ===----------------------------------------------------------------------=== # + struct QAGSInterval[dtype: DType](ImplicitlyCopyable, Movable): """Represents an integration subinterval with error estimate for adaptive subdivision. @@ -243,6 +239,10 @@ fn get_quad_error_message(ier: Int) -> String: return String("Unknown error code.") +# ===----------------------------------------------------------------------=== # +# Result types +# ===----------------------------------------------------------------------=== # + struct IntegralResult[dtype: DType](Copyable, Movable, Writable): """Result structure for numerical integration operations. diff --git a/scijo/interpolate/interpolate.mojo b/scijo/interpolate/interpolate.mojo index b632556..561f882 100644 --- a/scijo/interpolate/interpolate.mojo +++ b/scijo/interpolate/interpolate.mojo @@ -7,9 +7,10 @@ # ===----------------------------------------------------------------------=== # """Interpolate Module - Interpolation Functions (scijo.interpolate.interpolate) -Linear interpolation functions and callable interpolator objects for 1-D data. -Provides both a reusable `LinearInterpolator` struct and a functional `interp1d` -interface. +Linear interpolation utilities for 1-D data. Provides a reusable +`LinearInterpolator` and a functional `interp1d` interface. + +Example: interp = interp1d(x, y); yq = interp(Scalar[DType.float64](0.5)) """ from numojo import zeros @@ -17,6 +18,11 @@ from numojo import zeros from .utility import _binary_search, _validate_interpolation_input +# ===----------------------------------------------------------------------=== # +# Linear interpolator +# ===----------------------------------------------------------------------=== # + + # TODO: Add extrapolation and fill_value handling to LinearInterpolator struct LinearInterpolator[dtype: DType = DType.float64](Copyable, Movable): """A callable linear interpolation object similar to scipy.interpolate.interp1d. @@ -27,6 +33,20 @@ struct LinearInterpolator[dtype: DType = DType.float64](Copyable, Movable): Parameters: dtype: The floating-point data type. Defaults to DType.float64. + + Examples: + ```mojo + import numojo as nm + from scijo.interpolate import interp1d + from scijo.prelude import * + + var x = nm.arange[f64](0.0, 1.0, 0.5) # [0.0, 0.5, 1.0] + var y = nm.array[f64]([0.0, 0.25, 1.0]) # y = x^2 + var interp = interp1d(x, y, bounds_error=False, fill_value=0.0) + var yq1 = interp(Scalar[f64](0.25)) + var yq2 = interp(nm.array[f64]([0.1, 0.5, 0.9])) + var yq3 = interp(Scalar[f64](1.5)) + ``` """ var x: NDArray[Self.dtype] @@ -47,6 +67,8 @@ struct LinearInterpolator[dtype: DType = DType.float64](Copyable, Movable): ) raises: """Initializes the linear interpolator. + Example: LinearInterpolator(x, y, bounds_error=False, fill_value=0.0) + Args: x: The x-coordinates of the data points, must be strictly increasing. y: The y-coordinates of the data points, same length as x. @@ -61,14 +83,16 @@ struct LinearInterpolator[dtype: DType = DType.float64](Copyable, Movable): """ _validate_interpolation_input(x, y) - self.x = x.copy() - self.y = y.copy() + self.x = x.deep_copy() + self.y = y.deep_copy() self.bounds_error = bounds_error self.fill_value = fill_value fn __call__(self, xi: Scalar[Self.dtype]) raises -> Scalar[Self.dtype]: """Interpolates a single value. + Example: yq = interp(Scalar[dtype](0.25)) + Args: xi: The point at which to interpolate. @@ -118,6 +142,8 @@ struct LinearInterpolator[dtype: DType = DType.float64](Copyable, Movable): fn __call__(self, xi: NDArray[Self.dtype]) raises -> NDArray[Self.dtype]: """Interpolates an array of values. + Example: yq = interp(xq_array) + Args: xi: Array of points at which to interpolate. @@ -175,6 +201,11 @@ struct LinearInterpolator[dtype: DType = DType.float64](Copyable, Movable): return result^ +# ===----------------------------------------------------------------------=== # +# interp1d (constructor) +# ===----------------------------------------------------------------------=== # + + # TODO: Add more interpolation methods like 'quadratic', 'cubic'. # TODO: Add both interpolate and extrapolate fill methods. fn interp1d[ @@ -187,6 +218,8 @@ fn interp1d[ ) raises -> LinearInterpolator[dtype]: """Creates a callable LinearInterpolator from data points. + Example: interp = interp1d(x, y, bounds_error=False, fill_value=0.0) + Parameters: dtype: The floating-point data type. Defaults to DType.float64. @@ -198,16 +231,35 @@ fn interp1d[ fill_value: Value to use for out-of-bounds points when bounds_error is False. If None, extrapolate linearly. - Returns: - A callable LinearInterpolator object. - Raises: Error: If x and y have different lengths, have fewer than 2 points, or x is not strictly increasing. + + Returns: + A callable LinearInterpolator object. + + Examples: + ```mojo + import numojo as nm + from scijo.interpolate import interp1d + from scijo.prelude import * + + var x = nm.arange[f64](0.0, 1.0, 0.5) # [0.0, 0.5, 1.0] + var y = nm.array[f64]([0.0, 0.25, 1.0]) # y = x^2 + var interp = interp1d(x, y, bounds_error=False, fill_value=0.0) + var yq1 = interp(Scalar[f64](0.25)) + var yq2 = interp(nm.array[f64]([0.1, 0.5, 0.9])) + var yq3 = interp(Scalar[f64](1.5)) + ``` """ return LinearInterpolator[dtype](x, y, bounds_error, fill_value) +# ===----------------------------------------------------------------------=== # +# interp1d (functional) +# ===----------------------------------------------------------------------=== # + + fn interp1d[ dtype: DType = DType.float64, type: String = "linear", @@ -221,6 +273,8 @@ fn interp1d[ ]: """Interpolates the values of y at the points xi using the specified method. + Example: yq = interp1d(xq, x, y) + Functional interface similar to numpy.interp that directly returns interpolated values without creating a reusable interpolator object. @@ -235,11 +289,23 @@ fn interp1d[ x: Array of x-coordinates of data points, must be strictly increasing. y: Array of y-coordinates of data points, same length as x. + Raises: + Error: If inputs are invalid or method/fill_method is unsupported. + Returns: NDArray of interpolated values at the points xi. - Raises: - Error: If inputs are invalid or method/fill_method is unsupported. + Examples: + ```mojo + import numojo as nm + from scijo.interpolate import interp1d + from scijo.prelude import * + + var x = nm.arange[f64](0.0, 1.0, 0.5) # [0.0, 0.5, 1.0] + var y = x * x # y = x^2 + var xq = nm.array[f64]([0.1, 0.5, 0.9]) + var yq = interp1d[f64, type="linear", fill_method="interpolate"](xq, x, y) + ``` """ _validate_interpolation_input(x, y) @@ -258,6 +324,11 @@ fn interp1d[ ) +# ===----------------------------------------------------------------------=== # +# Internal linear helpers +# ===----------------------------------------------------------------------=== # + + fn _interp1d_linear_interpolate[ dtype: DType ](xi: NDArray[dtype], x: NDArray[dtype], y: NDArray[dtype]) raises -> NDArray[ @@ -265,6 +336,8 @@ fn _interp1d_linear_interpolate[ ]: """Linear interpolation with boundary clamping. + Example: yq = _interp1d_linear_interpolate(xq, x, y) + For points outside the data range, returns the nearest boundary value. Parameters: @@ -309,6 +382,8 @@ fn _interp1d_linear_extrapolate[ ]: """Linear interpolation with linear extrapolation beyond boundaries. + Example: yq = _interp1d_linear_extrapolate(xq, x, y) + For points outside the data range, extrapolates using the slope of the nearest boundary segment. @@ -362,18 +437,19 @@ fn _interp1d_linear_extrapolate[ return result^ -# Higher-order interpolation methods +# ===----------------------------------------------------------------------=== # +# Higher-order interpolation methods +# ===----------------------------------------------------------------------=== # + # fn _interp1d_quadratic_interpolate[dtype: DType]( # xi: NDArray[dtype], x: NDArray[dtype], y: NDArray[dtype] # ) raises -> NDArray[dtype]: # """Quadratic interpolation with boundary clamping.""" -# # Implementation would use 3-point Lagrange interpolation # pass # fn _interp1d_cubic_interpolate[dtype: DType]( # xi: NDArray[dtype], x: NDArray[dtype], y: NDArray[dtype] # ) raises -> NDArray[dtype]: # """Cubic interpolation with boundary clamping.""" -# # Implementation would use 4-point Lagrange interpolation or cubic splines # pass diff --git a/scijo/interpolate/utility.mojo b/scijo/interpolate/utility.mojo index 13936f9..ae74326 100644 --- a/scijo/interpolate/utility.mojo +++ b/scijo/interpolate/utility.mojo @@ -11,6 +11,10 @@ Internal utility functions for interpolation, including binary search and input validation. """ +# ===----------------------------------------------------------------------=== # +# Binary search +# ===----------------------------------------------------------------------=== # + fn _binary_search[ dtype: DType @@ -40,6 +44,11 @@ fn _binary_search[ return right +# ===----------------------------------------------------------------------=== # +# Input validation +# ===----------------------------------------------------------------------=== # + + fn _validate_interpolation_input[ dtype: DType ](x: NDArray[dtype], y: NDArray[dtype]) raises: diff --git a/scijo/optimize/__init__.mojo b/scijo/optimize/__init__.mojo index 22b6d77..ce597c1 100644 --- a/scijo/optimize/__init__.mojo +++ b/scijo/optimize/__init__.mojo @@ -13,3 +13,4 @@ the secant method. """ from .root_scalar import root_scalar, newton, bisect, secant +from .min_scalar import minimize_scalar diff --git a/scijo/optimize/min_scalar.mojo b/scijo/optimize/min_scalar.mojo new file mode 100644 index 0000000..ad5e382 --- /dev/null +++ b/scijo/optimize/min_scalar.mojo @@ -0,0 +1,489 @@ +# ===----------------------------------------------------------------------=== # +# Scijo: Optimize - Minimize Scalar +# Distributed under the Apache 2.0 License with LLVM Exceptions. +# See LICENSE and the LLVM License for more information. +# https://github.com/Mojo-Numerics-and-Algorithms-group/NuMojo/blob/main/LICENSE +# https://llvm.org/LICENSE.txt +# ===----------------------------------------------------------------------=== # +"""Optimize Module - Minimize scalar (scijo.optimize.min_scalar) +""" + +comptime _phi: Float64 = 1.618033988749895 +"""Inverse of the golden ratio, used in optimization algorithms.""" +comptime _invphi: Float64 = 0.3819660112501051 +"""Inverse of the golden ratio conjugate, used in optimization algorithms.""" + +# ===----------------------------------------------------------------------=== # +# Scalar minimization +# ===----------------------------------------------------------------------=== # + + +struct OptimizeResult[dtype: DType](ImplicitlyCopyable, Writable): + """Result structure for scalar minimization operations.""" + + var x: Scalar[Self.dtype] + var fun: Scalar[Self.dtype] + var success: Bool + var status: Int + var message: String + var nit: Int + var nfev: Int + + fn __init__( + out self, + x: Scalar[Self.dtype], + fun: Scalar[Self.dtype], + success: Bool, + status: Int, + message: String, + nit: Int, + nfev: Int, + ): + self.x = x + self.fun = fun + self.success = success + self.status = status + self.message = message + self.nit = nit + self.nfev = nfev + + fn __str__(self) raises -> String: + return ( + t"Result(success={self.success}, x={self.x}, fun={self.fun}," + t" status={self.status}, nit={self.nit}, nfev={self.nfev})" + ) + + fn write_to[W: Writer](self, mut writer: W): + """Writes the array to a writer. + + Args: + writer: The writer to write the array to. + """ + writer.write( + t"Result(success={self.success}, x={self.x}, fun={self.fun}," + t" status={self.status}, nit={self.nit}, nfev={self.nfev})" + ) + + +# ===----------------------------------------------------------------------=== # +# Implementation of scalar minimization algorithms: . +# Brent's method, Golden section search, and bounded minimization +# ===----------------------------------------------------------------------=== # + + +fn _brent_minimize[ + dtype: DType, + f: fn[dtype: DType]( + x: Scalar[dtype], args: Optional[List[Scalar[dtype]]] + ) -> Scalar[dtype], +]( + args: Optional[List[Scalar[dtype]]], + bracket: Tuple[Scalar[dtype], Scalar[dtype]], + tol: Scalar[dtype], + maxiter: Int, +) raises -> OptimizeResult[dtype]: + var a: Scalar[dtype] = bracket[0] + var b: Scalar[dtype] = bracket[1] + + if a == b: + raise Error( + "Scijo [_brent_minimize]: Bracket endpoints must be distinct." + ) + + if a > b: + var tmp = a + a = b + b = tmp + + var fa = f(a, args) + var fb = f(b, args) + var nfev = 2 + + if fb > fa: + var tmpx = a + var tmpf = fa + a = b + fa = fb + b = tmpx + fb = tmpf + + var c = b + (b - a) * Scalar[dtype](_phi) + var fc = f(c, args) + nfev += 1 + var bracket_iter = 0 + while fb > fc and bracket_iter < 50: + a = b + fa = fb + b = c + fb = fc + c = b + (b - a) * Scalar[dtype](_phi) + fc = f(c, args) + nfev += 1 + bracket_iter += 1 + + if a > c: + var tmp2 = a + a = c + c = tmp2 + + var x = b + var w = b + var v = b + var fx = fb + var fw = fb + var fv = fb + var d: Scalar[dtype] = 0 + var e: Scalar[dtype] = 0 + + for i in range(maxiter): + var m = (a + c) / 2 + var tol1 = tol * abs(x) + Scalar[dtype](1e-12) + var tol2 = tol1 * 2 + + if abs(c - a) <= tol2: + return OptimizeResult[dtype]( + x=x, + fun=fx, + success=True, + status=0, + message="Optimization terminated successfully.", + nit=i, + nfev=nfev, + ) + + var p: Scalar[dtype] = 0 + var q: Scalar[dtype] = 0 + var r: Scalar[dtype] = 0 + if abs(e) > tol1: + r = (x - w) * (fx - fv) + q = (x - v) * (fx - fw) + p = (x - v) * q - (x - w) * r + q = (q - r) * 2 + if q > 0: + p = -p + q = abs(q) + if ( + abs(p) < abs(Scalar[dtype](0.5) * q * e) + and p > q * (a - x) + and p < q * (c - x) + ): + d = p / q + var u1 = x + d + if (u1 - a) < tol2 or (c - u1) < tol2: + d = tol1 if x < m else -tol1 + else: + e = c - x if x < m else a - x + d = Scalar[dtype](_invphi) * e + else: + e = c - x if x < m else a - x + d = Scalar[dtype](_invphi) * e + + var u = x + d if abs(d) >= tol1 else x + (tol1 if d > 0 else -tol1) + var fu = f(u, args) + nfev += 1 + + if fu <= fx: + if u < x: + c = x + else: + a = x + v = w + fv = fw + w = x + fw = fx + x = u + fx = fu + else: + if u < x: + a = u + else: + c = u + if fu <= fw or w == x: + v = w + fv = fw + w = u + fw = fu + elif fu <= fv or v == x or v == w: + v = u + fv = fu + + return OptimizeResult[dtype]( + x=x, + fun=fx, + success=False, + status=1, + message="Maximum number of iterations exceeded.", + nit=maxiter, + nfev=nfev, + ) + + +fn _golden_minimize[ + dtype: DType, + f: fn[dtype: DType]( + x: Scalar[dtype], args: Optional[List[Scalar[dtype]]] + ) -> Scalar[dtype], +]( + args: Optional[List[Scalar[dtype]]], + bounds: Tuple[Scalar[dtype], Scalar[dtype]], + tol: Scalar[dtype], + maxiter: Int, +) raises -> OptimizeResult[dtype]: + var a: Scalar[dtype] = bounds[0] + var b: Scalar[dtype] = bounds[1] + + if a == b: + raise Error("Scijo [_golden_minimize]: Bounds must be distinct.") + + if a > b: + var tmp = a + a = b + b = tmp + + var c = b - Scalar[dtype](_invphi) * (b - a) + var d = a + Scalar[dtype](_invphi) * (b - a) + var fc = f(c, args) + var fd = f(d, args) + var nfev = 2 + + for i in range(maxiter): + if abs(b - a) <= tol * (abs(c) + abs(d)): + var x = c if fc < fd else d + var fx = fc if fc < fd else fd + return OptimizeResult[dtype]( + x=x, + fun=fx, + success=True, + status=0, + message="Optimization terminated successfully.", + nit=i, + nfev=nfev, + ) + + if fc < fd: + b = d + d = c + fd = fc + c = b - Scalar[dtype](_invphi) * (b - a) + fc = f(c, args) + nfev += 1 + else: + a = c + c = d + fc = fd + d = a + Scalar[dtype](_invphi) * (b - a) + fd = f(d, args) + nfev += 1 + + var x2 = c if fc < fd else d + var fx2 = fc if fc < fd else fd + return OptimizeResult[dtype]( + x=x2, + fun=fx2, + success=False, + status=1, + message="Maximum number of iterations exceeded.", + nit=maxiter, + nfev=nfev, + ) + + +fn _bounded_minimize[ + dtype: DType, + f: fn[dtype: DType]( + x: Scalar[dtype], args: Optional[List[Scalar[dtype]]] + ) -> Scalar[dtype], +]( + args: Optional[List[Scalar[dtype]]], + bounds: Tuple[Scalar[dtype], Scalar[dtype]], + tol: Scalar[dtype], + maxiter: Int, +) raises -> OptimizeResult[dtype]: + var a: Scalar[dtype] = bounds[0] + var b: Scalar[dtype] = bounds[1] + + if a == b: + raise Error("Scijo [_bounded_minimize]: Bounds must be distinct.") + + if a > b: + var tmp = a + a = b + b = tmp + + var x = a + Scalar[dtype](_invphi) * (b - a) + var w = x + var v = x + var fx = f(x, args) + var fw = fx + var fv = fx + var nfev = 1 + var d: Scalar[dtype] = 0 + var e: Scalar[dtype] = 0 + + for i in range(maxiter): + var m = (a + b) / 2 + var tol1 = tol * abs(x) + Scalar[dtype](1e-12) + var tol2 = tol1 * 2 + + if abs(b - a) <= tol2: + return OptimizeResult[dtype]( + x=x, + fun=fx, + success=True, + status=0, + message="Optimization terminated successfully.", + nit=i, + nfev=nfev, + ) + + var p: Scalar[dtype] = 0 + var q: Scalar[dtype] = 0 + var r: Scalar[dtype] = 0 + if abs(e) > tol1: + r = (x - w) * (fx - fv) + q = (x - v) * (fx - fw) + p = (x - v) * q - (x - w) * r + q = (q - r) * 2 + if q > 0: + p = -p + q = abs(q) + if ( + abs(p) < abs(Scalar[dtype](0.5) * q * e) + and p > q * (a - x) + and p < q * (b - x) + ): + d = p / q + var u1 = x + d + if (u1 - a) < tol2 or (b - u1) < tol2: + d = tol1 if x < m else -tol1 + else: + e = b - x if x < m else a - x + d = Scalar[dtype](_invphi) * e + else: + e = b - x if x < m else a - x + d = Scalar[dtype](_invphi) * e + + var u = x + d if abs(d) >= tol1 else x + (tol1 if d > 0 else -tol1) + if u < a + tol2: + u = a + tol2 + if u > b - tol2: + u = b - tol2 + + var fu = f(u, args) + nfev += 1 + + if fu <= fx: + if u < x: + b = x + else: + a = x + v = w + fv = fw + w = x + fw = fx + x = u + fx = fu + else: + if u < x: + a = u + else: + b = u + if fu <= fw or w == x: + v = w + fv = fw + w = u + fw = fu + elif fu <= fv or v == x or v == w: + v = u + fv = fu + + return OptimizeResult[dtype]( + x=x, + fun=fx, + success=False, + status=1, + message="Maximum number of iterations exceeded.", + nit=maxiter, + nfev=nfev, + ) + + +fn minimize_scalar[ + dtype: DType, + f: fn[dtype: DType]( + x: Scalar[dtype], args: Optional[List[Scalar[dtype]]] + ) -> Scalar[dtype], + *, + method: String = "Brent", +]( + Bracket: Optional[Tuple[Scalar[dtype], Scalar[dtype]]] = None, + bounds: Optional[Tuple[Scalar[dtype], Scalar[dtype]]] = None, + args: Optional[List[Scalar[dtype]]] = None, + tol: Scalar[dtype] = 1e-8, + maxiter: Int = 500, +) raises -> OptimizeResult[dtype]: + """Minimize a scalar function using the specified method. + + Parameters: + dtype: The floating-point data type. + f: Function f(x, args) -> Scalar[dtype] to minimize. + method: Optimization algorithm: "Brent", "Golden", or "Bounded". + + Args: + Bracket: (a, b) tuple specifying an initial interval for bracketing methods. + bounds: (lower, upper) tuple specifying the search interval for bounded methods. + args: Optional arguments forwarded to f. + tol: Tolerance for convergence. + maxiter: Maximum number of iterations. + + Returns: + OptimizeResult[dtype] containing the optimization result. + + Examples: + ```mojo + from scijo.prelude import * + from scijo.optimize import minimize_scalar + + fn objective[dtype: DType](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype]: + return (x - 2) * (x - 2) + 1 + + var result = minimize_scalar[Float64, objective, method="Brent"]( + Bracket=(0.0, 4.0), + tol=1e-8, + maxiter=100 + ) + print(result) + # Output: Result(success=True, x=2.0, fun=1.0, status=0, nit=5, nfev=8) + ``` + """ + + if method == "Brent" or method == "brent": + if Bracket: + return _brent_minimize[dtype, f]( + args, Bracket.value(), tol, maxiter + ) + if bounds: + return _brent_minimize[dtype, f](args, bounds.value(), tol, maxiter) + raise Error("Bracket or bounds must be provided for Brent method.") + + if method == "Golden" or method == "golden": + if Bracket: + return _golden_minimize[dtype, f]( + args, Bracket.value(), tol, maxiter + ) + if bounds: + return _golden_minimize[dtype, f]( + args, bounds.value(), tol, maxiter + ) + raise Error("Bracket or bounds must be provided for Golden method.") + + if method == "Bounded" or method == "bounded": + if not bounds: + raise Error("Bounds must be provided for bounded method.") + return _bounded_minimize[dtype, f](args, bounds.value(), tol, maxiter) + + raise Error( + "Unsupported method: " + + String(method) + + ". Supported methods: 'Brent', 'Golden', 'Bounded'." + ) diff --git a/scijo/optimize/root_scalar.mojo b/scijo/optimize/root_scalar.mojo index 848f336..eda4365 100644 --- a/scijo/optimize/root_scalar.mojo +++ b/scijo/optimize/root_scalar.mojo @@ -14,6 +14,10 @@ methods (secant). # TODO: check if we are using the right tolerance conditions in all methods. +# ===----------------------------------------------------------------------=== # +# Root scalar +# ===----------------------------------------------------------------------=== # + fn root_scalar[ dtype: DType, @@ -27,6 +31,7 @@ fn root_scalar[ dtype ] ] = None, + *, method: String = "bisect", ]( args: Optional[List[Scalar[dtype]]] = None, @@ -36,7 +41,6 @@ fn root_scalar[ xtol: Scalar[dtype] = 1e-8, rtol: Scalar[dtype] = 1e-8, maxiter: Int = 100, - # options: SolverOptions ) raises -> Scalar[dtype]: """Finds a root of a scalar function using the specified method. @@ -55,30 +59,46 @@ fn root_scalar[ rtol: Relative tolerance for convergence. maxiter: Maximum number of iterations. - Returns: - The approximate root as a Scalar[dtype]. - Raises: Error: If required inputs for the chosen method are missing or invalid. + + Returns: + The approximate root of the function. """ @parameter - if method == "newton" and fprime: + if method == "newton": + if not fprime: + raise Error( + "Scijo [root_scalar]: Derivative fprime must be provided for" + " Newton's method." + ) return newton[dtype, f, fprime.value()](args, x0, xtol, rtol, maxiter) elif method == "bisect": if not bracket: - raise Error("Bracket must be provided for bisection method.") + raise Error( + "Scijo [root_scalar]: Bracket must be provided for bisection" + " method." + ) return bisect[dtype, f](args, bracket.value(), xtol, rtol, maxiter) elif method == "secant": if not (x0 and x1): raise Error( - "Initial guesses x0 and x1 must be provided for secant method." + "Scijo [root_scalar]: Initial guesses x0 and x1 must be" + " provided for secant method." ) return secant[dtype, f]( args, x0.value(), x1.value(), xtol, rtol, maxiter ) else: - raise Error("Unsupported method: " + String(method)) + raise Error( + "Scijo [root_scalar]: Unsupported method: " + String(method) + ) + + +# ===----------------------------------------------------------------------=== # +# Root scalar methods +# ===----------------------------------------------------------------------=== # fn newton[ @@ -115,17 +135,20 @@ fn newton[ rtol: Relative tolerance for convergence. maxiter: Maximum number of iterations. - Returns: - The approximate root as a Scalar[dtype]. - Raises: Error: If x0 is not provided or the derivative is zero at any step. + + Returns: + The approximate root as a Scalar[dtype]. """ var xn: Scalar[dtype] if x0: xn = x0.value() else: - raise Error("Initial guess x0 must be provided for Newton's method.") + raise Error( + "Scijo [newton]: Initial guess x0 must be provided for Newton's" + " method." + ) for _ in range(maxiter): var fx = f(xn, args) @@ -133,7 +156,8 @@ fn newton[ if fpx == 0: raise Error( - "Derivative is zero. Newton-Raphson step would divide by zero." + "Scijo [newton]: Derivative is zero. Newton-Raphson step would" + " divide by zero." ) var delta = fx / fpx @@ -181,11 +205,11 @@ fn bisect[ rtol: Relative tolerance for convergence. maxiter: Maximum number of iterations. - Returns: - The approximate root as a Scalar[dtype]. - Raises: Error: If f(a) and f(b) do not have opposite signs. + + Returns: + The approximate root as a Scalar[dtype]. """ var a: Scalar[dtype] = bracket[0] var b: Scalar[dtype] = bracket[1] @@ -200,8 +224,8 @@ fn bisect[ if fa * fb > 0: raise Error( - "f(a) and f(b) must have opposite signs (bracket does not enclose a" - " root)." + "Scijo [newton]: f(a) and f(b) must have opposite signs (bracket" + " does not enclose a root)." ) for _ in range(maxiter): @@ -212,16 +236,15 @@ fn bisect[ return c var tol_x = max(xtol, rtol * abs(c)) - var half_width = (b - a) / 2 + var half_width = abs(b - a) / 2 if half_width <= tol_x: return c if fa * fc < 0: - # Root is in [a, c] b = c else: - # Root is in [c, b] a = c + fa = fc return (a + b) / 2 @@ -238,7 +261,7 @@ fn secant[ xtol: Scalar[dtype] = 1e-8, rtol: Scalar[dtype] = 1e-8, maxiter: Int = 100, -) -> Scalar[dtype]: +) raises -> Scalar[dtype]: """Finds a root using the secant method. A derivative-free method that approximates the derivative using finite @@ -256,6 +279,9 @@ fn secant[ rtol: Relative tolerance for convergence. maxiter: Maximum number of iterations. + Raises: + Error: If zero slope is encountered. + Returns: The approximate root as a Scalar[dtype]. """ @@ -263,12 +289,19 @@ fn secant[ var b: Scalar[dtype] = x1 for _ in range(maxiter): - var f0 = f[dtype](a, args) - var f1 = f[dtype](b, args) + var f0 = f(a, args) + var f1 = f(b, args) + + var denom = f1 - f0 + if denom == 0: + raise Error( + "Scijo [newton]: Secant method encountered zero slope (f1 - f0" + " == 0)." + ) - var xn = b - (f1 * (b - a)) / (f1 - f0) + var xn = b - (f1 * (b - a)) / denom - var fxn = f[dtype](xn, args) + var fxn = f(xn, args) var tol_x = max(xtol, rtol * abs(xn)) var tol_f = max(xtol, rtol * abs(fxn)) @@ -278,7 +311,7 @@ fn secant[ if abs(xn - b) <= tol_x: return xn + a = b b = xn - a = x1 - return x1 + return b diff --git a/scijo/optimize/utility.mojo b/scijo/optimize/utility.mojo index 3c5ccd9..48d043a 100644 --- a/scijo/optimize/utility.mojo +++ b/scijo/optimize/utility.mojo @@ -10,6 +10,10 @@ Data structures for returning results from optimization and root-finding routines. """ +# ===----------------------------------------------------------------------=== # +# RootResults +# ===----------------------------------------------------------------------=== # + struct RootResults[dtype: DType = DType.float64](): """Result structure for scalar root-finding operations. diff --git a/scijo/prelude.mojo b/scijo/prelude.mojo new file mode 100644 index 0000000..551f440 --- /dev/null +++ b/scijo/prelude.mojo @@ -0,0 +1,17 @@ +# ===----------------------------------------------------------------------=== # +# Scijo: A Scientific Computation Library for Mojo +# Distributed under the Apache 2.0 License with LLVM Exceptions. +# See LICENSE and the LLVM License for more information. +# https://github.com/Mojo-Numerics-and-Algorithms-group/NuMojo/blob/main/LICENSE +# https://llvm.org/LICENSE.txt +# ===----------------------------------------------------------------------=== # +""" +scijo.prelude +============= + +The SciJo prelude provides convenient access to core scientific computation tools and modules for Mojo. + +Importing `scijo.prelude` gives you access to all numojo core types such as datatypes (f64, i32 etc), array types (NDArray). +""" + +from numojo.prelude import * diff --git a/tests/test_differentiate.mojo b/tests/test_differentiate.mojo index c17457c..91b115a 100644 --- a/tests/test_differentiate.mojo +++ b/tests/test_differentiate.mojo @@ -43,7 +43,9 @@ fn cubic_function[ fn sin_function[ dtype: DType -](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): +](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype +] where dtype.is_floating_point(): """ F(x) = sin(x), f'(x) = cos(x). """ @@ -52,7 +54,9 @@ fn sin_function[ fn cos_function[ dtype: DType -](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): +](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype +] where dtype.is_floating_point(): """ F(x) = cos(x), f'(x) = -sin(x). """ @@ -61,7 +65,9 @@ fn cos_function[ fn exp_function[ dtype: DType -](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): +](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype +] where dtype.is_floating_point(): """ F(x) = e^x, f'(x) = e^x. """ @@ -384,5 +390,6 @@ fn test_step_size_parameters() raises: msg="Result should be consistent across step factors", ) + def main(): TestSuite.discover_tests[__functions_in_module()]().run() diff --git a/tests/test_fft.mojo b/tests/test_fft.mojo index 46ad578..22e57fb 100644 --- a/tests/test_fft.mojo +++ b/tests/test_fft.mojo @@ -263,5 +263,6 @@ fn test_error_conditions() raises: except: print("Non-power-of-2 error handling - PASSED") + def main(): TestSuite.discover_tests[__functions_in_module()]().run() diff --git a/tests/test_interpolate.mojo b/tests/test_interpolate.mojo index aa79245..cb25dee 100644 --- a/tests/test_interpolate.mojo +++ b/tests/test_interpolate.mojo @@ -313,9 +313,7 @@ fn test_edge_cases() raises: ) # Test with very small intervals - var x_small = nm.array[nm.f64]( - [0.0, 1e-10, 2e-10], [3] - ) + var x_small = nm.array[nm.f64]([0.0, 1e-10, 2e-10], [3]) var y_small = nm.array[nm.f64]([0.0, 1.0, 2.0], [3]) var interp_small = LinearInterpolator(x_small, y_small) @@ -327,7 +325,9 @@ fn test_edge_cases() raises: var small_test_point = 1.5e-10 var result_small = interp_small(small_test_point) - var scipy_small = Float64(py=py_interp_small(PythonObject(small_test_point))) + var scipy_small = Float64( + py=py_interp_small(PythonObject(small_test_point)) + ) assert_almost_equal( result_small, scipy_small, @@ -516,5 +516,6 @@ fn test_performance_comparison() raises: msg="Large dataset interpolation should match SciPy", ) + def main(): TestSuite.discover_tests[__functions_in_module()]().run() diff --git a/tests/test_quad.mojo b/tests/test_quad.mojo index 4fc96ca..da13f41 100644 --- a/tests/test_quad.mojo +++ b/tests/test_quad.mojo @@ -56,7 +56,9 @@ def test_quad_trigonometric(): # Test ∫sin(x) dx from 0 to π = 2 fn sine[ dtype: DType - ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype + ] where dtype.is_floating_point(): return sin(x) var result = quad[sj.f64, sine](0.0, pi, None) @@ -66,7 +68,9 @@ def test_quad_trigonometric(): # Test ∫cos(x) dx from 0 to π/2 = 1 fn cosine[ dtype: DType - ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype + ] where dtype.is_floating_point(): return cos(x) var result2 = quad[sj.f64, cosine](0.0, pi / 2.0, None) @@ -76,7 +80,9 @@ def test_quad_trigonometric(): # Test ∫sin²(x) dx from 0 to π = π/2 fn sin_squared[ dtype: DType - ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype + ] where dtype.is_floating_point(): var s = sin(x) return s * s @@ -91,7 +97,9 @@ def test_quad_exponential(): # Test ∫e^x dx from 0 to 1 = e - 1 fn exponential[ dtype: DType - ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype + ] where dtype.is_floating_point(): return exp(x) var result = quad[sj.f64, exponential](0.0, 1.0, None) @@ -102,7 +110,9 @@ def test_quad_exponential(): # Test ∫e^(-x) dx from 0 to ∞ ≈ 1 (using large upper bound) fn exp_decay[ dtype: DType - ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype + ] where dtype.is_floating_point(): return exp(-x) var result2 = quad[sj.f64, exp_decay]( @@ -166,7 +176,9 @@ def test_quad_difficult_integrands(): # Using finite bounds that approximate infinity fn gaussian[ dtype: DType - ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype + ] where dtype.is_floating_point(): return exp(-x * x) var result = quad[sj.f64, gaussian](-5.0, 5.0, None, epsrel=1e-8) @@ -176,7 +188,9 @@ def test_quad_difficult_integrands(): # Test oscillatory function sin(x)/x near origin (needs careful handling) fn sinc_like[ dtype: DType - ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype] where dtype.is_floating_point(): + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[ + dtype + ] where dtype.is_floating_point(): if abs(x) < 1e-10: return 1.0 # limit as x→0 of sin(x)/x = 1 return sin(x) / x @@ -245,5 +259,6 @@ def test_quad_difficult_integrands(): # # This should integrate to near zero due to oscillation # assert_almost_equal(result2.integral, 0.0, atol=1e-6) + def main(): TestSuite.discover_tests[__functions_in_module()]().run() diff --git a/tests/test_root_scalar.mojo b/tests/test_root_scalar.mojo new file mode 100644 index 0000000..d8ed1d7 --- /dev/null +++ b/tests/test_root_scalar.mojo @@ -0,0 +1,62 @@ +from testing import assert_almost_equal, assert_equal +from testing import TestSuite +from math import sqrt +import scijo as sj +from scijo.optimize.root_scalar import root_scalar, newton, bisect, secant + + +fn test_bisect_root_scalar_basic() raises: + fn f[ + dtype: DType + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype]: + return x * x - 2.0 + + var root = root_scalar[sj.f64, f](bracket=(0.0, 2.0)) + assert_almost_equal(root, sqrt(2.0), atol=1e-8) + + var root2 = bisect[sj.f64, f](None, (0.0, 2.0)) + assert_almost_equal(root2, sqrt(2.0), atol=1e-8) + + +fn test_newton_basic() raises: + fn f[ + dtype: DType + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype]: + return x * x - 2.0 + + fn fprime[ + dtype: DType + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype]: + return 2.0 * x + + var root = newton[sj.f64, f, fprime](None, x0=1.0, xtol=1e-12, rtol=1e-12) + assert_almost_equal(root, sqrt(2.0), atol=1e-10) + + +fn test_secant_basic() raises: + fn f[ + dtype: DType + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype]: + return x * x - 2.0 + + var root = secant[sj.f64, f](None, 1.0, 2.0) + assert_almost_equal(root, sqrt(2.0), atol=1e-8) + + +fn test_bisect_invalid_bracket_raises() raises: + fn g[ + dtype: DType + ](x: Scalar[dtype], args: Optional[List[Scalar[dtype]]]) -> Scalar[dtype]: + return x * x + 1.0 + + try: + var _ = bisect[sj.f64, g](None, (0.0, 1.0)) + assert_equal( + True, False, msg="Expected bisect to raise on invalid bracket." + ) + except: + pass + + +def main(): + TestSuite.discover_tests[__functions_in_module()]().run() diff --git a/tests/test_trapezoid.mojo b/tests/test_trapezoid.mojo index 6a76e95..1944532 100644 --- a/tests/test_trapezoid.mojo +++ b/tests/test_trapezoid.mojo @@ -187,5 +187,6 @@ fn test_error_conditions() raises: except: pass + def main(): TestSuite.discover_tests[__functions_in_module()]().run()