|
3 | 3 | Line search utilities for simultaneous correction methods. |
4 | 4 |
|
5 | 5 | """ |
6 | | -from scipy.optimize import fminbound |
| 6 | +import numpy as np |
| 7 | +from typing import NamedTuple |
7 | 8 |
|
8 | | -__all__ = ('exact_line_search', 'inexact_line_search') |
| 9 | +__all__ = ('inexact_line_search', 'LineSearchResult') |
9 | 10 |
|
10 | | -def exact_line_search(f, x, correction, t0=1e-9, t1=2, ttol=1e-6, maxiter=20, |
11 | | - full_output=True): |
12 | | - return fminbound( |
13 | | - lambda t: f(x + t * correction), |
14 | | - x1=t0, |
15 | | - x2=t1, |
16 | | - xtol=ttol, |
17 | | - maxfun=maxiter, |
18 | | - full_output=full_output, |
19 | | - ) |
| 11 | + |
| 12 | +class LineSearchResult(NamedTuple): |
| 13 | + t: float |
| 14 | + x: np.ndarray |
| 15 | + f: float |
| 16 | + |
| 17 | + |
| 18 | +# def exact_line_search(f, x, correction, t0=1e-9, t1=2, ttol=1e-6, maxiter=20, |
| 19 | +# full_output=True): |
| 20 | +# return fminbound( |
| 21 | +# lambda t: f(x + t * correction), |
| 22 | +# x1=t0, |
| 23 | +# x2=t1, |
| 24 | +# xtol=ttol, |
| 25 | +# maxfun=maxiter, |
| 26 | +# full_output=full_output, |
| 27 | +# ) |
20 | 28 |
|
21 | 29 | def inexact_line_search( |
22 | | - f, x, correction, t0=1e-6, t1=2, ttol=1e-3, maxiter=100, |
| 30 | + f, x, correction, t0=1e-6, t1=1.1, ttol=1e-3, maxiter=10, |
23 | 31 | fx=None, tguess=None, rho=0.618, |
24 | 32 | ): |
25 | 33 | # Find t such that f(x + t * correction) - fx < 0 |
26 | | - if tguess is None or not t0 < tguess < t1: |
27 | | - tguess = 0.5 * (t1 + t0) |
28 | | - fx = None |
29 | | - if fx is None: |
30 | | - fx = f(x) |
| 34 | + if tguess is None or not t0 < tguess < t1: tguess = 0.5 * (t1 + t0) |
| 35 | + if fx is None: fx = f(x) |
| 36 | + x0 = x + t0 * correction |
| 37 | + ft0 = f(x0) |
31 | 38 | x1 = x + t1 * correction |
32 | 39 | ft1 = f(x1) |
33 | 40 | if ft1 > fx: |
34 | 41 | xguess = x + tguess * correction |
35 | 42 | ftguess = f(xguess) |
36 | | - if ftguess > fx: |
| 43 | + if ft1 > ft0: |
| 44 | + best = LineSearchResult(t0, x0, ft0) |
37 | 45 | for i in range(maxiter): |
38 | 46 | t = rho * tguess + (1 - rho) * t0 |
39 | 47 | xt = x + t * correction |
40 | 48 | ft = f(xt) |
41 | | - if ft < fx: return t, xt, ft |
| 49 | + if ft < best.f: best = LineSearchResult(t, xt, ft) |
42 | 50 | if abs(t - tguess) < ttol: break |
43 | 51 | tguess = t |
44 | | - x0 = x + t0 * correction |
45 | | - ft0 = f(x0) |
46 | | - if ft0 >= fx: |
47 | | - raise ValueError( |
48 | | - 'line search could not find improvement over reference point' |
49 | | - ) |
50 | | - return t0, x0, ft0 |
51 | 52 | else: |
| 53 | + best = LineSearchResult(tguess, xguess, ftguess) |
52 | 54 | for i in range(maxiter): |
53 | 55 | t = (1 - rho) * tguess + rho * t1 |
54 | 56 | xt = x + t * correction |
55 | 57 | ft = f(xt) |
56 | | - if ft < ftguess: return t, xt, ft |
| 58 | + if ft < ftguess: best = LineSearchResult(t, xt, ft) |
57 | 59 | if abs(t - tguess) < ttol: break |
58 | 60 | t1 = t |
59 | | - xguess = x + tguess * correction |
60 | | - ftguess = f(xguess) |
61 | | - return tguess, xguess, ftguess |
62 | 61 | else: |
63 | 62 | xguess = x + tguess * correction |
64 | 63 | ftguess = f(xguess) |
65 | 64 | if ftguess < ft1: |
66 | 65 | tnext = tguess |
| 66 | + best = LineSearchResult(tguess, xguess, ftguess) |
67 | 67 | for i in range(maxiter): |
68 | 68 | t = (1 - rho) * tnext + rho * t1 |
69 | 69 | xt = x + t * correction |
70 | 70 | ft = f(xt) |
71 | | - if ft < ftguess: return t, xt, ft |
| 71 | + if ft < best.f: best = LineSearchResult(t, xt, ft) |
72 | 72 | if abs(t - tguess) < ttol: break |
73 | | - tnext = t |
74 | | - return tguess, xguess, ftguess |
| 73 | + t1 = t |
75 | 74 | else: |
76 | | - return t1, x1, ft1 |
| 75 | + best = LineSearchResult(t1, x1, ft1) |
| 76 | + return best |
77 | 77 |
|
78 | 78 |
|
79 | 79 |
|
|
0 commit comments