Skip to content

Commit ee22cf4

Browse files
committed
Add notebook 17, optimal control
1 parent 869848c commit ee22cf4

File tree

1 file changed

+362
-0
lines changed

1 file changed

+362
-0
lines changed

17-optimal_control.ipynb

+362
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import torch\n",
10+
"from torch import nn\n",
11+
"from torch.optim import SGD\n",
12+
"from matplotlib import pyplot as plt\n",
13+
"from matplotlib import patches\n",
14+
"import matplotlib as mpl\n",
15+
"import numpy as np"
16+
]
17+
},
18+
{
19+
"cell_type": "code",
20+
"execution_count": null,
21+
"metadata": {},
22+
"outputs": [],
23+
"source": [
24+
"plt.style.use(['dark_background', 'bmh'])\n",
25+
"plt.rc('axes', facecolor='k')\n",
26+
"plt.rc('figure', facecolor='k', figsize=(10, 6), dpi=100) # (17, 10)\n",
27+
"plt.rc('savefig', bbox='tight')\n",
28+
"plt.rc('axes', labelsize=36)\n",
29+
"plt.rc('legend', fontsize=24)\n",
30+
"plt.rc('text', usetex=True)\n",
31+
"plt.rcParams['text.latex.preamble'] = [r'\\usepackage{bm}']\n",
32+
"plt.rc('lines', markersize=10)"
33+
]
34+
},
35+
{
36+
"cell_type": "markdown",
37+
"metadata": {},
38+
"source": [
39+
"The state transition equation is the following:\n",
40+
"\n",
41+
"$$\\def \\vx {\\boldsymbol{\\color{Plum}{x}}}\n",
42+
"\\def \\vu {\\boldsymbol{\\color{orange}{u}}}\n",
43+
"\\dot{\\vx} = f(\\vx, \\vu) \\quad\n",
44+
"\\left\\{\n",
45+
"\\begin{array}{l}\n",
46+
"\\dot{x} = s \\cos \\theta \\\\\n",
47+
"\\dot{y} = s \\sin \\theta \\\\\n",
48+
"\\dot{\\theta} = \\frac{s}{L} \\tan \\phi \\\\\n",
49+
"\\dot{s} = a\n",
50+
"\\end{array}\n",
51+
"\\right. \\quad\n",
52+
"\\vx = (x\\;y\\;\\theta\\;s) \\quad\n",
53+
"\\vu = (\\phi\\;a)\n",
54+
"$$"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"metadata": {},
61+
"outputs": [],
62+
"source": [
63+
"def f(x, u, t=None):\n",
64+
" \"\"\"\n",
65+
" Kinematic model for tricycle\n",
66+
" ẋ(t) = f[x(t), u(t), t]\n",
67+
" x: states (x, y, θ, s)\n",
68+
" u: control\n",
69+
" t: time\n",
70+
" f: kinematic model\n",
71+
" ẋ = dx/dt\n",
72+
" x' = x + f(x, u, t) * dt\n",
73+
" \"\"\"\n",
74+
" L = 1 # m\n",
75+
" x, y, θ, s = x\n",
76+
" \n",
77+
" ϕ, a = u\n",
78+
" f = torch.zeros(4)\n",
79+
" f[0] = s * torch.cos(θ)\n",
80+
" f[1] = s * torch.sin(θ)\n",
81+
" f[2] = s / L * torch.tan(ϕ)\n",
82+
" f[3] = a\n",
83+
" return f"
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": null,
89+
"metadata": {},
90+
"outputs": [],
91+
"source": [
92+
"def draw_car(ax, x, y, θ, width=0.4, length=1.0):\n",
93+
" rect = patches.Rectangle(\n",
94+
" (x, y - width / 2), \n",
95+
" length,\n",
96+
" width,\n",
97+
" transform=mpl.transforms.Affine2D().rotate_around(*(x, y), θ) + ax.transData,\n",
98+
" alpha=0.8,\n",
99+
" fill=False,\n",
100+
" ec='grey',\n",
101+
" )\n",
102+
" ax.add_patch(rect)\n",
103+
" \n",
104+
"def plot_τ(ax, τ, car=False, ax_lims=None):\n",
105+
" \"\"\"\n",
106+
" Plot trajectory of vehicles\n",
107+
" ax_lims is a tuple of two tuples ((x_lim_left, x_lim_right), (y_lim_bottom, y_lim_top))\n",
108+
" \"\"\"\n",
109+
" if ax_lims is None:\n",
110+
" ax_lims = ((-1, 7), (-2, 2))\n",
111+
" ax.plot(τ[:,0], τ[:,1], 'o-')\n",
112+
" ax.set_aspect('equal')\n",
113+
" ax.grid(True)\n",
114+
" ax.autoscale(False)\n",
115+
" ax.set_xlabel(r'$x \\; [\\mathrm{m}]$')\n",
116+
" ax.set_ylabel(r'$y \\; [\\mathrm{m}]$')\n",
117+
" \n",
118+
" ax.set_xlim(*ax_lims[0])\n",
119+
" ax.set_ylim(*ax_lims[1])\n",
120+
" ax.set_xticks(torch.arange(ax_lims[0][0], ax_lims[0][1] + 1, 1))\n",
121+
" ax.set_yticks(torch.arange(ax_lims[1][0], ax_lims[1][1] + 1, 1))\n",
122+
" \n",
123+
" plt.title('Trajectory')\n",
124+
" if car:\n",
125+
" for x, y, θ in τ[:, :3]:\n",
126+
" draw_car(plt.gca(), x, y, θ)"
127+
]
128+
},
129+
{
130+
"cell_type": "code",
131+
"execution_count": null,
132+
"metadata": {},
133+
"outputs": [],
134+
"source": [
135+
"# Manual driving\n",
136+
"x = torch.tensor((0, 0, 0, 1),dtype=torch.float32)\n",
137+
"# Optimal action from back propagation\n",
138+
"u = torch.tensor([\n",
139+
" [0.1280, 0.0182],\n",
140+
" [0.0957, 0.0131],\n",
141+
" [0.0637, 0.0085],\n",
142+
" [0.0318, 0.0043],\n",
143+
" [0.0000, 0.0000]\n",
144+
"])\n",
145+
"# Brake\n",
146+
"u = torch.ones(10, 2) * -0.1\n",
147+
"u[:, 0] = 0\n",
148+
"# S\n",
149+
"u = torch.zeros(10, 2)\n",
150+
"u[:5, 0] = 0.2\n",
151+
"u[5:, 0] = -0.2\n",
152+
"# Straight\n",
153+
"# u = torch.zeros(10, 2)\n",
154+
"\n",
155+
"dt = 1 # s\n",
156+
"trajectory = [x.clone()]\n",
157+
"for t in range(10):\n",
158+
" x += f(x, u[t]) * dt\n",
159+
" print(x)\n",
160+
" trajectory.append(x.clone())\n",
161+
"τ = torch.stack(trajectory)\n",
162+
"\n",
163+
"# plt.plot(0,0,'gx', markersize=20, markeredgewidth=5)\n",
164+
"# plt.plot(5,1,'rx', markersize=20, markeredgewidth=5)\n",
165+
"plot_τ(plt.gca(), τ, car=True)\n",
166+
"\n",
167+
"plt.axis((-1, 10, -1, 5))\n",
168+
"name = 'S'\n",
169+
"\n",
170+
"# plt.axis((-1, 12, -3, 3))\n",
171+
"# name = 'straight'\n",
172+
"\n",
173+
"# plt.axis((-1, 7, -2, 2))\n",
174+
"# name = 'brake'\n",
175+
"\n",
176+
"# plt.savefig(f'{name}.pdf')\n",
177+
"\n",
178+
"plt.figure(figsize=(6, 2))\n",
179+
"plt.title('Control signal')\n",
180+
"plt.stem(np.arange(10)+0.9, u[:,0], 'C1', markerfmt='C1o', use_line_collection=True, basefmt='none')\n",
181+
"plt.stem(np.arange(10)+1.1, u[:,1], 'C2', markerfmt='C2o', use_line_collection=True, basefmt='none')\n",
182+
"plt.ylim((-0.5, 0.5))\n",
183+
"plt.xticks(np.arange(12))\n",
184+
"plt.xlabel('discrete time index', fontsize=12)\n",
185+
"# plt.savefig(f'{name}-ctrl.pdf')"
186+
]
187+
},
188+
{
189+
"cell_type": "code",
190+
"execution_count": null,
191+
"metadata": {},
192+
"outputs": [],
193+
"source": [
194+
"# Costs definition\n",
195+
"# x: states (x, y, θ, s)\n",
196+
"def vanilla_cost(state, target):\n",
197+
" x_x, x_y = target\n",
198+
" return (state[-1][0] - x_x).pow(2) + (state[-1][1] - x_y).pow(2)\n",
199+
"\n",
200+
"def cost_with_target_s(state, target):\n",
201+
" x_x, x_y = target\n",
202+
" return (state[-1][0] - x_x).pow(2) + (state[-1][1] - x_y).pow(2) + (state[-1][-1]).pow(2)\n",
203+
"\n",
204+
"def cost_sum_distances(state, target):\n",
205+
" x_x, x_y = target\n",
206+
" dists = ((state[:, 0] - x_x).pow(2) + (state[:, 1] - x_y).pow(2)).pow(0.5)\n",
207+
" return dists.mean()\n",
208+
"\n",
209+
"def cost_sum_square_distances(state, target):\n",
210+
" x_x, x_y = target\n",
211+
" dists = ((state[:, 0] - x_x).pow(2) + (state[:, 1] - x_y).pow(2))\n",
212+
" return dists.mean()\n",
213+
"\n",
214+
"def cost_logsumexp(state, target):\n",
215+
" x_x, x_y = target\n",
216+
" dists = ((state[:, 0] - x_x).pow(2) + (state[:, 1] - x_y).pow(2))#.pow(0.5)\n",
217+
" return -1 * torch.logsumexp(-1 * dists, dim=0)"
218+
]
219+
},
220+
{
221+
"cell_type": "code",
222+
"execution_count": null,
223+
"metadata": {},
224+
"outputs": [],
225+
"source": [
226+
"# Path planning\n",
227+
"def path_planning_with_cost(x_x, x_y, s, T, epochs, stepsize, cost_f, ax=None, ax_lims=None, debug=False):\n",
228+
" \"\"\"\n",
229+
" Path planning for tricycle\n",
230+
" x_x: x component of postion vector\n",
231+
" x_y: y component of postion vector\n",
232+
" s: initial speed\n",
233+
" T: time steps\n",
234+
" epochs: number of epochs for back propagation\n",
235+
" stepsize: stepsize for back propagation\n",
236+
" cost_f: cost funciton that takes the trajectory and the tuple (x, y) - target.\n",
237+
" ax: axis to plot the trajectory\n",
238+
" \"\"\"\n",
239+
" ax = ax or plt.gca()\n",
240+
" plt.plot(0, 0, 'gx', markersize=20, markeredgewidth=5)\n",
241+
" plt.plot(x_x, x_y, 'rx', markersize=20, markeredgewidth=5)\n",
242+
" u = nn.Parameter(torch.zeros(T, 2))\n",
243+
" optimizer = SGD((u,), lr=stepsize)\n",
244+
" dt = 1 # s\n",
245+
" costs = []\n",
246+
" for epoch in range(epochs):\n",
247+
" x = [torch.tensor((0, 0, 0, s),dtype=torch.float32)]\n",
248+
" for t in range(1, T+1):\n",
249+
" x.append(x[-1] + f(x[-1], u[t-1]) * dt)\n",
250+
" x_t = torch.stack(x)\n",
251+
" τ = torch.stack(x).detach()\n",
252+
" cost = cost_f(x_t, (x_x, x_y))\n",
253+
" costs.append(cost.item())\n",
254+
" optimizer.zero_grad()\n",
255+
" cost.backward()\n",
256+
" optimizer.step()\n",
257+
" if debug: \n",
258+
" print(u.data)\n",
259+
" # Only plot the first and last trajectories\n",
260+
" if epoch == 0: \n",
261+
" plot_τ(ax, τ, ax_lims=ax_lims)\n",
262+
" if epoch == epochs-1:\n",
263+
" plot_τ(ax, τ, car=True, ax_lims=ax_lims)"
264+
]
265+
},
266+
{
267+
"cell_type": "code",
268+
"execution_count": null,
269+
"metadata": {},
270+
"outputs": [],
271+
"source": [
272+
"path_planning_with_cost(x_x=5, x_y=1, s=1, T=5, epochs=5, stepsize=0.01, cost_f=vanilla_cost, debug=False)"
273+
]
274+
},
275+
{
276+
"cell_type": "code",
277+
"execution_count": null,
278+
"metadata": {},
279+
"outputs": [],
280+
"source": [
281+
"plt.figure(dpi=100, figsize=(10, 55))\n",
282+
"for i in range(5, 16):\n",
283+
" ax = plt.subplot(11, 1, i - 5 + 1)\n",
284+
" path_planning_with_cost(x_x=5, x_y=1, s=1, T=i, epochs=50, stepsize=0.001, ax=ax, cost_f=vanilla_cost, debug=False)\n",
285+
" plt.title(f'T={i}')\n",
286+
"plt.tight_layout()\n",
287+
"plt.suptitle('Using just final position for the cost', y=1.01)\n",
288+
"# plt.savefig('final-position.pdf')"
289+
]
290+
},
291+
{
292+
"cell_type": "code",
293+
"execution_count": null,
294+
"metadata": {},
295+
"outputs": [],
296+
"source": [
297+
"plt.figure(dpi=100, figsize=(10, 55))\n",
298+
"plt.suptitle('Using final position and speed for the cost', y=1.01)\n",
299+
"for i in range(5, 16):\n",
300+
" ax = plt.subplot(11, 1, i - 5 + 1)\n",
301+
" path_planning_with_cost(x_x=5, x_y=1, s=1, T=i, epochs=50, stepsize=0.001, cost_f=cost_with_target_s, ax=ax, debug=False)\n",
302+
" plt.title(f\"T={i}\")\n",
303+
"plt.tight_layout()\n",
304+
"# plt.savefig('final-position-and-speed.pdf')"
305+
]
306+
},
307+
{
308+
"cell_type": "code",
309+
"execution_count": null,
310+
"metadata": {},
311+
"outputs": [],
312+
"source": [
313+
"plt.figure(dpi=100, figsize=(10, 55))\n",
314+
"plt.suptitle('Using sum of distances for the cost', y=1.01)\n",
315+
"for i in range(5, 16):\n",
316+
" ax = plt.subplot(11, 1, i - 5 + 1)\n",
317+
" costs = path_planning_with_cost(x_x=5, x_y=1, s=1, T=i, epochs=40, stepsize=0.0025, ax=ax, cost_f=cost_sum_square_distances, debug=False)\n",
318+
" plt.title(f\"T={i}\")\n",
319+
" plt.gca().set_aspect(\"equal\")\n",
320+
"plt.tight_layout()\n",
321+
"# plt.savefig('average-distance.pdf')"
322+
]
323+
},
324+
{
325+
"cell_type": "code",
326+
"execution_count": null,
327+
"metadata": {},
328+
"outputs": [],
329+
"source": [
330+
"plt.figure(dpi=100, figsize=(10, 55))\n",
331+
"plt.suptitle('Using softmin of distances for the cost (focusing on the points closest to target)', y=1.01)\n",
332+
"for i in range(5, 16):\n",
333+
" ax = plt.subplot(11, 1, i - 5 + 1)\n",
334+
" path_planning_with_cost(x_x=5, x_y=1, s=1, T=i, epochs=100, stepsize=0.005, cost_f=cost_logsumexp, ax=ax, debug=False)\n",
335+
" plt.title(f\"T={i}\")\n",
336+
"plt.tight_layout()\n",
337+
"plt.savefig('softmin.pdf')"
338+
]
339+
}
340+
],
341+
"metadata": {
342+
"kernelspec": {
343+
"display_name": "Python 3",
344+
"language": "python",
345+
"name": "python3"
346+
},
347+
"language_info": {
348+
"codemirror_mode": {
349+
"name": "ipython",
350+
"version": 3
351+
},
352+
"file_extension": ".py",
353+
"mimetype": "text/x-python",
354+
"name": "python",
355+
"nbconvert_exporter": "python",
356+
"pygments_lexer": "ipython3",
357+
"version": "3.8.1"
358+
}
359+
},
360+
"nbformat": 4,
361+
"nbformat_minor": 4
362+
}

0 commit comments

Comments
 (0)