55"""
66import numpy as np
77from typing import NamedTuple
8+ from scipy .optimize import fminbound
89
9- __all__ = ('inexact_line_search' , 'LineSearchResult' )
10+ __all__ = (
11+ 'exact_line_search' ,
12+ 'inexact_line_search' ,
13+ 'LineSearchResult'
14+ )
1015
1116
1217class LineSearchResult (NamedTuple ):
13- t : float
14- x : np .ndarray
15- f : float
18+ t : float #: step length
19+ x : np .ndarray # data point
20+ r : float # residual
1621
1722
18- # def exact_line_search(f, x, correction, t0=1e-9, t1=2, ttol=1e-6, maxiter=20,
19- # full_output=True):
20- # return fminbound(
21- # lambda t: f(x + t * correction) ,
22- # x1=t0 ,
23- # x2=t1 ,
24- # xtol=ttol ,
25- # maxfun=maxiter ,
26- # full_output=full_output,
27- # )
23+ def exact_line_search (f , x , correction , t0 = 1e-9 , t1 = 2 , ttol = 1e-6 , maxiter = 20 ):
24+ t , f , * _ = fminbound (
25+ lambda t : f ( x + t * correction ),
26+ x1 = t0 ,
27+ x2 = t1 ,
28+ xtol = ttol ,
29+ maxfun = maxiter ,
30+ full_output = True ,
31+ )
32+ return LineSearchResult ( t , x + t * correction , f )
2833
2934def inexact_line_search (
3035 f , x , correction , t0 = 1e-6 , t1 = 1.1 , ttol = 1e-3 , maxiter = 10 ,
@@ -40,7 +45,9 @@ def inexact_line_search(
4045 if ft1 > fx :
4146 xguess = x + tguess * correction
4247 ftguess = f (xguess )
43- if ft1 > ft0 :
48+ if ftguess < fx : # Done!
49+ best = LineSearchResult (tguess , xguess , ftguess )
50+ elif ft1 > ft0 : # Move towards x0.
4451 best = LineSearchResult (t0 , x0 , ft0 )
4552 for i in range (maxiter ):
4653 t = rho * tguess + (1 - rho ) * t0
@@ -49,29 +56,16 @@ def inexact_line_search(
4956 if ft < best .f : best = LineSearchResult (t , xt , ft )
5057 if abs (t - tguess ) < ttol : break
5158 tguess = t
52- else :
59+ elif ftguess < ft1 : # No where to go; moving forward is risky
5360 best = LineSearchResult (tguess , xguess , ftguess )
54- for i in range (maxiter ):
55- t = (1 - rho ) * tguess + rho * t1
56- xt = x + t * correction
57- ft = f (xt )
58- if ft < ftguess : best = LineSearchResult (t , xt , ft )
59- if abs (t - tguess ) < ttol : break
60- t1 = t
61+ else : # Hope for the best
62+ best = LineSearchResult (t1 , x1 , ft1 )
6163 else :
6264 xguess = x + tguess * correction
6365 ftguess = f (xguess )
64- if ftguess < ft1 :
65- tnext = tguess
66+ if ftguess < ft1 : # Stick with guess
6667 best = LineSearchResult (tguess , xguess , ftguess )
67- for i in range (maxiter ):
68- t = (1 - rho ) * tnext + rho * t1
69- xt = x + t * correction
70- ft = f (xt )
71- if ft < best .f : best = LineSearchResult (t , xt , ft )
72- if abs (t - tguess ) < ttol : break
73- t1 = t
74- else :
68+ else : # Accelerate
7569 best = LineSearchResult (t1 , x1 , ft1 )
7670 return best
7771
0 commit comments