Skip to content

Commit cb68671

Browse files
committed
Add classic control systems
1 parent 708bc23 commit cb68671

File tree

3 files changed

+255
-0
lines changed

3 files changed

+255
-0
lines changed

torchcontrol/systems/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .template import *
2+
from .classic_control import *
3+
from .cstr import *
4+
from .quadcopter import *
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import torch
2+
from warnings import warn
3+
from torch import cos, sin
4+
from .template import ControlledSystemTemplate
5+
6+
7+
class ForceMass(ControlledSystemTemplate):
8+
'''System of a force acting on a mass with unitary weight'''
9+
def __init__(self, *args, **kwargs):
10+
super().__init__(*args, **kwargs)
11+
12+
def dynamics(self, t, x):
13+
self.nfe += 1 # increment number of function evaluations
14+
u = self._evaluate_controller(t, x)
15+
16+
# States
17+
p = x[...,1:]
18+
19+
# Differential Equations
20+
dq = p
21+
dp = u
22+
# trick for broadcasting into the same dimension
23+
self.cur_f = torch.cat(torch.broadcast_tensors(dq, dp), -1)
24+
return self.cur_f
25+
26+
27+
class LTISystem(ControlledSystemTemplate):
28+
"""Linear Time Invariant System
29+
Args:
30+
A (Tensor): dynamics matrix
31+
B (Tensor): controller weights
32+
"""
33+
def __init__(self, A=None, B=None, *args, **kwargs):
34+
super().__init__(*args, **kwargs)
35+
if A is None:
36+
raise ValueError("Matrix A was not declared")
37+
self.A = A
38+
self.dim = A.shape[0]
39+
if B is None:
40+
warn("Controller weigth matrix B not specified;"
41+
" using default identity matrix")
42+
self.B = torch.eye(self.dim).to(A)
43+
else:
44+
self.B = B.to(A)
45+
46+
def dynamics(self, t, x):
47+
"""The system is described by the ODE:
48+
dx = Ax + BU(t,x)
49+
We perform the operations in batches via torch.einsum()
50+
"""
51+
self.nfe += 1 # increment number of function evaluations
52+
u = self._evaluate_controller(t, x)
53+
54+
# Differential equations
55+
dx = torch.einsum('jk, ...bj -> ...bk', self.A, x) + \
56+
torch.einsum('ij, ...bj -> ...bi', self.B, u)
57+
return dx
58+
59+
60+
class SpringMass(ControlledSystemTemplate):
61+
"""
62+
Spring Mass model
63+
"""
64+
def __init__(self, *args, **kwargs):
65+
super().__init__(*args, **kwargs)
66+
self.m = 1.
67+
self.k = 0.5
68+
69+
def dynamics(self, t, x):
70+
self.nfe += 1 # increment number of function evaluations
71+
u = self._evaluate_controller(t, x)
72+
73+
# States
74+
q, p = x[..., :1], x[..., 1:]
75+
76+
# Differential equations
77+
dq = p/self.m
78+
dp = -self.k*q + u
79+
self.cur_f = torch.cat([dq, dp], -1)
80+
return self.cur_f
81+
82+
83+
class Pendulum(ControlledSystemTemplate):
84+
"""
85+
Inverted pendulum with torsional spring
86+
"""
87+
def __init__(self, *args, **kwargs):
88+
super().__init__(*args, **kwargs)
89+
self.m = 1.
90+
self.k = 0.5
91+
self.l = 1
92+
self.qr = 0
93+
self.beta = 0.01
94+
self.g = 9.81
95+
96+
def dynamics(self, t, x):
97+
self.nfe += 1 # increment number of function evaluations
98+
u = self._evaluate_controller(t, x)
99+
100+
# States
101+
q, p = x[..., :1], x[..., 1:]
102+
103+
# Differential equations
104+
dq = p/self.m
105+
dp = -self.k*(q - self.qr) - self.m*self.g*self.l*sin(q)- self.beta*p/self.m + u
106+
self.cur_f = torch.cat([dq, dp], -1)
107+
return self.cur_f
108+
109+
110+
class Acrobot(ControlledSystemTemplate):
111+
"""
112+
Acrobot: underactuated 2dof manipulator
113+
"""
114+
def __init__(self, *args, **kwargs):
115+
super().__init__(*args, **kwargs)
116+
self.m1 = 1.
117+
self.m2 = 1.
118+
self.l1 = 1.
119+
self.l2 = 1.
120+
self.b1 = 1
121+
self.b2 = 1
122+
self.g = 9.81
123+
124+
def dynamics(self, t, x):
125+
self.nfe += 1 # increment number of function evaluations
126+
u = self._evaluate_controller(t, x)
127+
128+
with torch.set_grad_enabled(True):
129+
# States
130+
q1, q2, p1, p2 = x[:, :1], x[:, 1:2], x[:, 2:3], x[:, 3:4]
131+
132+
# Variables
133+
s1, s2 = sin(q1), sin(q2)
134+
c2, c2 = cos(q1), cos(q2)
135+
s12, c12, s212 = sin(q1-q2), cos(q1-q2), sin(2*(q1-q2))
136+
h1 = p1*p2*s12/(self.l1*self.l2*(self.m1 + self.m2*(s12**2)))
137+
h2 = self.m2*(self.l2**2)*(p1**2) + (self.m1+self.m2)*(self.l1**2)*(p2**2) - 2*self.m2*self.l1*self.l2*p1*p2*c12
138+
h2 = h2/(2*((self.l1*self.l2)**2)*(self.m1 + self.m2*(s12**2))**2)
139+
140+
# Differential Equations
141+
dqdt = torch.cat([
142+
(self.l2*p1 - self.l1*p2*c12)/((self.l1**2)*self.l2*(self.m1 + self.m2*(s12**2))),
143+
(-self.m2*self.l2*p1*c12 + (self.m1+self.m2)*self.l1*p2)/(self.m2*(self.l2**2)*self.l1*(self.m1 + self.m2*(s12**2)))
144+
], 1)
145+
dpdt = torch.cat([
146+
-(self.m1+self.m2)*self.g*self.l1*s1 - h1 + h2*s212 - self.b1*dqdt[:,:1],
147+
-self.m2*self.g*self.l2*s2 + h1 - h2*s212 - self.b2*dqdt[:,1:]], 1)
148+
self.cur_f = torch.cat([dqdt, dpdt+u], 1)
149+
return self.cur_f
150+
151+
152+
class CartPole(ControlledSystemTemplate):
153+
'''Continuous version of the OpenAI Gym cartpole
154+
Inspired by: https://gist.github.com/iandanforth/e3ffb67cf3623153e968f2afdfb01dc8'''
155+
def __init__(self, *args, **kwargs):
156+
super().__init__(*args, **kwargs)
157+
self.gravity = 9.81
158+
self.masscart = 1.0
159+
self.masspole = 0.1
160+
self.total_mass = (self.masspole + self.masscart)
161+
self.length = 0.5
162+
self.polemass_length = (self.masspole * self.length)
163+
164+
def dynamics(self, t, x_):
165+
self.nfe += 1 # increment number of function evaluations
166+
u = self._evaluate_controller(t, x_) # controller
167+
168+
# States
169+
x = x_[..., 0:1]
170+
dx = x_[..., 1:2]
171+
θ = x_[..., 2:3]
172+
= x_[..., 3:4]
173+
174+
# Auxiliary variables
175+
cosθ, sinθ = cos(θ), sin(θ)
176+
temp = (u + self.polemass_length * **2 * sinθ) / self.total_mass
177+
178+
# Differential Equations
179+
ddθ = (self.gravity * sinθ - cosθ * temp) / \
180+
(self.length * (4.0/3.0 - self.masspole * cosθ**2 / self.total_mass))
181+
ddx = temp - self.polemass_length * ddθ * cosθ / self.total_mass
182+
self.cur_f = torch.cat([dx, ddx, , ddθ], -1)
183+
return self.cur_f
184+
185+
def render(self):
186+
raise NotImplementedError("TODO: add the rendering from OpenAI Gym")

