diff --git a/examples/trajectory_optim_2integrator.ipynb b/examples/trajectory_optim_2integrator.ipynb index ed12e05..3aada66 100644 --- a/examples/trajectory_optim_2integrator.ipynb +++ b/examples/trajectory_optim_2integrator.ipynb @@ -259,28 +259,22 @@ "text": [ "compiling optimizer...\n", "👉 solving problem with n_horizon=50, n_states=2 n_inputs=1\n", - "🔄 it=0 \t (sub iter=103)\tt=0 \teq_error/eq_tol=122516 %\tinside bounds: True\n", - "🔄 it=1 \t (sub iter=98)\tt=1 \teq_error/eq_tol=106595 %\tinside bounds: True\n", - "🔄 it=2 \t (sub iter=103)\tt=1 \teq_error/eq_tol=88343 %\tinside bounds: True\n", - "🔄 it=3 \t (sub iter=108)\tt=2 \teq_error/eq_tol=69808 %\tinside bounds: True\n", - "🔄 it=4 \t (sub iter=112)\tt=3 \teq_error/eq_tol=52260 %\tinside bounds: True\n", - "🔄 it=5 \t (sub iter=118)\tt=5 \teq_error/eq_tol=37293 %\tinside bounds: True\n", - "🔄 it=6 \t (sub iter=122)\tt=7 \teq_error/eq_tol=25625 %\tinside bounds: True\n", - "🔄 it=7 \t (sub iter=125)\tt=11 \teq_error/eq_tol=17127 %\tinside bounds: True\n", - "🔄 it=8 \t (sub iter=121)\tt=17 \teq_error/eq_tol=11223 %\tinside bounds: True\n", - "🔄 it=9 \t (sub iter=126)\tt=27 \teq_error/eq_tol=7246 %\tinside bounds: True\n", - "🔄 it=10 \t (sub iter=143)\tt=41 \teq_error/eq_tol=4627 %\tinside bounds: True\n", - "🔄 it=11 \t (sub iter=131)\tt=64 \teq_error/eq_tol=2933 %\tinside bounds: True\n", - "🔄 it=12 \t (sub iter=129)\tt=100 \teq_error/eq_tol=1849 %\tinside bounds: True\n", - "🔄 it=13 \t (sub iter=130)\tt=100 \teq_error/eq_tol=1166 %\tinside bounds: True\n", - "🔄 it=14 \t (sub iter=123)\tt=100 \teq_error/eq_tol=733 %\tinside bounds: True\n", - "🔄 it=15 \t (sub iter=127)\tt=100 \teq_error/eq_tol=460 %\tinside bounds: True\n", - "🔄 it=16 \t (sub iter=121)\tt=100 \teq_error/eq_tol=288 %\tinside bounds: True\n", - "🔄 it=17 \t (sub iter=114)\tt=100 \teq_error/eq_tol=180 %\tinside bounds: True\n", - "🔄 it=18 \t (sub iter=113)\tt=100 \teq_error/eq_tol=113 %\tinside bounds: True\n", - "🔄 it=19 \t (sub iter=113)\tt=100 \teq_error/eq_tol=70 %\tinside bounds: True\n", + "🔄 it=0 \t (sub iter=103)\tt=0 \teq_error/eq_tol=126560 gain=0.0 lambda=1.7323600482924109 \tinside bounds: True\n", + "🔄 it=1 \t (sub iter=99)\tt=1 \teq_error/eq_tol=107763 gain=1.174432068744008 lambda=1.7893927811858676 \tinside bounds: True\n", + "🔄 it=2 \t (sub iter=105)\tt=1 \teq_error/eq_tol=84226 gain=1.2794397639489083 lambda=1.8448023703016776 \tinside bounds: True\n", + "🔄 it=3 \t (sub iter=112)\tt=2 \teq_error/eq_tol=59394 gain=1.4180897312552108 lambda=1.8939759568384897 \tinside bounds: True\n", + "🔄 it=4 \t (sub iter=116)\tt=3 \teq_error/eq_tol=37819 gain=1.5704973307116938 lambda=1.9338016869283867 \tinside bounds: True\n", + "🔄 it=5 \t (sub iter=121)\tt=5 \teq_error/eq_tol=22131 gain=1.7088600474646911 lambda=1.9639259767347728 \tinside bounds: True\n", + "🔄 it=6 \t (sub iter=123)\tt=7 \teq_error/eq_tol=12154 gain=1.8209304511249051 lambda=1.9852507346735249 \tinside bounds: True\n", + "🔄 it=7 \t (sub iter=135)\tt=11 \teq_error/eq_tol=6380 gain=1.9049323806243565 lambda=1.9989625859636135 \tinside bounds: True\n", + "🔄 it=8 \t (sub iter=120)\tt=17 \teq_error/eq_tol=3252 gain=1.9619209707534442 lambda=2.006454412752182 \tinside bounds: True\n", + "🔄 it=9 \t (sub iter=121)\tt=27 \teq_error/eq_tol=1631 gain=1.9943019558508168 lambda=2.0095040800352395 \tinside bounds: True\n", + "🔄 it=10 \t (sub iter=122)\tt=41 \teq_error/eq_tol=812 gain=2.009366471213732 lambda=2.0095040800352395 \tinside bounds: True\n", + "🔄 it=11 \t (sub iter=125)\tt=64 \teq_error/eq_tol=403 gain=2.0136681857660084 lambda=2.0095040800352395 \tinside bounds: True\n", + "🔄 it=12 \t (sub iter=119)\tt=100 \teq_error/eq_tol=200 gain=2.0147061699096214 lambda=2.0095040800352395 \tinside bounds: True\n", + "🔄 it=13 \t (sub iter=114)\tt=100 \teq_error/eq_tol=100 gain=2.006816054891115 lambda=2.0095040800352395 \tinside bounds: True\n", "✅ found feasible solution\n", - "time to run: 2.8347809314727783 seconds\n" + "time to run: 2.8244380950927734 seconds\n" ] }, { @@ -294,7 +288,7 @@ " \n", " \n", " \n", - " 2024-01-13T17:16:21.543538\n", + " 2024-05-05T17:02:24.738861\n", " image/svg+xml\n", " \n", " \n", @@ -329,47 +323,47 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -378,12 +372,12 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -419,7 +413,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -458,7 +452,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -492,7 +486,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -537,7 +531,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -591,7 +585,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -679,116 +673,116 @@ " \n", " \n", " \n", + "L 56.823949 148.782974 \n", + "L 60.882131 146.040111 \n", + "L 64.940313 143.299474 \n", + "L 68.998494 140.56272 \n", + "L 73.056676 137.834495 \n", + "L 77.114858 135.132768 \n", + "L 81.17304 132.512383 \n", + "L 85.231222 130.018074 \n", + "L 89.289403 127.660966 \n", + "L 93.347585 125.444257 \n", + "L 97.405767 123.369222 \n", + "L 101.463949 121.436485 \n", + "L 105.522131 119.646393 \n", + "L 109.580312 117.999158 \n", + "L 113.638494 116.494918 \n", + "L 117.696676 115.133768 \n", + "L 121.754858 113.915774 \n", + "L 125.81304 112.840985 \n", + "L 129.871222 111.909437 \n", + "L 133.929403 111.121159 \n", + "L 137.987585 110.47617 \n", + "L 142.045767 109.974486 \n", + "L 146.103949 109.616119 \n", + "L 150.162131 109.401077 \n", + "L 154.220312 109.329365 \n", + "L 158.278494 109.400985 \n", + "L 162.336676 109.615936 \n", + "L 166.394858 109.974215 \n", + "L 170.45304 110.475815 \n", + "L 174.511222 111.120726 \n", + "L 178.569403 111.908933 \n", + "L 182.627585 112.840418 \n", + "L 186.685767 113.915154 \n", + "L 190.743949 115.133105 \n", + "L 194.802131 116.494225 \n", + "L 198.860312 117.998447 \n", + "L 202.918494 119.645678 \n", + "L 206.976676 121.43578 \n", + "L 211.034858 123.368543 \n", + "L 215.09304 125.44362 \n", + "L 219.151222 127.660387 \n", + "L 223.209403 130.01757 \n", + "L 227.267585 132.511967 \n", + "L 231.325767 135.132444 \n", + "L 235.383949 137.834211 \n", + "L 239.442131 140.562491 \n", + "L 243.500313 143.299292 \n", + "L 247.558494 146.040002 \n", + "L 251.616676 148.782921 \n", + "L 255.674858 151.525903 \n", + "\" clip-path=\"url(#p48be24b81e)\" style=\"fill: none; stroke: #ff0000; stroke-width: 1.5; stroke-linecap: square\"/>\n", " \n", " \n", " \n", + "L 56.823949 151.389576 \n", + "L 60.882131 150.977517 \n", + "L 64.940313 150.291282 \n", + "L 68.998494 149.331174 \n", + "L 73.056676 148.097814 \n", + "L 77.114858 146.592954 \n", + "L 81.17304 144.821985 \n", + "L 85.231222 142.79528 \n", + "L 89.289403 140.526 \n", + "L 93.347585 138.028027 \n", + "L 97.405767 135.315464 \n", + "L 101.463949 132.40251 \n", + "L 105.522131 129.303411 \n", + "L 109.580312 126.032444 \n", + "L 113.638494 122.6039 \n", + "L 117.696676 119.032083 \n", + "L 121.754858 115.331307 \n", + "L 125.81304 111.515889 \n", + "L 129.871222 107.600151 \n", + "L 133.929403 103.59842 \n", + "L 137.987585 99.525022 \n", + "L 142.045767 95.394287 \n", + "L 146.103949 91.220548 \n", + "L 150.162131 87.018135 \n", + "L 154.220312 82.801382 \n", + "L 158.278494 78.584621 \n", + "L 162.336676 74.382187 \n", + "L 166.394858 70.208411 \n", + "L 170.45304 66.077626 \n", + "L 174.511222 62.004164 \n", + "L 178.569403 58.002355 \n", + "L 182.627585 54.086528 \n", + "L 186.685767 50.27101 \n", + "L 190.743949 46.570123 \n", + "L 194.802131 42.998187 \n", + "L 198.860312 39.569515 \n", + "L 202.918494 36.298413 \n", + "L 206.976676 33.199175 \n", + "L 211.034858 30.286077 \n", + "L 215.09304 27.573369 \n", + "L 219.151222 25.07525 \n", + "L 223.209403 22.805826 \n", + "L 227.267585 20.778978 \n", + "L 231.325767 19.007872 \n", + "L 235.383949 17.502874 \n", + "L 239.442131 16.269376 \n", + "L 243.500313 15.30913 \n", + "L 247.558494 14.622756 \n", + "L 251.616676 14.210561 \n", + "L 255.674858 14.073275 \n", + "\" clip-path=\"url(#p48be24b81e)\" style=\"fill: none; stroke: #000000; stroke-width: 1.5; stroke-linecap: square\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p48be24b81e)\" style=\"fill: none; stroke-dasharray: 1.5,2.475; stroke-dashoffset: 0; stroke: #000000; stroke-width: 1.5\"/>\n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -995,7 +989,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1009,7 +1003,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1023,7 +1017,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1071,7 +1065,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1085,7 +1079,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1128,7 +1122,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1151,7 +1145,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1165,7 +1159,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1178,7 +1172,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1191,7 +1185,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1236,67 +1230,67 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", + "\" clip-path=\"url(#p958960d379)\" style=\"fill: none; stroke-dasharray: 1.5,2.475; stroke-dashoffset: 0; stroke: #000000; stroke-width: 1.5\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p958960d379)\" style=\"fill: none; stroke-dasharray: 1.5,2.475; stroke-dashoffset: 0; stroke: #000000; stroke-width: 1.5\"/>\n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1459,7 +1453,7 @@ " \n", " \n", " \n", - " 2024-01-13T17:16:21.831151\n", + " 2024-05-05T17:02:24.997773\n", " image/svg+xml\n", " \n", " \n", @@ -1494,47 +1488,47 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1543,17 +1537,17 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + "\" clip-path=\"url(#pc60dfef3f2)\" style=\"fill: none; stroke: #ff0000; stroke-width: 1.5; stroke-linecap: square\"/>\n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", @@ -3118,7 +2764,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3134,9 +2780,9 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3147,9 +2793,9 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3161,9 +2807,9 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3175,9 +2821,9 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3189,9 +2835,9 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3224,9 +2870,9 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3267,14 +2913,14 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3306,9 +2952,9 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3321,14 +2967,14 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3364,1065 +3010,747 @@ " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", " \n", + "\" clip-path=\"url(#p6227bfa58e)\" style=\"fill: none; stroke: #ff0000; stroke-width: 1.5; stroke-linecap: square\"/>\n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4527,17 +3855,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4561,7 +3889,7 @@ " \n", " \n", " \n", - " 2024-01-13T17:16:21.903411\n", + " 2024-05-05T17:02:25.065324\n", " image/svg+xml\n", " \n", " \n", @@ -4596,12 +3924,12 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4637,7 +3965,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4667,7 +3995,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4707,7 +4035,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4755,7 +4083,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4790,7 +4118,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4833,17 +4161,17 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4886,12 +4214,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4902,12 +4230,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4918,12 +4246,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4933,12 +4261,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4948,12 +4276,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4963,12 +4291,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4978,12 +4306,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4992,1064 +4320,746 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", " \n", + "\" clip-path=\"url(#pd5ea6dc8e0)\" style=\"fill: none; stroke: #ff0000; stroke-width: 1.5; stroke-linecap: square\"/>\n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", " \n", " \n", " \n", @@ -6480,14 +5458,14 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -6516,7 +5494,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "7920d242-b26d-49a7-a630-9acfb06c6f2b", "metadata": { "tags": [] @@ -6529,16 +5507,12 @@ " 'u_max' : widgets.FloatSlider(min=0, max=10, step=0.01, value=2, description='u_max'),\n", "}\n", "\n", - "solver = Solver( partial(problem_def_2integrator, n_steps = 50, dt=0.1) )\n", - "solver.solver_settings['max_float32_iterations'] = 0\n", - "solver.solver_settings['max_iter_boundary_method'] = 100\n", - "solver.solver_settings['c_eq_init'] = 1000\n", - "\n" + "solver = Solver( partial(problem_def_2integrator, n_steps = 50, dt=0.1) )" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "62299d02-fe4b-47d1-ab9d-83f061ce2427", "metadata": { "tags": [] @@ -6547,7 +5521,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1529a7d4e2774eafb558c79d1c1c0d42", + "model_id": "f409ebd42b054174aef2b3d1a3d2da4a", "version_major": 2, "version_minor": 0 }, @@ -6561,7 +5535,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "27e35ee1e3404baebc371d03097e83b6", + "model_id": "979ef8fd49c148f88ac1ad5dc1ade4df", "version_major": 2, "version_minor": 0 }, @@ -6575,7 +5549,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "406fb06988bb43b194de06165aa2e411", + "model_id": "64dc9dda33f34596b8600d3b85d8f86d", "version_major": 2, "version_minor": 0 }, @@ -6594,7 +5568,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "f472d0b0-ac13-4d06-b9d3-e490de0ead5c", "metadata": { "tags": [] @@ -6616,7 +5590,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "ce2c3f8a-7e8c-4ad3-9de4-45543ee035fa", "metadata": { "tags": [] @@ -6627,23 +5601,22 @@ "output_type": "stream", "text": [ "👉 solving problem with n_horizon=50, n_states=2 n_inputs=1\n", - "🔄 it=0 \t (sub iter=133)\tt=0 \teq_error/eq_tol=45476 %\tinside bounds: True\n", - "🔄 it=1 \t (sub iter=119)\tt=1 \teq_error/eq_tol=30764 %\tinside bounds: True\n", - "🔄 it=2 \t (sub iter=122)\tt=1 \teq_error/eq_tol=20066 %\tinside bounds: True\n", - "🔄 it=3 \t (sub iter=122)\tt=2 \teq_error/eq_tol=12762 %\tinside bounds: True\n", - "🔄 it=4 \t (sub iter=121)\tt=3 \teq_error/eq_tol=7989 %\tinside bounds: True\n", - "🔄 it=5 \t (sub iter=118)\tt=5 \teq_error/eq_tol=4960 %\tinside bounds: True\n", - "🔄 it=6 \t (sub iter=117)\tt=7 \teq_error/eq_tol=3069 %\tinside bounds: True\n", - "🔄 it=7 \t (sub iter=116)\tt=11 \teq_error/eq_tol=1899 %\tinside bounds: True\n", - "🔄 it=8 \t (sub iter=122)\tt=17 \teq_error/eq_tol=1176 %\tinside bounds: True\n", - "🔄 it=9 \t (sub iter=119)\tt=27 \teq_error/eq_tol=729 %\tinside bounds: True\n", - "🔄 it=10 \t (sub iter=118)\tt=41 \teq_error/eq_tol=453 %\tinside bounds: True\n", - "🔄 it=11 \t (sub iter=119)\tt=64 \teq_error/eq_tol=282 %\tinside bounds: True\n", - "🔄 it=12 \t (sub iter=118)\tt=100 \teq_error/eq_tol=176 %\tinside bounds: True\n", - "🔄 it=13 \t (sub iter=109)\tt=100 \teq_error/eq_tol=110 %\tinside bounds: True\n", - "🔄 it=14 \t (sub iter=108)\tt=100 \teq_error/eq_tol=69 %\tinside bounds: True\n", + "🔄 it=0 \t (sub iter=102)\tt=0 \teq_error/eq_tol=122244 gain=0.0 lambda=1.72774223155908 \tinside bounds: True\n", + "🔄 it=1 \t (sub iter=101)\tt=1 \teq_error/eq_tol=104091 gain=1.174387347908792 lambda=1.7842316859768328 \tinside bounds: True\n", + "🔄 it=2 \t (sub iter=103)\tt=1 \teq_error/eq_tol=81387 gain=1.2789754560747137 lambda=1.8390591833556362 \tinside bounds: True\n", + "🔄 it=3 \t (sub iter=113)\tt=2 \teq_error/eq_tol=57428 gain=1.4171968054969755 lambda=1.887609957770989 \tinside bounds: True\n", + "🔄 it=4 \t (sub iter=117)\tt=3 \teq_error/eq_tol=36552 gain=1.5711112834541572 lambda=1.9264973036694084 \tinside bounds: True\n", + "🔄 it=5 \t (sub iter=122)\tt=5 \teq_error/eq_tol=21337 gain=1.7130896375511475 lambda=1.9549783182057725 \tinside bounds: True\n", + "🔄 it=6 \t (sub iter=122)\tt=7 \teq_error/eq_tol=11686 gain=1.8258941235010082 lambda=1.9741493073681207 \tinside bounds: True\n", + "🔄 it=7 \t (sub iter=126)\tt=11 \teq_error/eq_tol=6129 gain=1.9067506077840948 lambda=1.9856118051497085 \tinside bounds: True\n", + "🔄 it=8 \t (sub iter=124)\tt=17 \teq_error/eq_tol=3132 gain=1.9567021663474276 lambda=1.9914447879614428 \tinside bounds: True\n", + "🔄 it=9 \t (sub iter=119)\tt=27 \teq_error/eq_tol=1579 gain=1.983278567881635 lambda=1.9934915914425297 \tinside bounds: True\n", + "🔄 it=10 \t (sub iter=116)\tt=41 \teq_error/eq_tol=792 gain=1.9948878349570724 lambda=1.9934915914425297 \tinside bounds: True\n", + "🔄 it=11 \t (sub iter=116)\tt=64 \teq_error/eq_tol=396 gain=1.998140405810653 lambda=1.9934915914425297 \tinside bounds: True\n", + "🔄 it=12 \t (sub iter=119)\tt=100 \teq_error/eq_tol=198 gain=1.9984259145046508 lambda=1.9934915914425297 \tinside bounds: True\n", + "🔄 it=13 \t (sub iter=110)\tt=100 \teq_error/eq_tol=100 gain=1.9914875833134649 lambda=1.9934915914425297 \tinside bounds: True\n", "✅ found feasible solution\n", - "time to run: 0.4376699924468994 seconds\n" + "time to run: 0.3845529556274414 seconds\n" ] } ], @@ -6662,7 +5635,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "d2e4ce8f-e337-4167-b9f9-b1837786a3a9", "metadata": { "tags": [] @@ -6684,6 +5657,14 @@ "metadata": {}, "outputs": [], "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27c8b5fe-0263-41ea-9940-a0dc67826425", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/examples/trajectory_optim_cart_pendulum.ipynb b/examples/trajectory_optim_cart_pendulum.ipynb index 37702a5..b200c21 100644 --- a/examples/trajectory_optim_cart_pendulum.ipynb +++ b/examples/trajectory_optim_cart_pendulum.ipynb @@ -45,12 +45,12 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 3, "id": "29ba4f8d-620c-4c69-acf4-b8b710d9698d", "metadata": {}, "outputs": [], "source": [ - "from jax_control_algorithms.trajectory_optimization import Solver, Functions, ProblemDefinition, constraint_geq, constraint_leq, unpack_res, generate_penalty_parameter_trace\n", + "from jax_control_algorithms.trajectory_optimization import Solver, Functions, ProblemDefinition, SolverSettings, constraint_geq, constraint_leq, unpack_res, generate_penalty_parameter_trace\n", "from jax_control_algorithms.ui import manual_investigate, solve_and_plot\n", "from jax_control_algorithms.common import rk4\n", "\n", @@ -298,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 7, "id": "7920d242-b26d-49a7-a630-9acfb06c6f2b", "metadata": { "tags": [] @@ -311,13 +311,15 @@ " 'u_max' : widgets.FloatSlider(min=0, max=100, step=0.01, value=100, description='u_max'),\n", "}\n", "\n", - "solver = Solver( partial(problem_def_cart_pendulum, n_steps = 50, dt=0.1) )\n", - "solver.solver_settings['max_float32_iterations'] = 0\n", - "\n", - "solver.solver_settings['max_iter_boundary_method'] = 100\n", - "solver.solver_settings['c_eq_init'] = 10000\n", - "solver.solver_settings['penalty_parameter_trace'] = generate_penalty_parameter_trace(t_start=0.5, t_final=1.0, n_steps=13)[0]\n", - "solver.solver_settings['lam'] = 1.6\n", + "solver = Solver(\n", + " partial(problem_def_cart_pendulum, n_steps = 50, dt=0.1),\n", + " solver_settings=SolverSettings(\n", + " max_iter_boundary_method=100,\n", + " c_eq_init=10000,\n", + " penalty_parameter_trace=generate_penalty_parameter_trace(t_start=0.5, t_final=1.0, n_steps=13)[0],\n", + " \n", + " ),\n", + ")\n", "\n", "def set_theta_fn(solver, a, u_min, u_max):\n", " solver.problem_definition.parameters['a'] = a\n", @@ -327,7 +329,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 8, "id": "62299d02-fe4b-47d1-ab9d-83f061ce2427", "metadata": { "tags": [] @@ -336,7 +338,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "8ca18cc42d0249a49c708f4178b67a27", + "model_id": "956e3f574d944f8794c8bf0730e133a6", "version_major": 2, "version_minor": 0 }, @@ -350,7 +352,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "83eb272093da4b33a8c806ffd062f4c4", + "model_id": "bd625df46bbb433dbe7b74798f89c2f0", "version_major": 2, "version_minor": 0 }, @@ -364,7 +366,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "7f94030db2ad44d58259597887e3a28d", + "model_id": "0bcb17cf5190479582991e3e3e99f71e", "version_major": 2, "version_minor": 0 }, @@ -448,7 +450,7 @@ " \n", " \n", " \n", - " 2024-01-07T21:02:21.709627\n", + " 2024-05-05T17:13:33.607890\n", " image/svg+xml\n", " \n", " \n", @@ -485,7 +487,7 @@ "L 214.457143 182.571429 \n", "L 150.685714 182.571429 \n", "z\n", - "\" clip-path=\"url(#p63a8e70119)\" style=\"stroke: #000000; stroke-linejoin: miter\"/>\n", + "\" clip-path=\"url(#pf689b7f315)\" style=\"stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#pf689b7f315)\" style=\"stroke: #000000; stroke-linejoin: miter\"/>\n", " \n", " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -546,7 +548,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -582,7 +584,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -738,17 +740,17 @@ " \n", " \n", + "\" clip-path=\"url(#pf689b7f315)\" style=\"fill: none; stroke: #008000; stroke-width: 4; stroke-linecap: square\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#pf689b7f315)\" style=\"fill: none; stroke-dasharray: 7.4,3.2; stroke-dashoffset: 0; stroke: #808080; stroke-width: 2\"/>\n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -846,20 +848,7 @@ "metadata": { "tags": [] }, - "outputs": [ - { - "ename": "BaseException", - "evalue": "Test cart_pendulum failed", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mBaseException\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[12], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m r[\u001b[38;5;241m1\u001b[39m]:\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTest \u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m r[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m failed\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 6\u001b[0m \u001b[43mverify_test_results\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_results\u001b[49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[0;32mIn[12], line 4\u001b[0m, in \u001b[0;36mverify_test_results\u001b[0;34m(test_results)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m r \u001b[38;5;129;01min\u001b[39;00m test_results:\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m r[\u001b[38;5;241m1\u001b[39m]:\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTest \u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m r[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m failed\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", - "\u001b[0;31mBaseException\u001b[0m: Test cart_pendulum failed" - ] - } - ], + "outputs": [], "source": [ "def verify_test_results(test_results):\n", " for r in test_results:\n", @@ -876,6 +865,14 @@ "metadata": {}, "outputs": [], "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8e93cad-4268-4f9a-9395-964850402d57", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/examples/trajectory_optim_integrator.ipynb b/examples/trajectory_optim_integrator.ipynb index 596d52d..4c4be11 100644 --- a/examples/trajectory_optim_integrator.ipynb +++ b/examples/trajectory_optim_integrator.ipynb @@ -52,7 +52,7 @@ }, "outputs": [], "source": [ - "from jax_control_algorithms.trajectory_optimization import Solver, Functions, ProblemDefinition, constraint_geq, constraint_leq, unpack_res, generate_penalty_parameter_trace\n", + "from jax_control_algorithms.trajectory_optimization import Solver, Functions, SolverSettings, ProblemDefinition, constraint_geq, constraint_leq, unpack_res, generate_penalty_parameter_trace\n", "from jax_control_algorithms.ui import manual_investigate, solve_and_plot, plot_iterations\n", "from jax_control_algorithms.common import rk4\n", "\n", @@ -266,25 +266,22 @@ "text": [ "compiling optimizer...\n", "👉 solving problem with n_horizon=50, n_states=1 n_inputs=1\n", - "🔄 it=0 \t (sub iter=64)\tt=0 \teq_error/eq_tol=95677 %\tinside bounds: True\n", - "🔄 it=1 \t (sub iter=65)\tt=1 \teq_error/eq_tol=70538 %\tinside bounds: True\n", - "🔄 it=2 \t (sub iter=69)\tt=2 \teq_error/eq_tol=49685 %\tinside bounds: True\n", - "🔄 it=3 \t (sub iter=71)\tt=3 \teq_error/eq_tol=33768 %\tinside bounds: True\n", - "🔄 it=4 \t (sub iter=74)\tt=6 \teq_error/eq_tol=22368 %\tinside bounds: True\n", - "🔄 it=5 \t (sub iter=73)\tt=12 \teq_error/eq_tol=14572 %\tinside bounds: True\n", - "🔄 it=6 \t (sub iter=73)\tt=22 \teq_error/eq_tol=9392 %\tinside bounds: True\n", - "🔄 it=7 \t (sub iter=78)\tt=42 \teq_error/eq_tol=6007 %\tinside bounds: True\n", - "🔄 it=8 \t (sub iter=82)\tt=79 \teq_error/eq_tol=3818 %\tinside bounds: True\n", - "🔄 it=9 \t (sub iter=82)\tt=150 \teq_error/eq_tol=2415 %\tinside bounds: True\n", - "🔄 it=10 \t (sub iter=90)\tt=282 \teq_error/eq_tol=1522 %\tinside bounds: True\n", - "🔄 it=11 \t (sub iter=86)\tt=531 \teq_error/eq_tol=957 %\tinside bounds: True\n", - "🔄 it=12 \t (sub iter=84)\tt=1000 \teq_error/eq_tol=601 %\tinside bounds: True\n", - "🔄 it=13 \t (sub iter=88)\tt=1000 \teq_error/eq_tol=377 %\tinside bounds: True\n", - "🔄 it=14 \t (sub iter=89)\tt=1000 \teq_error/eq_tol=236 %\tinside bounds: True\n", - "🔄 it=15 \t (sub iter=76)\tt=1000 \teq_error/eq_tol=148 %\tinside bounds: True\n", - "🔄 it=16 \t (sub iter=72)\tt=1000 \teq_error/eq_tol=93 %\tinside bounds: True\n", + "🔄 it=0 \t (sub iter=63)\tt=0 \teq_error/eq_tol=104193 gain=0.0 lambda=1.7066377344920718 \tinside bounds: True\n", + "🔄 it=1 \t (sub iter=65)\tt=1 \teq_error/eq_tol=73529 gain=1.4170320198395614 lambda=1.733291007425863 \tinside bounds: True\n", + "🔄 it=2 \t (sub iter=70)\tt=2 \teq_error/eq_tol=48382 gain=1.5197643926417277 lambda=1.7541307356723017 \tinside bounds: True\n", + "🔄 it=3 \t (sub iter=73)\tt=3 \teq_error/eq_tol=30035 gain=1.6108713716793888 lambda=1.769139453454021 \tinside bounds: True\n", + "🔄 it=4 \t (sub iter=74)\tt=6 \teq_error/eq_tol=17858 gain=1.6818532490422464 lambda=1.779113355699883 \tinside bounds: True\n", + "🔄 it=5 \t (sub iter=76)\tt=12 \teq_error/eq_tol=10311 gain=1.7319130537763767 lambda=1.7851031309122738 \tinside bounds: True\n", + "🔄 it=6 \t (sub iter=75)\tt=22 \teq_error/eq_tol=5846 gain=1.7637528054797666 lambda=1.7881742025455243 \tinside bounds: True\n", + "🔄 it=7 \t (sub iter=74)\tt=42 \teq_error/eq_tol=3282 gain=1.7812344130577584 lambda=1.7893334599630348 \tinside bounds: True\n", + "🔄 it=8 \t (sub iter=69)\tt=79 \teq_error/eq_tol=1835 gain=1.7885616802657218 lambda=1.7894878558614464 \tinside bounds: True\n", + "🔄 it=9 \t (sub iter=68)\tt=150 \teq_error/eq_tol=1024 gain=1.7911584413934551 lambda=1.7890704529929005 \tinside bounds: True\n", + "🔄 it=10 \t (sub iter=70)\tt=282 \teq_error/eq_tol=572 gain=1.7911841986310006 lambda=1.7890704529929005 \tinside bounds: True\n", + "🔄 it=11 \t (sub iter=71)\tt=531 \teq_error/eq_tol=319 gain=1.7904138659385855 lambda=1.7890704529929005 \tinside bounds: True\n", + "🔄 it=12 \t (sub iter=75)\tt=1000 \teq_error/eq_tol=178 gain=1.7902512799009547 lambda=1.7890704529929005 \tinside bounds: True\n", + "🔄 it=13 \t (sub iter=104)\tt=1000 \teq_error/eq_tol=100 gain=1.7865014335371148 lambda=1.7890704529929005 \tinside bounds: True\n", "✅ found feasible solution\n", - "time to run: 2.3953638076782227 seconds\n" + "time to run: 2.204076051712036 seconds\n" ] }, { @@ -298,7 +295,7 @@ " \n", " \n", " \n", - " 2024-03-17T10:59:45.538625\n", + " 2024-05-05T17:01:22.295981\n", " image/svg+xml\n", " \n", " \n", @@ -333,47 +330,47 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -382,12 +379,12 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -432,7 +429,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -473,7 +470,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -509,7 +506,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -556,7 +553,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -612,7 +609,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -701,62 +698,62 @@ " \n", " \n", " \n", + "L 66.364574 165.7795 \n", + "L 70.422756 164.955273 \n", + "L 74.480937 164.156351 \n", + "L 78.539119 163.369793 \n", + "L 82.597301 162.585141 \n", + "L 86.655483 161.793715 \n", + "L 90.713665 160.988099 \n", + "L 94.771847 160.161759 \n", + "L 98.830028 159.308753 \n", + "L 102.88821 158.423517 \n", + "L 106.946392 157.500696 \n", + "L 111.004574 156.535018 \n", + "L 115.062756 155.521191 \n", + "L 119.120937 154.453832 \n", + "L 123.179119 153.327408 \n", + "L 127.237301 152.136196 \n", + "L 131.295483 150.874237 \n", + "L 135.353665 149.535295 \n", + "L 139.411847 148.112795 \n", + "L 143.470028 146.599787 \n", + "L 147.52821 144.988895 \n", + "L 151.586392 143.272309 \n", + "L 155.644574 141.441757 \n", + "L 159.702756 139.488503 \n", + "L 163.760937 137.403303 \n", + "L 167.819119 135.176382 \n", + "L 171.877301 132.79742 \n", + "L 175.935483 130.255565 \n", + "L 179.993665 127.539475 \n", + "L 184.051847 124.637391 \n", + "L 188.110028 121.537204 \n", + "L 192.16821 118.226607 \n", + "L 196.226392 114.693279 \n", + "L 200.284574 110.925222 \n", + "L 204.342756 106.911231 \n", + "L 208.400937 102.641539 \n", + "L 212.459119 98.108746 \n", + "L 216.517301 93.309054 \n", + "L 220.575483 88.244043 \n", + "L 224.633665 82.923081 \n", + "L 228.691847 77.366656 \n", + "L 232.750028 71.610898 \n", + "L 236.80821 65.713646 \n", + "L 240.866392 59.762441 \n", + "L 244.924574 53.884078 \n", + "L 248.982756 48.249371 \n", + "L 253.040937 42.97816 \n", + "L 257.099119 38.059676 \n", + "L 261.157301 33.4712 \n", + "L 265.215483 29.204582 \n", + "\" clip-path=\"url(#pc92314e9de)\" style=\"fill: none; stroke: #000000; stroke-width: 1.5; stroke-linecap: square\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#pc92314e9de)\" style=\"fill: none; stroke-dasharray: 1.5,2.475; stroke-dashoffset: 0; stroke: #000000; stroke-width: 1.5\"/>\n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1121,7 +1118,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1135,7 +1132,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1149,7 +1146,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1197,7 +1194,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1211,7 +1208,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1254,7 +1251,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1279,7 +1276,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1295,7 +1292,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1311,7 +1308,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1327,7 +1324,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1342,7 +1339,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1357,7 +1354,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1404,67 +1401,67 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", + "\" clip-path=\"url(#p69240d873a)\" style=\"fill: none; stroke-dasharray: 1.5,2.475; stroke-dashoffset: 0; stroke: #000000; stroke-width: 1.5\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p69240d873a)\" style=\"fill: none; stroke-dasharray: 1.5,2.475; stroke-dashoffset: 0; stroke: #000000; stroke-width: 1.5\"/>\n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1678,8 +1675,14 @@ } ], "source": [ - "solver = Solver( partial(problem_def_integrator, n_steps = 50, dt=0.1) )\n", - "solver.solver_settings['penalty_parameter_trace'] = generate_penalty_parameter_trace(t_start=0.5, t_final=1000.0, n_steps=13)[0]\n", + "solver = Solver( \n", + " partial(\n", + " problem_def_integrator, n_steps = 50, dt=0.1,\n", + " ),\n", + " solver_settings=SolverSettings(\n", + " penalty_parameter_trace=generate_penalty_parameter_trace(t_start=0.5, t_final=1000.0, n_steps=13)[0]\n", + " ),\n", + ")\n", "\n", "set_parameters_fn(solver, a=0.7, u_min=-2, u_max=1.0)\n", "solver.verbose = True\n", @@ -1689,6 +1692,16 @@ "plot_integrator(X_opt, U_opt, system_outputs, solver.problem_definition.parameters)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "8fa53ddb-4b16-48b9-922c-16b4745c2e47", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", "id": "92826b09-2d0b-499f-b77a-34e880928aee", @@ -1716,7 +1729,7 @@ " \n", " \n", " \n", - " 2024-03-17T10:59:45.756794\n", + " 2024-05-05T17:01:22.465521\n", " image/svg+xml\n", " \n", " \n", @@ -1751,12 +1764,12 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1792,7 +1805,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1822,7 +1835,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1862,7 +1875,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1910,7 +1923,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1945,7 +1958,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1988,17 +2001,17 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2032,12 +2045,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2047,12 +2060,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2164,905 +2177,746 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + "\" clip-path=\"url(#pfa0ad9f391)\" style=\"fill: none; stroke: #ff0000; stroke-width: 1.5; stroke-linecap: square\"/>\n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3426,7 +3280,7 @@ " \n", " \n", " \n", - " 2024-03-17T10:59:45.810577\n", + " 2024-05-05T17:01:22.514759\n", " image/svg+xml\n", " \n", " \n", @@ -3461,12 +3315,12 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3502,7 +3356,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3532,7 +3386,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3572,7 +3426,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3620,7 +3474,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3655,7 +3509,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3698,17 +3552,17 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3742,12 +3596,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3757,12 +3611,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3875,904 +3729,745 @@ " \n", " \n", " \n", + "L 44.389581 233.89511 \n", + "L 48.530583 233.715628 \n", + "L 52.671585 233.478532 \n", + "L 56.812587 233.193886 \n", + "L 60.953589 232.868698 \n", + "L 65.094591 232.505117 \n", + "L 69.235593 232.105226 \n", + "L 73.376594 231.668966 \n", + "L 77.517596 231.196085 \n", + "L 81.658598 230.685688 \n", + "L 85.7996 230.136572 \n", + "L 89.940602 229.547411 \n", + "L 94.081604 228.915246 \n", + "L 98.222606 228.236695 \n", + "L 102.363607 227.508304 \n", + "L 106.504609 226.727272 \n", + "L 110.645611 225.89085 \n", + "L 114.786613 224.995787 \n", + "L 118.927615 224.038238 \n", + "L 123.068617 223.013993 \n", + "L 127.209619 221.918855 \n", + "L 131.35062 220.748724 \n", + "L 135.491622 219.499453 \n", + "L 139.632624 218.166626 \n", + "L 143.773626 216.745287 \n", + "L 147.914628 215.229583 \n", + "L 152.05563 213.612478 \n", + "L 156.196631 211.885946 \n", + "L 160.337633 210.042056 \n", + "L 164.478635 208.074571 \n", + "L 168.619637 205.979777 \n", + "L 172.760639 203.75543 \n", + "L 176.901641 201.397717 \n", + "L 181.042643 198.89907 \n", + "L 185.183644 196.249463 \n", + "L 189.324646 193.439753 \n", + "L 193.465648 190.467831 \n", + "L 197.60665 187.335272 \n", + "L 201.747652 184.047554 \n", + "L 205.888654 180.606832 \n", + "L 210.029656 177.019276 \n", + "L 214.170657 173.300023 \n", + "L 218.311659 169.473118 \n", + "L 222.452661 165.586435 \n", + "L 226.593663 161.703824 \n", + "L 230.734665 157.943859 \n", + "L 234.875667 154.477364 \n", + "L 239.016669 151.601089 \n", + "L 243.15767 149.778689 \n", + "\" clip-path=\"url(#pb5815e4c49)\" style=\"fill: none; stroke: #8000ff; stroke-width: 1.5; stroke-linecap: square\"/>\n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + "\" clip-path=\"url(#pb5815e4c49)\" style=\"fill: none; stroke: #ff0000; stroke-width: 1.5; stroke-linecap: square\"/>\n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5171,14 +4866,14 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5232,7 +4927,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "117b49e36cf1461a859b7ed20cf500c2", + "model_id": "69d2829914c648a0b583ef60919f8b35", "version_major": 2, "version_minor": 0 }, @@ -5246,7 +4941,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f209ae323fb542cc8750c868f4930aa1", + "model_id": "143261c23eef4a6da253c569fff2b2c8", "version_major": 2, "version_minor": 0 }, @@ -5260,7 +4955,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b6d14076081c43e3b40aeb9325f04991", + "model_id": "2cc99bb491764df7913ac3b1d3abd53d", "version_major": 2, "version_minor": 0 }, diff --git a/jax_control_algorithms/trajectory_optim/outer_loop_solver.py b/jax_control_algorithms/trajectory_optim/outer_loop_solver.py index b93dfc3..20687dd 100644 --- a/jax_control_algorithms/trajectory_optim/outer_loop_solver.py +++ b/jax_control_algorithms/trajectory_optim/outer_loop_solver.py @@ -1,8 +1,12 @@ import jax import jax.numpy as jnp import jaxopt +from functools import partial from jax_control_algorithms.jax_helper import * +from jax_control_algorithms.trajectory_optim.penality_method import control_convergence_of_iteration +from jax_control_algorithms.trajectory_optim.problem_definition import * +from jax_control_algorithms.trajectory_optim.penality_method import eval_objective_of_penalty_method """ Implements the nested solver loops @@ -31,86 +35,70 @@ """ -def _print_loop_info(loop_par, is_max_iter_reached_and_not_finished, print_errors): - lax.cond(loop_par['is_abort'], lambda: jax.debug.print("-> abort as convergence has stopped"), lambda: None) +def _print_loop_info(loop_par: OuterLoopVariables, is_max_iter_reached_and_not_finished, print_errors): + lax.cond(loop_par.is_abort, lambda: jax.debug.print("-> abort as convergence has stopped"), lambda: None) if print_errors: lax.cond( is_max_iter_reached_and_not_finished, lambda: jax.debug.print("❌ max. iterations reached without a feasible solution"), lambda: None ) - lax.cond(jnp.logical_not(loop_par['is_X_finite']), lambda: jax.debug.print("❌ found non finite numerics"), lambda: None) + lax.cond(jnp.logical_not(loop_par.is_X_finite), lambda: jax.debug.print("❌ found non finite numerics"), lambda: None) -def _iterate(i, loop_var, solver_settings, objective_fn, verification_fn): +def _iterate(loop_var: OuterLoopVariables, functions, solver_settings : SolverSettings): + + i = loop_var.i # get the penalty parameter - penalty_parameter = loop_var['penalty_parameter_trace'][i] - n_outer_iterations_target = loop_var['penalty_parameter_trace'].shape[0] + penalty_parameter = loop_var.penalty_parameter_trace[i] + n_outer_iterations_target = loop_var.penalty_parameter_trace.shape[0] is_maximal_penalty_parameter_reached = i >= n_outer_iterations_target - 1 # - parameters_passed_to_inner_solver = loop_var['parameters_of_dynamic_model'] + ( + parameters_passed_to_inner_solver = loop_var.parameters_of_dynamic_model + ( penalty_parameter, - loop_var['opt_c_eq'], + loop_var.opt_c_eq, ) # run inner solver - gd = jaxopt.BFGS(fun=objective_fn, value_and_grad=False, tol=loop_var['tol_inner'], maxiter=solver_settings['max_iter_inner']) - res = gd.run(loop_var['variables'], parameters=parameters_passed_to_inner_solver) + # objective_fn = eval_feasibility_metric_of_penalty_method(variables, parameters_of_dynamic_model, functions : Functions) + objective_fn = partial(eval_objective_of_penalty_method, functions=functions) + + gd = jaxopt.BFGS(fun=objective_fn, value_and_grad=False, tol=loop_var.tol_inner, maxiter=solver_settings.max_iter_inner) + res = gd.run(loop_var.variables, parameters=parameters_passed_to_inner_solver) variables_updated_by_inner_solver = res.params - # run callback to verify the solution + # run verify the solution and control the convergence of to the equality constraints ( - verification_state_next, is_solution_feasible, is_equality_constraints_fulfilled, is_abort, is_X_finite, i_best, - max_eq_error, normalized_equality_error_change, normalized_equality_error_gain, opt_c_eq_next - ) = verification_fn( - loop_var['verification_state'], i, n_outer_iterations_target, res, variables_updated_by_inner_solver, loop_var['parameters_of_dynamic_model'], - penalty_parameter, loop_var['opt_c_eq'] + controller_state_next, is_equality_constraints_fulfilled, is_abort, is_X_finite, i_best, + max_eq_error, opt_c_eq_next, lam + ) = control_convergence_of_iteration( + loop_var.controller_state, + i, + n_outer_iterations_target, + res, + variables_updated_by_inner_solver, + loop_var.parameters_of_dynamic_model, + penalty_parameter, + loop_var.opt_c_eq, + loop_var.lam, + functions, + eq_tol=solver_settings.eq_tol, + verbose=True ) # update the state of the optimization variables in case the outer loop shall not be aborted - variables_next = tree_where(is_abort, loop_var['variables'], variables_updated_by_inner_solver) - - # - # - # - # - - # trace, _ = verification_state_next - - # def _control_gamma_eq( - # gamma_eq, is_equality_constraints_fulfilled, lam, max_eq_error, normalized_equality_error_change, - # normalized_equality_error_gain - # ): - - # _lam = lam * 1.0 - - # gamma_eq_next = gamma_eq * jnp.where(is_equality_constraints_fulfilled, 1.0, _lam) - - # return gamma_eq_next - - # # update opt_c_eq: in case the equality constraints are not satisfies yet, increase opt_c_eq by multiplication with lam > 1 - # # otherwise leave opt_c_eq untouched. - # # opt_c_eq_next = loop_var['opt_c_eq'] * jnp.where(is_equality_constraints_fulfilled, 1.0, loop_var['lam']) - # opt_c_eq_next = _control_gamma_eq( - # loop_var['opt_c_eq'], is_equality_constraints_fulfilled, loop_var['lam'], max_eq_error, normalized_equality_error_change, - # normalized_equality_error_gain - # ) - - # - # - # - # + variables_next = tree_where(is_abort, loop_var.variables, variables_updated_by_inner_solver) # solution found? - is_finished = jnp.logical_and(is_solution_feasible, is_maximal_penalty_parameter_reached) + is_finished = jnp.logical_and(controller_state_next.is_converged, is_maximal_penalty_parameter_reached) - return variables_next, verification_state_next, opt_c_eq_next, is_finished, is_abort, is_X_finite + return variables_next, controller_state_next, opt_c_eq_next, lam, is_finished, is_abort, is_X_finite def _run_outer_loop( - i, variables, parameters_of_dynamic_model, opt_c_eq, verification_state_init, solver_settings, objective_fn, verification_fn, - verbose, print_errors, target_dtype + i, variables, model_to_solve: ModelToSolve, opt_c_eq, verification_state_init, solver_settings : SolverSettings, verbose, print_errors, + target_dtype ): """ Execute the outer loop of the optimization process: herein in each iteration, the parameters of the @@ -121,64 +109,62 @@ def _run_outer_loop( # convert dtypes to the target datatype used in the computation ( variables, - parameters_of_dynamic_model, + # model_to_solve, penalty_parameter_trace, opt_c_eq, verification_state_init, - lam, tol_inner, ) = convert_dtype( ( variables, - parameters_of_dynamic_model, - solver_settings['penalty_parameter_trace'], + # model_to_solve.parameters_of_dynamic_model, # model_to_solve, # model_to_solve. + solver_settings.penalty_parameter_trace, opt_c_eq, verification_state_init, - solver_settings['lam'], - solver_settings['tol_inner'], - ), target_dtype + solver_settings.tol_inner, + ), + target_dtype ) # loop: - def run_outer_loop_body(loop_var): + def run_outer_loop_body(loop_var: OuterLoopVariables): # loop iteration variable i - i = loop_var['i'] + i = loop_var.i - variables_next, verification_state_next, opt_c_eq_next, is_finished, is_abort, is_X_finite = _iterate( - i, loop_var, solver_settings, objective_fn, verification_fn + variables_next, verification_state_next, opt_c_eq_next, lam, is_finished, is_abort, is_X_finite = _iterate( + loop_var, model_to_solve.functions, solver_settings ) if verbose: lax.cond(is_finished, lambda: jax.debug.print("✅ found feasible solution"), lambda: None) - loop_var = { - 'is_finished': is_finished, - 'is_abort': is_abort, - 'is_X_finite': is_X_finite, - 'variables': variables_next, - 'parameters_of_dynamic_model': loop_var['parameters_of_dynamic_model'], - 'penalty_parameter_trace': penalty_parameter_trace, - 'opt_c_eq': opt_c_eq_next, - 'i': loop_var['i'] + 1, - 'verification_state': verification_state_next, - 'tol_inner': loop_var['tol_inner'], - 'lam': loop_var['lam'], - } + loop_var = OuterLoopVariables( + is_finished=is_finished, + is_abort=is_abort, + is_X_finite=is_X_finite, + variables=variables_next, + parameters_of_dynamic_model=loop_var.parameters_of_dynamic_model, + penalty_parameter_trace=penalty_parameter_trace, + opt_c_eq=opt_c_eq_next, + lam=lam, + i=loop_var.i + 1, + controller_state=verification_state_next, + tol_inner=loop_var.tol_inner, + ) return loop_var - def eval_outer_loop_condition(loop_var): - is_n_iter_not_reached = loop_var['i'] < solver_settings['max_iter_boundary_method'] + def eval_outer_loop_condition(loop_var: OuterLoopVariables): + is_n_iter_not_reached = loop_var.i < solver_settings.max_iter_boundary_method is_max_iter_reached_and_not_finished = jnp.logical_and( jnp.logical_not(is_n_iter_not_reached), - jnp.logical_not(loop_var['is_finished']), + jnp.logical_not(loop_var.is_finished), ) is_continue_iteration = jnp.logical_and( - jnp.logical_not(loop_var['is_abort']), - jnp.logical_and(jnp.logical_not(loop_var['is_finished']), is_n_iter_not_reached) + jnp.logical_not(loop_var.is_abort), jnp.logical_and(jnp.logical_not(loop_var.is_finished), is_n_iter_not_reached) ) if verbose: @@ -187,50 +173,47 @@ def eval_outer_loop_condition(loop_var): return is_continue_iteration # variables that are passed amount the loop-iterations - loop_var = { - 'is_finished': jnp.array(False, dtype=jnp.bool_), - 'is_abort': jnp.array(False, dtype=jnp.bool_), - 'is_X_finite': jnp.array(True, dtype=jnp.bool_), - 'variables': variables, - 'parameters_of_dynamic_model': parameters_of_dynamic_model, - 'penalty_parameter_trace': penalty_parameter_trace, - 'opt_c_eq': opt_c_eq, - 'i': i, - 'verification_state': verification_state_init, - 'tol_inner': tol_inner, - 'lam': lam, - } + loop_var = OuterLoopVariables( + is_finished=jnp.array(False, dtype=jnp.bool_), + is_abort=jnp.array(False, dtype=jnp.bool_), + is_X_finite=jnp.array(True, dtype=jnp.bool_), + variables=variables, + parameters_of_dynamic_model=model_to_solve.parameters_of_dynamic_model, + penalty_parameter_trace=penalty_parameter_trace, + opt_c_eq=opt_c_eq, + lam=jnp.array(jnp.nan), + i=i, + controller_state=verification_state_init, + tol_inner=tol_inner, + ) - loop_var = lax.while_loop(eval_outer_loop_condition, run_outer_loop_body, loop_var) + loop_var: OuterLoopVariables = lax.while_loop(eval_outer_loop_condition, run_outer_loop_body, loop_var) - n_iter = loop_var['i'] + n_iter = loop_var.i - return loop_var['variables'], loop_var['opt_c_eq'], n_iter, loop_var['verification_state'] + return loop_var.variables, loop_var.opt_c_eq, n_iter, loop_var.controller_state def run_outer_loop_solver( - variables, parameters_of_dynamic_model, solver_settings, trace_init, objective_, verification_fn_, max_float32_iterations, - enable_float64, verbose + variables, model_to_solve, solver_settings : SolverSettings, trace_init, max_float32_iterations, enable_float64, verbose ): """ execute the solution finding process """ - opt_c_eq = solver_settings['c_eq_init'] + opt_c_eq = solver_settings.c_eq_init i = 0 - verification_state = (trace_init, jnp.array(0, dtype=jnp.bool_)) + verification_state = ConvergenceControllerState(trace=trace_init, is_converged=jnp.array(0, dtype=jnp.bool_)) # iterations that are performed using float32 datatypes if max_float32_iterations > 0: variables, opt_c_eq, n_iter_f32, verification_state = _run_outer_loop( i, variables, - parameters_of_dynamic_model, + model_to_solve, jnp.array(opt_c_eq, dtype=jnp.float32), verification_state, solver_settings, - objective_, - verification_fn_, verbose, True if verbose else False, # show_errors target_dtype=jnp.float32 @@ -249,12 +232,10 @@ def run_outer_loop_solver( variables, opt_c_eq, n_iter_f64, verification_state = _run_outer_loop( i, variables, - parameters_of_dynamic_model, + model_to_solve, jnp.array(opt_c_eq, dtype=jnp.float64), verification_state, solver_settings, - objective_, - verification_fn_, verbose, True if verbose else False, # show_errors target_dtype=jnp.float64 diff --git a/jax_control_algorithms/trajectory_optim/penality_method.py b/jax_control_algorithms/trajectory_optim/penality_method.py index 06ca9df..98d5da4 100644 --- a/jax_control_algorithms/trajectory_optim/penality_method.py +++ b/jax_control_algorithms/trajectory_optim/penality_method.py @@ -7,29 +7,36 @@ from jax_control_algorithms.trajectory_optim.boundary_function import boundary_fn from jax_control_algorithms.trajectory_optim.dynamics_constraints import eval_dynamics_equality_constraints from jax_control_algorithms.trajectory_optim.cost_function import evaluate_cost +from jax_control_algorithms.trajectory_optim.problem_definition import Functions, ConvergenceControllerState """ https://en.wikipedia.org/wiki/Penalty_method """ -def _objective(variables, parameters_passed_to_solver, static_parameters): +def _objective(variables, parameters_passed_to_solver, functions : Functions): + """ + compute the objective function including + + - a penalty for equality constraints + - a penalty for boundaries + - the cost function of the problem to solve + """ K, parameters, x0, penalty_parameter, opt_c_eq = parameters_passed_to_solver - f, terminal_constraints, inequality_constraints, cost, running_cost = static_parameters X, U = variables n_steps = X.shape[0] assert U.shape[0] == n_steps # get equality constraint. The constraints are fulfilled of all elements of c_eq are zero - c_eq = eval_dynamics_equality_constraints(f, terminal_constraints, X, U, K, x0, parameters).reshape(-1) - c_ineq = inequality_constraints(X, U, K, parameters).reshape(-1) + c_eq = eval_dynamics_equality_constraints(functions.f, functions.terminal_constraints, X, U, K, x0, parameters).reshape(-1) + c_ineq = functions.inequality_constraints(X, U, K, parameters).reshape(-1) # equality constraints using penalty method J_equality_costs = opt_c_eq * jnp.mean((c_eq.reshape(-1))**2) # eval cost function of problem definition - J_cost_function = evaluate_cost(f, cost, running_cost, X, U, K, parameters) + J_cost_function = evaluate_cost(functions.f, functions.cost, functions.running_cost, X, U, K, parameters) # apply boundary costs (boundary function) J_boundary_costs = jnp.mean(boundary_fn(c_ineq, penalty_parameter, 11, True)) @@ -37,11 +44,11 @@ def _objective(variables, parameters_passed_to_solver, static_parameters): return J_equality_costs + J_cost_function + J_boundary_costs, c_eq -def eval_objective_of_penalty_method(variables, parameters, static_parameters): - return _objective(variables, parameters, static_parameters)[0] +def eval_objective_of_penalty_method(variables, parameters, functions : Functions): + return _objective(variables, parameters, functions)[0] -def eval_feasibility_metric_of_penalty_method(variables, parameters_of_dynamic_model, static_parameters): +def eval_feasibility_metric_of_penalty_method(variables, parameters_of_dynamic_model, functions : Functions): """ evaluate the correctness of the given solution candidate (variables) @@ -52,12 +59,11 @@ def eval_feasibility_metric_of_penalty_method(variables, parameters_of_dynamic_m solution candidate is inside the boundaries defined by the constraints. """ K, parameters, x0 = parameters_of_dynamic_model - f, terminal_constraints, inequality_constraints, cost, running_cost = static_parameters X, U = variables # get equality constraint. The constraints are fulfilled of all elements of c_eq are zero - c_eq = eval_dynamics_equality_constraints(f, terminal_constraints, X, U, K, x0, parameters) - c_ineq = inequality_constraints(X, U, K, parameters) + c_eq = eval_dynamics_equality_constraints( functions.f, functions.terminal_constraints, X, U, K, x0, parameters) + c_ineq = functions.inequality_constraints(X, U, K, parameters) # metric_c_eq = jnp.max(jnp.abs(c_eq)) @@ -71,31 +77,6 @@ def eval_feasibility_metric_of_penalty_method(variables, parameters_of_dynamic_m return metric_c_eq, is_solution_inside_boundaries -def _check_monotonic_convergence(i, trace): - """ - Check the monotonic convergence of the error for the equality constraints - """ - trace_data = get_trace_data(trace) - - # As being in the 2nd iteration, compare to prev. metric and see if it got smaller - is_metric_check_active = i > 2 - - def true_fn(par): - i, trace = par - - delta_max_eq_error = trace[0][i] - trace[0][i - 1] - is_abort = delta_max_eq_error >= 0 - - return is_abort - - def false_fn(par): - return False - - is_not_monotonic = lax.cond(is_metric_check_active, true_fn, false_fn, (i, trace_data)) - - return is_not_monotonic - - def _eval_eq_constraints_improvement(i, trace): """ return the change of the error for the equality constraints between the iteration i and i-1 @@ -109,33 +90,60 @@ def true_fn(par): normalized_equality_error_before = trace[0][i - 1] normalized_equality_error_after = trace[0][i] - normalized_equality_error_change = normalized_equality_error_before - normalized_equality_error_after normalized_equality_error_gain = normalized_equality_error_before / normalized_equality_error_after - if False: - jax.debug.print( - "normalized_equality_error_before={normalized_equality_error_before} normalized_equality_error_after={normalized_equality_error_after}", - normalized_equality_error_before=normalized_equality_error_before, normalized_equality_error_after=normalized_equality_error_after - ) - - return normalized_equality_error_change, normalized_equality_error_gain + return normalized_equality_error_gain def false_fn(par): - return 0.0, 0.0 + return 0.0 return lax.cond(i >= 1, true_fn, false_fn, (i, trace_data)) +def _control_gamma_eq( + i, # loop index + gamma_eq, + lam_prev, + is_equality_constraints_fulfilled, + normalized_equality_error, + normalized_equality_error_gain, + n_outer_iterations_target +): + # normalized_equality_error --> 1 in n_outer_iterations_target + # + # normalized_equality_error / lambda ^ n_outer_iterations_target < 1.0 + # normalized_equality_error = lambda ^ n_outer_iterations_target + + n_iter_left = n_outer_iterations_target - i + + lam = jnp.where(n_iter_left > 3, normalized_equality_error**(1 / n_iter_left), lam_prev) + + if False: + jax.debug.print( + "lam={lam} lam_prev={lam_prev} normalized_equality_error={normalized_equality_error} n_outer_iterations_target={n_outer_iterations_target}", + lam=lam, + lam_prev=lam_prev, + normalized_equality_error=normalized_equality_error, + n_outer_iterations_target=n_outer_iterations_target + ) + + gamma_eq_next = gamma_eq * jnp.where(is_equality_constraints_fulfilled, 1.0, lam) + + # NOTE: normalized_equality_error_gain shall be close to lam_prev + # This could be used for monitoring the convergence. + + return gamma_eq_next, lam -def verify_convergence_of_iteration( - verification_state, +def control_convergence_of_iteration( + controller_state : ConvergenceControllerState, i, n_outer_iterations_target, res_inner, variables, parameters_of_dynamic_model, penalty_parameter, - opt_c_eq, # blub - feasibility_metric_fn, + opt_c_eq, + lam, + functions, eq_tol, verbose: bool ): @@ -144,14 +152,14 @@ def verify_convergence_of_iteration( for each iteration of the outer optimization loop. """ - trace, _, = verification_state + trace = controller_state.trace # is_X_finite = jnp.isfinite(variables[0]).all() is_abort_because_of_nonfinite = jnp.logical_not(is_X_finite) # verify step - max_eq_error, is_solution_inside_boundaries = feasibility_metric_fn(variables, parameters_of_dynamic_model) + max_eq_error, is_solution_inside_boundaries = eval_feasibility_metric_of_penalty_method(variables, parameters_of_dynamic_model, functions) n_iter_inner = res_inner.state.iter_num # @@ -167,95 +175,34 @@ def verify_convergence_of_iteration( trace_next, is_trace_appended = append_to_trace( trace, (normalized_equality_error, 1.0 * is_solution_inside_boundaries, n_iter_inner, X, U) ) - verification_state_next = (trace_next, is_converged) - - # - # - # - # - # - # - # - # - - # check for monotonic convergence of the equality constraints - # - # - # NOTE: it is ok if it is not monotonic, the case of step back, the control needs to increase the parameter for the - # equality constraints - # - is_not_monotonic = jnp.logical_and( - _check_monotonic_convergence(i, trace_next), - jnp.logical_not(is_converged), - ) + verification_state_next = ConvergenceControllerState(trace=trace_next, is_converged=is_converged) # measure the improvement of eq-constraints fulfillment # ideally, this metric always decreases - normalized_equality_error_change, normalized_equality_error_gain = _eval_eq_constraints_improvement(i, trace_next) + normalized_equality_error_gain = _eval_eq_constraints_improvement(i, trace_next) # is_abort = jnp.logical_or(is_abort_because_of_nonfinite, is_not_monotonic) is_abort = is_abort_because_of_nonfinite # # control + # - def _control_gamma_eq( - i, # loop index - gamma_eq, - is_equality_constraints_fulfilled, - normalized_equality_error, - normalized_equality_error_change, - normalized_equality_error_gain, - n_outer_iterations_target - ): - # normalized_equality_error --> 1 in n_outer_iterations_target - # - # normalized_equality_error / lambda ^ n_outer_iterations_target < 1.0 - # normalized_equality_error = lambda ^ n_outer_iterations_target - - n_iter_left = n_outer_iterations_target - i - lam = jnp.where(n_iter_left > 3, normalized_equality_error**(1 / n_iter_left), 1.7) - - jax.debug.print( - "lam={lam} normalized_equality_error={normalized_equality_error} n_outer_iterations_target={n_outer_iterations_target}", - lam=lam, - normalized_equality_error=normalized_equality_error, - n_outer_iterations_target=n_outer_iterations_target - ) - - # lam = 1.6 - - _lam = lam * 1.0 - - gamma_eq_next = gamma_eq * jnp.where(is_equality_constraints_fulfilled, 1.0, _lam) - - return gamma_eq_next - - # update opt_c_eq: in case the equality constraints are not satisfies yet, increase opt_c_eq by multiplication with lam > 1 - # otherwise leave opt_c_eq untouched. - # opt_c_eq_next = loop_var['opt_c_eq'] * jnp.where(is_equality_constraints_fulfilled, 1.0, loop_var['lam']) - opt_c_eq_next = _control_gamma_eq( - i, opt_c_eq, is_equality_constraints_fulfilled, normalized_equality_error, normalized_equality_error_change, + opt_c_eq_next, lam_next = _control_gamma_eq( + i, opt_c_eq, lam, is_equality_constraints_fulfilled, normalized_equality_error, normalized_equality_error_gain, n_outer_iterations_target ) - # - # - # - # - # - # - i_best = None if verbose: jax.debug.print( - "🔄 it={i} \t (sub iter={n_iter_inner})\tt={penalty_parameter} \teq_error/eq_tol={normalized_equality_error} gain={normalized_equality_error_gain} change={normalized_equality_error_change} \tinside bounds: {is_solution_inside_boundaries}", + "🔄 it={i} \t (sub iter={n_iter_inner})\tt={penalty_parameter} \teq_error/eq_tol={normalized_equality_error} gain={normalized_equality_error_gain} lambda={lam} \tinside bounds: {is_solution_inside_boundaries}", i=i, penalty_parameter=my_to_int(my_round(penalty_parameter, decimals=0)), normalized_equality_error=my_to_int(my_round(100 * normalized_equality_error, decimals=0)), normalized_equality_error_gain=normalized_equality_error_gain, - normalized_equality_error_change=normalized_equality_error_change, + lam=lam_next, n_iter_inner=n_iter_inner, is_solution_inside_boundaries=is_solution_inside_boundaries, ) @@ -270,8 +217,7 @@ def _control_gamma_eq( is_solution_inside_boundaries=is_solution_inside_boundaries, ) - # verification_state, is_finished, is_abort, i_best return ( - verification_state_next, is_converged, is_equality_constraints_fulfilled, is_abort, is_X_finite, i_best, max_eq_error, - normalized_equality_error_change, normalized_equality_error_gain, opt_c_eq_next + verification_state_next, is_equality_constraints_fulfilled, is_abort, is_X_finite, i_best, max_eq_error, + opt_c_eq_next, lam_next ) diff --git a/jax_control_algorithms/trajectory_optim/problem_definition.py b/jax_control_algorithms/trajectory_optim/problem_definition.py new file mode 100644 index 0000000..3a2c11e --- /dev/null +++ b/jax_control_algorithms/trajectory_optim/problem_definition.py @@ -0,0 +1,73 @@ +from dataclasses import dataclass +from typing import Callable, Tuple, NamedTuple +from jax import numpy as jnp + +#from jax_control_algorithms.trajectory_optim.penality_method import generate_penalty_parameter_trace + +def generate_penalty_parameter_trace(t_start, t_final, n_steps): + """ + Generate a sequence of penalty factors to be used in the optimization process + + Args: + t_start: Initial penalty parameter t of the penalty method + t_final: maximal penalty parameter t to apply + n_steps: the length of the trace + """ + lam = (t_final / t_start)**(1 / (n_steps - 1)) + t_trace = t_start * lam**jnp.arange(n_steps) + return t_trace, lam + +@dataclass(frozen=True) +class Functions: + f: Callable + initial_guess: Callable = None + g: Callable = None + terminal_constraints: Callable = None + inequality_constraints: Callable = None + cost: Callable = None + running_cost: Callable = None + transform_parameters: Callable = None + + +class ParametersOfModelToSolve(NamedTuple): + K: jnp.ndarray = None + parameters: jnp.ndarray = None + x0: jnp.ndarray = None + + +#@dataclass(frozen=True) +class ModelToSolve(NamedTuple): + functions: Functions = None + parameters_of_dynamic_model: ParametersOfModelToSolve = None + + #K: jnp.ndarray = None + #parameters: jnp.ndarray = None + #x0: jnp.ndarray = None + + +class ConvergenceControllerState(NamedTuple): + trace: any + is_converged: jnp.ndarray + + +class OuterLoopVariables(NamedTuple): + is_finished: jnp.ndarray + is_abort: jnp.ndarray + is_X_finite: jnp.ndarray + variables: any + parameters_of_dynamic_model: any + penalty_parameter_trace: jnp.ndarray + opt_c_eq: jnp.ndarray + lam: jnp.ndarray + i: jnp.ndarray + controller_state: ConvergenceControllerState + tol_inner: jnp.ndarray + + +class SolverSettings(NamedTuple): + max_iter_boundary_method : int = 40 + max_iter_inner:float = 5000 + c_eq_init :float = 100.0 + eq_tol :float = 0.0001 + penalty_parameter_trace : jnp.array = generate_penalty_parameter_trace(t_start=0.5, t_final=100.0, n_steps=13)[0] + tol_inner:float = 0.0001 diff --git a/jax_control_algorithms/trajectory_optimization.py b/jax_control_algorithms/trajectory_optimization.py index 9911f01..12ec945 100644 --- a/jax_control_algorithms/trajectory_optimization.py +++ b/jax_control_algorithms/trajectory_optimization.py @@ -14,6 +14,7 @@ from jax_control_algorithms.trajectory_optim.dynamics_constraints import eval_dynamics_equality_constraints from jax_control_algorithms.trajectory_optim.penality_method import * from jax_control_algorithms.trajectory_optim.outer_loop_solver import run_outer_loop_solver +from jax_control_algorithms.trajectory_optim.problem_definition import * """ Perform trajectory optimization of a dynamic system by finding the control sequence that @@ -21,16 +22,15 @@ """ -@dataclass(frozen=True) -class Functions: - f: Callable - initial_guess: Callable = None - g: Callable = None - terminal_constraints: Callable = None - inequality_constraints: Callable = None - cost: Callable = None - running_cost: Callable = None - transform_parameters: Callable = None + + +@dataclass +class SolverReturn: + is_converged: bool + n_iter: jnp.ndarray + c_eq: jnp.ndarray + c_ineq: jnp.ndarray + trace: tuple @dataclass() @@ -39,7 +39,7 @@ class ProblemDefinition: x0: jnp.ndarray parameters: any = None - def run(self, x0=None, parameters=None, verbose: bool = False, solver_settings=None): + def run(self, x0=None, parameters=None, verbose: bool = False, solver_settings:SolverSettings=None) -> SolverReturn: solver_return = optimize_trajectory( self.functions, self.x0 if x0 is None else x0, @@ -54,13 +54,6 @@ def run(self, x0=None, parameters=None, verbose: bool = False, solver_settings=N return solver_return -@dataclass -class SolverReturn: - is_converged: bool - n_iter: jnp.ndarray - c_eq: jnp.ndarray - c_ineq: jnp.ndarray - trace: tuple def constraint_geq(x, v): @@ -106,32 +99,21 @@ def _verify_shapes(X_guess, U_guess, x0): return -def generate_penalty_parameter_trace(t_start, t_final, n_steps): - """ - Generate a sequence of penalty factors to be used in the optimization process - Args: - t_start: Initial penalty parameter t of the penalty method - t_final: maximal penalty parameter t to apply - n_steps: the length of the trace - """ - lam = (t_final / t_start)**(1 / (n_steps - 1)) - t_trace = t_start * lam**jnp.arange(n_steps) - return t_trace, lam -def get_default_solver_settings(): +def get_default_solver_settings() -> SolverSettings: - solver_settings = { - 'max_iter_boundary_method': 40, - 'max_iter_inner': 5000, - 'c_eq_init': 100.0, - 'lam': 1.6, - 'eq_tol': 0.0001, - 'penalty_parameter_trace': generate_penalty_parameter_trace(t_start=0.5, t_final=100.0, n_steps=13)[0], - 'tol_inner': 0.0001, - } + # solver_settings = { + # 'max_iter_boundary_method': 40, + # 'max_iter_inner': 5000, + # 'c_eq_init': 100.0, + # 'eq_tol': 0.0001, + # 'penalty_parameter_trace': generate_penalty_parameter_trace(t_start=0.5, t_final=100.0, n_steps=13)[0], + # 'tol_inner': 0.0001, + # } + solver_settings = SolverSettings() return solver_settings @@ -360,24 +342,9 @@ def optimize_trajectory( K = _build_sampling_index_vector(n_steps) # pack parameters and variables - parameters_of_dynamic_model = (K, parameters, x0) - static_parameters = ( - functions.f, functions.terminal_constraints, functions.inequality_constraints, functions.cost, functions.running_cost - ) + model_to_solve = ModelToSolve(functions=functions, parameters_of_dynamic_model=ParametersOfModelToSolve(K=K, x0=x0, parameters=parameters)) variables = (X_guess, U_guess) - # pass static parameters into objective function - objective_ = partial(eval_objective_of_penalty_method, static_parameters=static_parameters) - feasibility_metric_ = partial(eval_feasibility_metric_of_penalty_method, static_parameters=static_parameters) - - # verification function (non specific to given problem to solve) - verification_fn_ = partial( - verify_convergence_of_iteration, - feasibility_metric_fn=feasibility_metric_, - eq_tol=solver_settings['eq_tol'], - verbose=verbose - ) - # trace vars trace_init = init_trace_memory( max_trace_entries, (jnp.float32, jnp.float32, jnp.int32, jnp.float32, jnp.float32), @@ -386,8 +353,8 @@ def optimize_trajectory( # run solver variables_star, is_converged, n_iter, trace = run_outer_loop_solver( - variables, parameters_of_dynamic_model, solver_settings, trace_init, objective_, verification_fn_, max_float32_iterations, - enable_float64, verbose + variables, model_to_solve, solver_settings, trace_init, + max_float32_iterations, enable_float64, verbose ) # unpack results for optimized variables @@ -421,14 +388,14 @@ class Solver: High-level interface to the solver """ - def __init__(self, problem_def_fn): + def __init__(self, problem_def_fn, solver_settings : SolverSettings = get_default_solver_settings()): self.problem_def_fn = problem_def_fn # get problem definition self.problem_definition = problem_def_fn() assert type(self.problem_definition) is ProblemDefinition - self.solver_settings = get_default_solver_settings() + self.solver_settings = solver_settings self.enable_float64 = True self.max_float32_iterations = 0