Skip to content

Commit f090695

Browse files
committed
New MPC tutorial on quadcopter and cstr
1 parent 40a776c commit f090695

File tree

3 files changed

+1285
-0
lines changed

3 files changed

+1285
-0
lines changed
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import torch\n",
10+
"import torch.nn as nn\n",
11+
"from torchdiffeq import odeint\n",
12+
"import sys; sys.path.append(2*'../')\n",
13+
"from src import *\n",
14+
"import matplotlib.pyplot as plt\n",
15+
"from torch.distributions import MultivariateNormal, Uniform\n",
16+
"from warnings import warn\n",
17+
"\n",
18+
"# device = torch.device('cuda:0')\n",
19+
"device=torch.device('cpu')"
20+
]
21+
},
22+
{
23+
"cell_type": "markdown",
24+
"metadata": {},
25+
"source": [
26+
"## 1. CSTR model"
27+
]
28+
},
29+
{
30+
"cell_type": "code",
31+
"execution_count": 2,
32+
"metadata": {},
33+
"outputs": [
34+
{
35+
"name": "stdout",
36+
"output_type": "stream",
37+
"text": [
38+
"Input scaling:\n",
39+
" tensor([1.0000, 1.0000, 0.0100, 0.0100])\n",
40+
"Output scaling:\n",
41+
" tensor([[ 5.0000e+00, 1.0000e+02],\n",
42+
" [-8.5000e+03, 0.0000e+00]])\n",
43+
"Lower bounds:\n",
44+
" [0.1, 0.1, 50.0, 50.0] \n",
45+
"Upper bounds:\n",
46+
" [2.0, 2.0, None, 140.0]\n"
47+
]
48+
}
49+
],
50+
"source": [
51+
"System = ControlledCSTR\n",
52+
"\n",
53+
"##### Scaling since the parameters have very different values\n",
54+
"scaling_T_R = 1/100\n",
55+
"scaling_T_K = 1/100\n",
56+
"scaling_Q_dot = 1/2000\n",
57+
"scaling_F = 1/100\n",
58+
"\n",
59+
"# Scale the inputs appropriately for the controller\n",
60+
"in_scal = torch.ones(4).to(device)\n",
61+
"in_scal[2] = scaling_T_R\n",
62+
"in_scal[3] = scaling_T_K\n",
63+
"print('Input scaling:\\n', in_scal)\n",
64+
"\n",
65+
"out_scal = torch.tensor([[5., 100.],\n",
66+
" [-8500, 0.]]).to(device)\n",
67+
"print('Output scaling:\\n', out_scal)\n",
68+
"\n",
69+
"# State constraints\n",
70+
"lower_bounds = [0.1, 0.1, 50., 50.]\n",
71+
"upper_bounds = [2., 2., None, 140.]\n",
72+
"penalties = .01*torch.ones(4); penalties[3] = 100\n",
73+
"# penalties = torch.zeros(4)\n",
74+
"print('Lower bounds:\\n', lower_bounds, '\\nUpper bounds:\\n', upper_bounds)"
75+
]
76+
},
77+
{
78+
"cell_type": "markdown",
79+
"metadata": {},
80+
"source": [
81+
"## 2. Parameters for MPC simulation"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": 3,
87+
"metadata": {},
88+
"outputs": [],
89+
"source": [
90+
"# Time constraints\n",
91+
"Δt = 0.005\n",
92+
"t0, tf = 0, 0.5 # 0.5\n",
93+
"t_span = torch.linspace(t0, tf, int(tf/Δt) + 1).to(device) # define the t span\n",
94+
"\n",
95+
"# MPC simulation variables\n",
96+
"steps_nom = 10 # Nominal steps to do between each MPC step\n",
97+
"max_iters = 50\n",
98+
"eps_accept = 1e-3 # so we 'fix' the iterations to be always maximum\n",
99+
"lookahead_steps = 10\n",
100+
"bs = 512\n",
101+
"\n",
102+
"# Desired final condition\n",
103+
"C_b_star = 0.6\n",
104+
"\n",
105+
"# Initial Conditions\n",
106+
"ε = .01 # 1% of uncertainty given initial conditions\n",
107+
"C_a_0 = 0.8 # This is the initial concentration inside the tank [mol/l]\n",
108+
"C_b_0 = 0.5 # This is the controlled variable [mol/l]\n",
109+
"T_R_0 = 134.14 #[C]\n",
110+
"T_K_0 = 130.0 #[C]\n",
111+
"init = torch.Tensor([C_a_0, C_b_0, T_R_0, T_K_0])\n",
112+
"init_dist = Uniform((1-ε)*init, (1+ε)*init)\n",
113+
"x0 = init_dist.sample((bs,)).to(device)\n",
114+
"\n",
115+
"# Controllers and systems\n",
116+
"lr = .5e-3\n",
117+
"u = BoxConstrainedController(4, 2, input_scaling=in_scal, output_scaling=out_scal, constrained=True)\n",
118+
"const_u = RandConstController([1, 1], -1, 1).to(device) # dummy constant controller for simulation\n",
119+
"opt = torch.optim.Adam(u.parameters(), lr=lr) # optimizer\n",
120+
"system = System(u, solver='midpoint', retain_u=True)\n",
121+
"real_system = System(const_u, solver='dopri5')"
122+
]
123+
},
124+
{
125+
"cell_type": "markdown",
126+
"metadata": {},
127+
"source": [
128+
"## 2b. Define cost function"
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": 4,
134+
"metadata": {},
135+
"outputs": [],
136+
"source": [
137+
"loss = nn.MSELoss()\n",
138+
"class PositioningCost(nn.Module):\n",
139+
" '''Economic version of the positioning cost: we want to\n",
140+
" penalize big control inputs\n",
141+
"\n",
142+
" Args:\n",
143+
" target: torch.tensor, target position\n",
144+
" Q: float, state weight\n",
145+
" R: float, controller weight\n",
146+
" P: float, terminal cost weight\n",
147+
" '''\n",
148+
" def __init__(self, target, Q=1, R=0, P=0):\n",
149+
" super().__init__()\n",
150+
" self.target = target\n",
151+
" self.Q, self.R, self.P = Q, R, P\n",
152+
" \n",
153+
" def forward(self, traj, u=None, mesh_p=None):\n",
154+
" \"\"\"\n",
155+
" traj: trajectory to be followed\n",
156+
" u: control input to be minimized\n",
157+
" \"\"\"\n",
158+
" cost = self.Q*torch.norm(traj[-1, ..., 1] - self.target).mean(0)\n",
159+
" return cost\n",
160+
" \n",
161+
"cost_function = PositioningCost(torch.Tensor([C_b_star]))"
162+
]
163+
},
164+
{
165+
"cell_type": "code",
166+
"execution_count": 5,
167+
"metadata": {},
168+
"outputs": [
169+
{
170+
"name": "stdout",
171+
"output_type": "stream",
172+
"text": [
173+
"Starting simulation... Time: 0.0000 s\n",
174+
"Inner-loop did not converge, last cost: 0.762 | Time: 0.0050 s\n",
175+
"Inner-loop did not converge, last cost: 0.362 | Time: 0.0100 s\n",
176+
"Inner-loop did not converge, last cost: 0.372 | Time: 0.0150 s\n",
177+
"Inner-loop did not converge, last cost: 0.380 | Time: 0.0200 s\n",
178+
"Inner-loop did not converge, last cost: 0.385 | Time: 0.0250 s\n",
179+
"Inner-loop did not converge, last cost: 0.387 | Time: 0.0300 s\n",
180+
"Inner-loop did not converge, last cost: 0.386 | Time: 0.0350 s\n",
181+
"Inner-loop did not converge, last cost: 0.384 | Time: 0.0400 s\n",
182+
"Inner-loop did not converge, last cost: 0.381 | Time: 0.0450 s\n",
183+
"Inner-loop did not converge, last cost: 0.377 | Time: 0.0500 s\n",
184+
"Inner-loop did not converge, last cost: 0.371 | Time: 0.0550 s\n",
185+
"Inner-loop did not converge, last cost: 0.366 | Time: 0.0600 s\n",
186+
"Inner-loop did not converge, last cost: 0.360 | Time: 0.0650 s\n",
187+
"Inner-loop did not converge, last cost: 0.354 | Time: 0.0700 s\n",
188+
"Inner-loop did not converge, last cost: 0.347 | Time: 0.0750 s\n",
189+
"Inner-loop did not converge, last cost: 0.341 | Time: 0.0800 s\n",
190+
"Inner-loop did not converge, last cost: 0.334 | Time: 0.0850 s\n",
191+
"Inner-loop did not converge, last cost: 0.327 | Time: 0.0900 s\n",
192+
"Inner-loop did not converge, last cost: 0.321 | Time: 0.0950 s\n",
193+
"Inner-loop did not converge, last cost: 0.314 | Time: 0.1000 s\n",
194+
"Inner-loop did not converge, last cost: 0.308 | Time: 0.1050 s\n",
195+
"Inner-loop did not converge, last cost: 0.302 | Time: 0.1100 s\n",
196+
"Inner-loop did not converge, last cost: 0.295 | Time: 0.1150 s\n",
197+
"Inner-loop did not converge, last cost: 0.289 | Time: 0.1200 s\n",
198+
"Inner-loop did not converge, last cost: 0.284 | Time: 0.1250 s\n",
199+
"Inner-loop did not converge, last cost: 0.278 | Time: 0.1300 s\n",
200+
"Inner-loop did not converge, last cost: 0.273 | Time: 0.1350 s\n",
201+
"Inner-loop did not converge, last cost: 0.267 | Time: 0.1400 s\n",
202+
"Inner-loop did not converge, last cost: 0.262 | Time: 0.1450 s\n",
203+
"Inner-loop did not converge, last cost: 0.257 | Time: 0.1500 s\n",
204+
"Inner-loop did not converge, last cost: 0.253 | Time: 0.1550 s\n",
205+
"Inner-loop did not converge, last cost: 0.247 | Time: 0.1600 s\n",
206+
"Inner-loop did not converge, last cost: 0.243 | Time: 0.1650 s\n",
207+
"Inner-loop did not converge, last cost: 0.239 | Time: 0.1700 s\n",
208+
"Inner-loop did not converge, last cost: 0.236 | Time: 0.1750 s\n",
209+
"Inner-loop did not converge, last cost: 0.233 | Time: 0.1800 s\n",
210+
"Inner-loop did not converge, last cost: 0.229 | Time: 0.1850 s\n",
211+
"Inner-loop did not converge, last cost: 0.227 | Time: 0.1900 s\n",
212+
"Inner-loop did not converge, last cost: 0.223 | Time: 0.1950 s\n",
213+
"Inner-loop did not converge, last cost: 0.221 | Time: 0.2000 s\n",
214+
"Inner-loop did not converge, last cost: 0.221 | Time: 0.2050 s\n",
215+
"Inner-loop did not converge, last cost: 0.217 | Time: 0.2100 s\n",
216+
"Inner-loop did not converge, last cost: 0.214 | Time: 0.2150 s\n",
217+
"Inner-loop did not converge, last cost: 0.212 | Time: 0.2200 s\n",
218+
"Inner-loop did not converge, last cost: 0.211 | Time: 0.2250 s\n",
219+
"Inner-loop did not converge, last cost: 0.209 | Time: 0.2300 s\n",
220+
"Inner-loop did not converge, last cost: 0.207 | Time: 0.2350 s\n",
221+
"Inner-loop did not converge, last cost: 0.205 | Time: 0.2400 s\n",
222+
"Inner-loop did not converge, last cost: 0.204 | Time: 0.2450 s\n",
223+
"Inner-loop did not converge, last cost: 0.203 | Time: 0.2500 s\n",
224+
"Inner-loop did not converge, last cost: 0.201 | Time: 0.2550 s\n",
225+
"Inner-loop did not converge, last cost: 0.200 | Time: 0.2600 s\n",
226+
"Inner-loop did not converge, last cost: 0.198 | Time: 0.2650 s\n",
227+
"Inner-loop did not converge, last cost: 0.197 | Time: 0.2700 s\n",
228+
"Inner-loop did not converge, last cost: 0.196 | Time: 0.2750 s\n",
229+
"Inner-loop did not converge, last cost: 0.195 | Time: 0.2800 s\n",
230+
"Inner-loop did not converge, last cost: 0.194 | Time: 0.2850 s\n",
231+
"Inner-loop did not converge, last cost: 0.193 | Time: 0.2900 s\n",
232+
"Inner-loop did not converge, last cost: 0.192 | Time: 0.2950 s\n",
233+
"Inner-loop did not converge, last cost: 0.191 | Time: 0.3000 s\n",
234+
"Inner-loop did not converge, last cost: 0.191 | Time: 0.3050 s\n",
235+
"Inner-loop did not converge, last cost: 0.190 | Time: 0.3100 s\n",
236+
"Inner-loop did not converge, last cost: 0.189 | Time: 0.3150 s\n",
237+
"Inner-loop did not converge, last cost: 0.188 | Time: 0.3200 s\n",
238+
"Inner-loop did not converge, last cost: 0.188 | Time: 0.3250 s\n",
239+
"Inner-loop did not converge, last cost: 0.187 | Time: 0.3300 s\n",
240+
"Inner-loop did not converge, last cost: 0.186 | Time: 0.3350 s\n",
241+
"Inner-loop did not converge, last cost: 0.186 | Time: 0.3400 s\n",
242+
"Inner-loop did not converge, last cost: 0.185 | Time: 0.3450 s\n",
243+
"Inner-loop did not converge, last cost: 0.185 | Time: 0.3500 s\n",
244+
"Inner-loop did not converge, last cost: 0.184 | Time: 0.3550 s\n",
245+
"Inner-loop did not converge, last cost: 0.184 | Time: 0.3600 s\n",
246+
"Inner-loop did not converge, last cost: 0.183 | Time: 0.3650 s\n",
247+
"Inner-loop did not converge, last cost: 0.183 | Time: 0.3700 s\n",
248+
"Inner-loop did not converge, last cost: 0.182 | Time: 0.3750 s\n",
249+
"Inner-loop did not converge, last cost: 0.182 | Time: 0.3800 s\n",
250+
"Inner-loop did not converge, last cost: 0.182 | Time: 0.3850 s\n",
251+
"Inner-loop did not converge, last cost: 0.181 | Time: 0.3900 s\n",
252+
"Inner-loop did not converge, last cost: 0.181 | Time: 0.3950 s\n",
253+
"Inner-loop did not converge, last cost: 0.181 | Time: 0.4000 s\n",
254+
"Inner-loop did not converge, last cost: 0.180 | Time: 0.4050 s\n",
255+
"Inner-loop did not converge, last cost: 0.180 | Time: 0.4100 s\n",
256+
"Inner-loop did not converge, last cost: 0.180 | Time: 0.4150 s\n",
257+
"Inner-loop did not converge, last cost: 0.179 | Time: 0.4200 s\n",
258+
"Inner-loop did not converge, last cost: 0.179 | Time: 0.4250 s\n",
259+
"Inner-loop did not converge, last cost: 0.179 | Time: 0.4300 s\n",
260+
"Inner-loop did not converge, last cost: 0.179 | Time: 0.4350 s\n",
261+
"Inner-loop did not converge, last cost: 0.178 | Time: 0.4400 s\n",
262+
"Inner-loop did not converge, last cost: 0.178 | Time: 0.4450 s\n",
263+
"Inner-loop did not converge, last cost: 0.178 | Time: 0.4500 s\n",
264+
"Inner-loop did not converge, last cost: 0.178 | Time: 0.4550 s\n",
265+
"Inner-loop did not converge, last cost: 0.178 | Time: 0.4600 s\n",
266+
"Inner-loop did not converge, last cost: 0.177 | Time: 0.4650 s\n",
267+
"Inner-loop did not converge, last cost: 0.177 | Time: 0.4700 s\n",
268+
"Inner-loop did not converge, last cost: 0.177 | Time: 0.4750 s\n",
269+
"Inner-loop did not converge, last cost: 0.177 | Time: 0.4800 s\n",
270+
"Inner-loop did not converge, last cost: 0.177 | Time: 0.4850 s\n",
271+
"Inner-loop did not converge, last cost: 0.176 | Time: 0.4900 s\n",
272+
"Inner-loop did not converge, last cost: 0.176 | Time: 0.4950 s\n",
273+
"Inner-loop did not converge, last cost: 0.176 | Time: 0.5000 s\n",
274+
"The simulation has ended!\n"
275+
]
276+
}
277+
],
278+
"source": [
279+
"mpc = TorchMPC(system, cost_function, t_span, opt, eps_accept=eps_accept, max_g_iters=max_iters,\n",
280+
" lookahead_steps=lookahead_steps, lower_bounds=lower_bounds,\n",
281+
" upper_bounds=upper_bounds, penalties=penalties).to(device)\n",
282+
"\n",
283+
"mpc.forward_simulation(real_system, x0, t_span)\n",
284+
"\n",
285+
"with torch.no_grad():\n",
286+
"# Save the learned controller and nominal trajectory\n",
287+
" torch.save(mpc.control_inputs, 'data/control_inputs.pt')\n",
288+
" torch.save(mpc.trajectory_nominal, 'data/trajectory.pt')"
289+
]
290+
}
291+
],
292+
"metadata": {
293+
"kernelspec": {
294+
"display_name": "Python 3 (ipykernel)",
295+
"language": "python",
296+
"name": "python3"
297+
},
298+
"language_info": {
299+
"codemirror_mode": {
300+
"name": "ipython",
301+
"version": 3
302+
},
303+
"file_extension": ".py",
304+
"mimetype": "text/x-python",
305+
"name": "python",
306+
"nbconvert_exporter": "python",
307+
"pygments_lexer": "ipython3",
308+
"version": "3.9.5"
309+
}
310+
},
311+
"nbformat": 4,
312+
"nbformat_minor": 4
313+
}

0 commit comments

Comments
 (0)