torchcontrol/systems/template.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
import torch.nn as nn
3+
from torchdyn.numerics.odeint import odeint
4+
5+
class ControlledSystemTemplate(nn.Module):
6+
"""
7+
Template Model
8+
"""
9+
def __init__(self, u,
10+
solver='euler',
11+
retain_u=False,
12+
**odeint_kwargs):
13+
super().__init__()
14+
self.u = u
15+
self.solver = solver
16+
self.retain_u = retain_u # use for retaining control input (e.g. MPC simulation)
17+
self.nfe = 0 # count number of function evaluations of the vector field
18+
self.cur_f = None # current dynamics evaluation
19+
self.cur_u = None # current controller value
20+
self._retain_flag = False # temporary flag for evaluating the controller only the first time
21+
self.odeint_kwargs = odeint_kwargs
22+
23+
def forward(self, x0, t_span):
24+
x = [x0[None]]
25+
xt = x0
26+
if self.retain_u:
27+
# Iterate over the t_span: evaluate the controller the first time only and then retain it
28+
# this is useful to simulate control with MPC
29+
for i in range(len(t_span)-1):
30+
self._retain_flag = False
31+
diff_span = torch.linspace(t_span[i], t_span[i+1], 2)
32+
odeint(self.dynamics, xt, diff_span, solver=self.solver, **self.odeint_kwargs)[1][-1]
33+
x.append(xt[None])
34+
traj = torch.cat(x)
35+
else:
36+
# Compute trajectory with odeint and base solvers
37+
traj = odeint(self.dynamics, xt, t_span, solver=self.solver, **self.odeint_kwargs)[1]
38+
return traj
39+
40+
def reset_nfe(self):
41+
"""Return number of function evaluation and reset"""
42+
cur_nfe = self.nfe; self.nfe = 0
43+
return cur_nfe
44+
45+
def _evaluate_controller(self, t, x):
46+
'''
47+
If we wish not to re-evaluate the control input, we set the retain
48+
flag to True so we do not re-evaluate next time
49+
'''
50+
if self.retain_u:
51+
if not self._retain_flag:
52+
self.cur_u = self.u(t, x)
53+
self._retain_flag = True
54+
else:
55+
pass # We do not re-evaluate the control input
56+
else:
57+
self.cur_u = self.u(t, x)
58+
return self.cur_u
59+
60+
61+
def dynamics(self, t, x):
62+
'''
63+
Model dynamics in the form xdot = f(t, x, u)
64+
'''
65+
raise NotImplementedError

0 commit comments

Comments
 (0)