Skip to content

Commit 40a776c

Browse files
committed
Clean code
1 parent a176d6a commit 40a776c

File tree

4 files changed

+510
-0
lines changed

4 files changed

+510
-0
lines changed

torchcontrol/mpc/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .torchmpc import *

torchcontrol/mpc/torchmpc.py

+235
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
import torch
2+
import torch.nn as nn
3+
from tqdm import trange
4+
5+
class TorchMPC(nn.Module):
6+
def __init__(self,
7+
system,
8+
cost_function,
9+
t_span,
10+
opt,
11+
max_g_iters=100,
12+
eps_accept=0.01,
13+
lookahead_steps=100,
14+
lower_bounds=None,
15+
upper_bounds=None,
16+
penalties=None,
17+
penalty_function=nn.Softplus(),
18+
scheduler=None,
19+
verbose=True):
20+
'''
21+
Gradient-based nMPC compatible with continuous-time task
22+
Controller, cost and system modules are defined separately
23+
Constrained optimization:
24+
1) For control inputs:
25+
- the controller module is already defined with constraints
26+
2) For states:
27+
- we add a penalty function for constraint violation i.e.
28+
ReLU; we could also use Lagrangian methods such as
29+
https://arxiv.org/abs/2102.12894
30+
31+
Args:
32+
system: controlled system module to be controlled
33+
cost_function: cost function module
34+
t_span: tensor containing the time span
35+
opt: optimizer module such as Adam or LBFGS
36+
max_g_iters (int, optional): maximum number of gradient iterations
37+
eps_accept (float, optional): cost function value under which optimization is stopped
38+
lookahead_steps (int, optional): number of receding horizon steps
39+
lower_bounds (list, optional): lower bounds corresponding to each state variable. Default: None
40+
upper_bounds (list, optional): upper bounds corresponding to each state variable. Default: None
41+
penalties (tensor, optional): penalty weights for each state. Default: None
42+
penalty_function (module, optional): function for penalizing constraint violation. Default: nn.Softplus()
43+
scheduler (optimizer, optional): learning rate or other kind of scheduler. Default: None
44+
verbose (bool, optional): print out debug information. Default: True
45+
'''
46+
super().__init__()
47+
self.sys, self.t_span = system, t_span
48+
self.opt = opt
49+
self.eps_accept, self.max_g_iters = eps_accept, max_g_iters
50+
self.lookahead_steps = lookahead_steps
51+
self.cost_function = cost_function
52+
self.loss = 0
53+
self.trajectory = None
54+
self.trajectory_nominal = None
55+
self.controls_inputs = None
56+
self.verbose = verbose
57+
self.scheduler = scheduler
58+
self.inner_loop_iters = self.max_g_iters
59+
self.converged = False
60+
61+
# Constraints
62+
self.lower_c = lower_bounds
63+
self.upper_c = upper_bounds
64+
if lower_bounds is not None or upper_bounds is not None:
65+
self._check_bounds()
66+
if penalties is None:
67+
raise ValueError("Penalty weights were not defined")
68+
self.λ = penalties
69+
self.penalty_func = penalty_function
70+
71+
def forward(self, x):
72+
'''
73+
Module forward loop: solve the optimization problem in the given time span from position x
74+
'''
75+
# Update receding horizon
76+
remaining_span = self.t_span[:self.lookahead_steps]
77+
# Solve optimization subproblem
78+
self._solve_subproblem(x, remaining_span)
79+
return self.trajectory
80+
81+
def forward_simulation(self, real_sys, x0, t_span, steps_nom=10, reset=False, reinit_zeros=False):
82+
'''
83+
Simulate MPC by propagating the system forward with a high precision solver:
84+
the optimization problem is repeated until the end of the time span
85+
86+
Args:
87+
real_sys: controlled system module describing the nominal system evolution
88+
x0: initial position
89+
t_span: time span in which the system is simulated
90+
steps_nom (int, optional): number of nominal steps per each MPC step. Default: 10
91+
reset (bool, optional): reset all the controller parameters after each nominal system propagation. Default: False
92+
reinit_zeros (bool, optional): reset the last layer of controller parameters. Default: False
93+
94+
Returns:
95+
val_loss: validation loss of the computed trajectory
96+
'''
97+
# Obtain time spans
98+
t0, tf = t_span[0].item(), t_span[-1].item()
99+
steps = len(t_span)
100+
Δt = (tf - t0) / (steps - 1)
101+
102+
# Variables initialization for simulation
103+
t_0 = t0; x_0 = x0
104+
traj = []; controls = []
105+
if self.verbose: print('Starting simulation...')
106+
107+
# Inner loop: simulate the MPC by keping the control input constant between sampling times
108+
with trange(0, steps - 1, desc="Steps") as stepx:
109+
for j in stepx:
110+
# Updates
111+
self.t_span = torch.linspace(t_0, tf + Δt * self.lookahead_steps,
112+
int((tf - t_0 + Δt * self.lookahead_steps) / Δt) + 1).to(x0)
113+
# t span to use in the system forward simulation
114+
Δt_span = torch.linspace(t_0, t_0 + Δt, steps_nom + 1).to(x0)
115+
116+
# We reset every time the controller
117+
if reset: self._reset()
118+
if reinit_zeros: self.sys.u._init_zeros()
119+
120+
# Optimize the MPC
121+
self(x_0)
122+
123+
# Update constant controller with current MPC input to retain
124+
# We may want to use a part of the state for the controller, as in this case
125+
real_sys.u.u0 = self.sys.u(t_0, x_0).to(x0)
126+
controls.append(real_sys.u.u0[None])
127+
128+
# Propagate system forward
129+
# we do not append the solution 0 since it was already calculated
130+
part_traj = real_sys(x_0, Δt_span).squeeze(0).detach()
131+
if j == 0:
132+
traj.append(part_traj)
133+
else:
134+
traj.append(part_traj[1:])
135+
t_0 = t_0 + Δt
136+
x_0 = part_traj[-1]
137+
138+
# Update tqdm
139+
stepx.set_postfix({'cost':self.loss.item(), 'timestamp':t_0, 'converged':self.converged})
140+
141+
if self.verbose: print('The simulation has ended!')
142+
143+
# Cost function evaluation via nominal trajectory
144+
self.trajectory_nominal = torch.cat(traj, 0).detach()
145+
self.control_inputs = torch.cat(controls, 0).detach()
146+
val_loss = self.cost_function(self.trajectory_nominal, self.control_inputs).cpu().detach()
147+
if self.lower_c is not None or self.upper_c is not None:
148+
val_loss += self._penalize_constraints(self.trajectory_nominal) # constraint loss
149+
return val_loss
150+
151+
def _solve_subproblem(self, x, remaining_span):
152+
'''
153+
Solve optimization sub-problem for the remaining time span
154+
'''
155+
opt, i = self.opt, 0
156+
while i <= self.max_g_iters:
157+
# Calculate loss via closure()
158+
# This function is required by LBFGS and can support
159+
# other optimizers e.g. Adam or SGD
160+
def closure():
161+
traj = self.sys(x, remaining_span)
162+
# apply cost function, the module is defined externally
163+
loss = self.cost_function(traj, self.sys.u(0, x)) # for u not dependent on time only (to modify)
164+
if self.lower_c is not None or self.upper_c is not None:
165+
loss += self._penalize_constraints(traj) # constraint loss
166+
loss.backward() # run gradient engine
167+
# Saving metrics
168+
self.loss = loss.detach().cpu()
169+
self.trajectory = traj
170+
return loss
171+
172+
# Optimization step
173+
opt.step(closure)
174+
if self.scheduler: self.scheduler.step()
175+
opt.zero_grad(); i += 1
176+
177+
# Check for errors due i.e. to stiff system giving inf values
178+
if torch.isnan(self.loss):
179+
self._force_stop_simulation("""Loss function yielded a nan value. \
180+
This may be due to a stiff system whose ODE solver integrated to +- inf. \
181+
Try lowering step size or use another solver, i.e. and adaptive one""")
182+
183+
if self.loss <= self.eps_accept:
184+
self.inner_loop_iters = i
185+
self.converged = True
186+
return
187+
else:
188+
self.inner_loop_iters = self.max_g_iters
189+
self.converged = False
190+
return
191+
192+
def _penalize_constraints(self, x):
193+
'''Calculate penalty for constraints violation'''
194+
P = 0
195+
# Lower Constraints
196+
for c_low, i in zip(self.lower_c, range(len(self.lower_c))):
197+
if c_low is None:
198+
pass
199+
else:
200+
P += (self.λ[i] * (self.penalty_func(-x[..., i] + c_low))).abs().mean()
201+
202+
# Upper Constraints
203+
for c_up, i in zip(self.upper_c, range(len(self.upper_c))):
204+
if c_up is None:
205+
pass
206+
else:
207+
P += (self.λ[i] * (self.penalty_func(x[..., i] - c_up))).abs().mean()
208+
return P
209+
210+
def _reset(self):
211+
'''
212+
Reinitialize controller parameter under task changes
213+
Reset functon is defined inside of the controller module
214+
'''
215+
self.sys.u._reset()
216+
217+
218+
def _check_bounds(self):
219+
'''Check constraints validity'''
220+
if self.lower_c is not None and self.upper_c is not None:
221+
if len(self.lower_c) != len(self.upper_c):
222+
raise ValueError("Constraints should be of the same "
223+
"dimension; use None for unconstrained variables. "
224+
"Got dimensions {} and {}".format(
225+
len(self.lower_c), len(self.upper_c)))
226+
227+
for i in range(len(self.lower_c)):
228+
if self.lower_c[i] is not None and self.upper_c[i] is not None:
229+
if self.lower_c > self.upper_c:
230+
raise ValueError("At least one lower constraint is "
231+
"greater than its upper constraint")
232+
233+
def _force_stop_simulation(self, message):
234+
'''Simulation stop handler for i.e. nan cost function'''
235+
raise RuntimeError(r"The simulation has been forcefully stopped. Reason: {}".format(message))

