Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5,947 changes: 2,464 additions & 3,483 deletions examples/trajectory_optim_2integrator.ipynb

Large diffs are not rendered by default.

73 changes: 35 additions & 38 deletions examples/trajectory_optim_cart_pendulum.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 3,
"id": "29ba4f8d-620c-4c69-acf4-b8b710d9698d",
"metadata": {},
"outputs": [],
"source": [
"from jax_control_algorithms.trajectory_optimization import Solver, Functions, ProblemDefinition, constraint_geq, constraint_leq, unpack_res, generate_penalty_parameter_trace\n",
"from jax_control_algorithms.trajectory_optimization import Solver, Functions, ProblemDefinition, SolverSettings, constraint_geq, constraint_leq, unpack_res, generate_penalty_parameter_trace\n",
"from jax_control_algorithms.ui import manual_investigate, solve_and_plot\n",
"from jax_control_algorithms.common import rk4\n",
"\n",
Expand Down Expand Up @@ -298,7 +298,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 7,
"id": "7920d242-b26d-49a7-a630-9acfb06c6f2b",
"metadata": {
"tags": []
Expand All @@ -311,13 +311,15 @@
" 'u_max' : widgets.FloatSlider(min=0, max=100, step=0.01, value=100, description='u_max'),\n",
"}\n",
"\n",
"solver = Solver( partial(problem_def_cart_pendulum, n_steps = 50, dt=0.1) )\n",
"solver.solver_settings['max_float32_iterations'] = 0\n",
"\n",
"solver.solver_settings['max_iter_boundary_method'] = 100\n",
"solver.solver_settings['c_eq_init'] = 10000\n",
"solver.solver_settings['penalty_parameter_trace'] = generate_penalty_parameter_trace(t_start=0.5, t_final=1.0, n_steps=13)[0]\n",
"solver.solver_settings['lam'] = 1.6\n",
"solver = Solver(\n",
" partial(problem_def_cart_pendulum, n_steps = 50, dt=0.1),\n",
" solver_settings=SolverSettings(\n",
" max_iter_boundary_method=100,\n",
" c_eq_init=10000,\n",
" penalty_parameter_trace=generate_penalty_parameter_trace(t_start=0.5, t_final=1.0, n_steps=13)[0],\n",
" \n",
" ),\n",
")\n",
"\n",
"def set_theta_fn(solver, a, u_min, u_max):\n",
" solver.problem_definition.parameters['a'] = a\n",
Expand All @@ -327,7 +329,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 8,
"id": "62299d02-fe4b-47d1-ab9d-83f061ce2427",
"metadata": {
"tags": []
Expand All @@ -336,7 +338,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8ca18cc42d0249a49c708f4178b67a27",
"model_id": "956e3f574d944f8794c8bf0730e133a6",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -350,7 +352,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "83eb272093da4b33a8c806ffd062f4c4",
"model_id": "bd625df46bbb433dbe7b74798f89c2f0",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -364,7 +366,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7f94030db2ad44d58259597887e3a28d",
"model_id": "0bcb17cf5190479582991e3e3e99f71e",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -448,7 +450,7 @@
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
" <cc:Work>\n",
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
" <dc:date>2024-01-07T21:02:21.709627</dc:date>\n",
" <dc:date>2024-05-05T17:13:33.607890</dc:date>\n",
" <dc:format>image/svg+xml</dc:format>\n",
" <dc:creator>\n",
" <cc:Agent>\n",
Expand Down Expand Up @@ -485,7 +487,7 @@
"L 214.457143 182.571429 \n",
"L 150.685714 182.571429 \n",
"z\n",
"\" clip-path=\"url(#p63a8e70119)\" style=\"stroke: #000000; stroke-linejoin: miter\"/>\n",
"\" clip-path=\"url(#pf689b7f315)\" style=\"stroke: #000000; stroke-linejoin: miter\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 182.571429 357.942857 \n",
Expand All @@ -498,18 +500,18 @@
"C 166.628571 346.228095 168.308411 350.283588 171.298126 353.273302 \n",
"C 174.287841 356.263017 178.343333 357.942857 182.571429 357.942857 \n",
"z\n",
"\" clip-path=\"url(#p63a8e70119)\" style=\"stroke: #000000; stroke-linejoin: miter\"/>\n",
"\" clip-path=\"url(#pf689b7f315)\" style=\"stroke: #000000; stroke-linejoin: miter\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <defs>\n",
" <path id=\"mafbef84cd4\" d=\"M 0 0 \n",
" <path id=\"m78a80aba27\" d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#mafbef84cd4\" x=\"182.571429\" y=\"421.714286\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" <use xlink:href=\"#m78a80aba27\" x=\"182.571429\" y=\"421.714286\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
Expand Down Expand Up @@ -546,7 +548,7 @@
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_2\">\n",
" <g>\n",
" <use xlink:href=\"#mafbef84cd4\" x=\"342\" y=\"421.714286\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" <use xlink:href=\"#m78a80aba27\" x=\"342\" y=\"421.714286\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
Expand Down Expand Up @@ -582,7 +584,7 @@
" <g id=\"xtick_3\">\n",
" <g id=\"line2d_3\">\n",
" <g>\n",
" <use xlink:href=\"#mafbef84cd4\" x=\"501.428571\" y=\"421.714286\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" <use xlink:href=\"#m78a80aba27\" x=\"501.428571\" y=\"421.714286\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
Expand Down Expand Up @@ -738,17 +740,17 @@
" <g id=\"line2d_4\">\n",
" <path d=\"M 182.571429 182.571429 \n",
"L 182.571429 342 \n",
"\" clip-path=\"url(#p63a8e70119)\" style=\"fill: none; stroke: #008000; stroke-width: 4; stroke-linecap: square\"/>\n",
"\" clip-path=\"url(#pf689b7f315)\" style=\"fill: none; stroke: #008000; stroke-width: 4; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_5\">\n",
" <path d=\"M 7.2 214.457143 \n",
"L 676.8 214.457143 \n",
"\" clip-path=\"url(#p63a8e70119)\" style=\"fill: none; stroke-dasharray: 7.4,3.2; stroke-dashoffset: 0; stroke: #808080; stroke-width: 2\"/>\n",
"\" clip-path=\"url(#pf689b7f315)\" style=\"fill: none; stroke-dasharray: 7.4,3.2; stroke-dashoffset: 0; stroke: #808080; stroke-width: 2\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"p63a8e70119\">\n",
" <clipPath id=\"pf689b7f315\">\n",
" <rect x=\"7.2\" y=\"7.2\" width=\"669.6\" height=\"414.514286\"/>\n",
" </clipPath>\n",
" </defs>\n",
Expand Down Expand Up @@ -846,20 +848,7 @@
"metadata": {
"tags": []
},
"outputs": [
{
"ename": "BaseException",
"evalue": "Test cart_pendulum failed",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mBaseException\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[12], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m r[\u001b[38;5;241m1\u001b[39m]:\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTest \u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m r[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m failed\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 6\u001b[0m \u001b[43mverify_test_results\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_results\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[12], line 4\u001b[0m, in \u001b[0;36mverify_test_results\u001b[0;34m(test_results)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m r \u001b[38;5;129;01min\u001b[39;00m test_results:\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m r[\u001b[38;5;241m1\u001b[39m]:\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTest \u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m r[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m failed\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
"\u001b[0;31mBaseException\u001b[0m: Test cart_pendulum failed"
]
}
],
"outputs": [],
"source": [
"def verify_test_results(test_results):\n",
" for r in test_results:\n",
Expand All @@ -876,6 +865,14 @@
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8e93cad-4268-4f9a-9395-964850402d57",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading