Skip to content

Commit 7699b2d

Browse files
committed
improve line search
1 parent 8ffe70c commit 7699b2d

1 file changed

Lines changed: 27 additions & 33 deletions

File tree

flexsolve/line_search.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,31 @@
55
"""
66
import numpy as np
77
from typing import NamedTuple
8+
from scipy.optimize import fminbound
89

9-
__all__ = ('inexact_line_search', 'LineSearchResult')
10+
__all__ = (
11+
'exact_line_search',
12+
'inexact_line_search',
13+
'LineSearchResult'
14+
)
1015

1116

1217
class LineSearchResult(NamedTuple):
13-
t: float
14-
x: np.ndarray
15-
f: float
18+
t: float #: step length
19+
x: np.ndarray # data point
20+
r: float # residual
1621

1722

18-
# def exact_line_search(f, x, correction, t0=1e-9, t1=2, ttol=1e-6, maxiter=20,
19-
# full_output=True):
20-
# return fminbound(
21-
# lambda t: f(x + t * correction),
22-
# x1=t0,
23-
# x2=t1,
24-
# xtol=ttol,
25-
# maxfun=maxiter,
26-
# full_output=full_output,
27-
# )
23+
def exact_line_search(f, x, correction, t0=1e-9, t1=2, ttol=1e-6, maxiter=20):
24+
t, f, *_ = fminbound(
25+
lambda t: f(x + t * correction),
26+
x1=t0,
27+
x2=t1,
28+
xtol=ttol,
29+
maxfun=maxiter,
30+
full_output=True,
31+
)
32+
return LineSearchResult(t, x + t * correction, f)
2833

2934
def inexact_line_search(
3035
f, x, correction, t0=1e-6, t1=1.1, ttol=1e-3, maxiter=10,
@@ -40,7 +45,9 @@ def inexact_line_search(
4045
if ft1 > fx:
4146
xguess = x + tguess * correction
4247
ftguess = f(xguess)
43-
if ft1 > ft0:
48+
if ftguess < fx: # Done!
49+
best = LineSearchResult(tguess, xguess, ftguess)
50+
elif ft1 > ft0: # Move towards x0.
4451
best = LineSearchResult(t0, x0, ft0)
4552
for i in range(maxiter):
4653
t = rho * tguess + (1 - rho) * t0
@@ -49,29 +56,16 @@ def inexact_line_search(
4956
if ft < best.f: best = LineSearchResult(t, xt, ft)
5057
if abs(t - tguess) < ttol: break
5158
tguess = t
52-
else:
59+
elif ftguess < ft1: # No where to go; moving forward is risky
5360
best = LineSearchResult(tguess, xguess, ftguess)
54-
for i in range(maxiter):
55-
t = (1 - rho) * tguess + rho * t1
56-
xt = x + t * correction
57-
ft = f(xt)
58-
if ft < ftguess: best = LineSearchResult(t, xt, ft)
59-
if abs(t - tguess) < ttol: break
60-
t1 = t
61+
else: # Hope for the best
62+
best = LineSearchResult(t1, x1, ft1)
6163
else:
6264
xguess = x + tguess * correction
6365
ftguess = f(xguess)
64-
if ftguess < ft1:
65-
tnext = tguess
66+
if ftguess < ft1: # Stick with guess
6667
best = LineSearchResult(tguess, xguess, ftguess)
67-
for i in range(maxiter):
68-
t = (1 - rho) * tnext + rho * t1
69-
xt = x + t * correction
70-
ft = f(xt)
71-
if ft < best.f: best = LineSearchResult(t, xt, ft)
72-
if abs(t - tguess) < ttol: break
73-
t1 = t
74-
else:
68+
else: # Accelerate
7569
best = LineSearchResult(t1, x1, ft1)
7670
return best
7771

0 commit comments

Comments
 (0)