torchcontrol/systems/cstr.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.autograd import grad
4+
from warnings import warn
5+
from torch import cos, sin, sign, norm
6+
from .template import ControlledSystemTemplate
7+
8+
9+
class CSTR(ControlledSystemTemplate):
10+
'''
11+
Controlled Continuous Stirred Tank Reactor
12+
Reference: https://www.do-mpc.com/en/latest/example_gallery/CSTR.html
13+
'''
14+
def __init__(self, *args, alpha=1, beta=1, **kwargs):
15+
super().__init__(*args, **kwargs)
16+
17+
# Parameters
18+
self.α = alpha # empirical parameter, may vary
19+
self.β = beta # empirical parameter, may vary
20+
self.K0_ab = 1.287e12 # K0 [h^-1]
21+
self.K0_bc = 1.287e12 # K0 [h^-1]
22+
self.K0_ad = 9.043e9 # K0 [l/mol.h]
23+
self.R_gas = 8.3144621e-3 # Universal gas constant
24+
self.E_A_ab = 9758.3*1.00 #* R_gas# [kj/mol]
25+
self.E_A_bc = 9758.3*1.00 #* R_gas# [kj/mol]
26+
self.E_A_ad = 8560.0*1.0 #* R_gas# [kj/mol]
27+
self.H_R_ab = 4.2 # [kj/mol A]
28+
self.H_R_bc = -11.0 # [kj/mol B] Exothermic
29+
self.H_R_ad = -41.85 # [kj/mol A] Exothermic
30+
self.Rou = 0.9342 # Density [kg/l]
31+
self.Cp = 3.01 # Specific Heat capacity [kj/Kg.K]
32+
self.Cp_k = 2.0 # Coolant heat capacity [kj/kg.k]
33+
self.A_R = 0.215 # Area of reactor wall [m^2]
34+
self.V_R = 10.01 #0.01 # Volume of reactor [l]
35+
self.m_k = 5.0 # Coolant mass[kg]
36+
self.T_in = 130.0 # Temp of inflow [Celsius]
37+
self.K_w = 4032.0 # [kj/h.m^2.K]
38+
self.C_A0 = (5.7+4.5)/2.0*1.0 # Concentration of A in input Upper bound 5.7 lower bound 4.5 [mol/l]
39+
40+
def dynamics(self, t, x):
41+
self.nfe += 1 # increment number of function evaluations
42+
u = self._evaluate_controller(t, x)
43+
44+
# States
45+
C_a = x[..., 0:1]
46+
C_b = x[..., 1:2]
47+
T_R = x[..., 2:3]
48+
T_K = x[..., 3:4]
49+
50+
# Controller
51+
F, dQ = u[..., :1], u[..., 1:]
52+
53+
# Auxiliary variables
54+
K_1 = self.β * self.K0_ab * torch.exp((-self.E_A_ab)/((T_R+273.15)))
55+
K_2 = self.K0_bc * torch.exp((-self.E_A_bc)/((T_R+273.15)))
56+
K_3 = self.K0_ad * torch.exp((-self.α*self.E_A_ad)/((T_R+273.15)))
57+
T_dif = T_R - T_K
58+
59+
# Differential equations
60+
dC_a = F*(self.C_A0 - C_a) -K_1*C_a - K_3*(C_a**2)
61+
dC_b = -F*C_b + K_1*C_a - K_2*C_b
62+
dT_R = ((K_1*C_a*self.H_R_ab + K_2*C_b*self.H_R_bc + K_3*(C_a**2)*self.H_R_ad)/(-self.Rou*self.Cp)) \
63+
+ F*(self.T_in-T_R) +(((self.K_w*self.A_R)*(-T_dif))/(self.Rou*self.Cp*self.V_R))
64+
dT_K = (dQ + self.K_w*self.A_R*(T_dif))/(self.m_k*self.Cp_k)
65+
self.cur_f = torch.cat([dC_a, dC_b, dT_R, dT_K], -1)
66+
return self.cur_f

0 commit comments

Comments
 (0)