Skip to content

Commit 9271e75

Browse files
committed
Add plotting utils
1 parent e1ceefb commit 9271e75

File tree

3 files changed

+269
-0
lines changed

3 files changed

+269
-0
lines changed

torchcontrol/plotting/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .cstr import *
2+
from .quadcopter import *

torchcontrol/plotting/cstr.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import matplotlib.pyplot as plt
2+
import torch
3+
4+
def plot_cstr_trajectories_controls(traj, controls, tf=0.5):
5+
fig, axs = plt.subplots(4, 1, figsize=(12, 12))
6+
alpha = 1
7+
dummy=torch.linspace(0, tf, controls.shape[0])
8+
axs[0].plot(traj[:, 0], color='blue', alpha=alpha, label=r'$C_a$')
9+
axs[0].plot(traj[:, 1], color='orange', alpha=alpha, label=r'$C_b$')
10+
axs[1].plot(traj[:, 2], color='blue', alpha=alpha, label=r'$T_R$')
11+
axs[1].plot(traj[:, 3], color='orange', alpha=alpha, label=r'$T_K$')
12+
axs[2].step(dummy, controls[:, 0], alpha=alpha, label=r'$F$')
13+
axs[3].step(dummy, controls[:, 1], alpha=alpha, label=r'$\dot{Q}$')
14+
15+
axs[0].set_ylabel('$Concentration~[mol/l]$')
16+
axs[1].set_ylabel('$Temperature~[°C]$')
17+
axs[2].set_ylabel('$Flow~[l/h]$')
18+
axs[3].set_xlabel('$Time~[h]$')
19+
axs[3].set_ylabel('$Heat~[kW]$')
20+
21+
for ax in axs:
22+
ax.legend()
23+
ax.label_outer()

torchcontrol/plotting/quadcopter.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
2+
from IPython.display import HTML
3+
import matplotlib.pyplot as plt
4+
from matplotlib.animation import FuncAnimation
5+
import numpy as np
6+
import torch
7+
from ..systems.quadcopter import euler_matrix
8+
9+
# Cube util function
10+
def cuboid_data2(pos, size=(1,1,1), rotation=None):
11+
X = [[[0, 1, 0], [0, 0, 0], [1, 0, 0], [1, 1, 0]],
12+
[[0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0]],
13+
[[1, 0, 1], [1, 0, 0], [1, 1, 0], [1, 1, 1]],
14+
[[0, 0, 1], [0, 0, 0], [0, 1, 0], [0, 1, 1]],
15+
[[0, 1, 0], [0, 1, 1], [1, 1, 1], [1, 1, 0]],
16+
[[0, 1, 1], [0, 0, 1], [1, 0, 1], [1, 1, 1]]]
17+
X = np.array(X).astype(float)
18+
for i in range(3):
19+
X[:,:,i] *= size[i]
20+
if rotation is not None:
21+
for i in range(4):
22+
X[:,i,:] = np.dot(rotation, X[:,i,:].T).T
23+
X += pos
24+
return X
25+
26+
# Plot cube for drone body
27+
def plot_cube(position,size=None,rotation=None,color=None, **kwargs):
28+
if not isinstance(color,(list,np.ndarray)): color=["C0"]*len(position)
29+
if not isinstance(size,(list,np.ndarray)): size=[(1,1,1)]*len(position)
30+
g = cuboid_data2(position, size=size, rotation=rotation)
31+
return Poly3DCollection(g,
32+
facecolor=np.repeat(color,6), **kwargs)
33+
34+
35+
36+
def plot_quadcopter_trajectories_3d(traj, x_star, i=0):
37+
'''
38+
Plot trajectory of the drone up to the i-th element
39+
Args
40+
traj: drone trajectory
41+
x_star: target state
42+
i: plot until i-th frame
43+
'''
44+
fig = plt.figure(figsize=(6, 6))
45+
ax = plt.axes(projection='3d')
46+
if isinstance(traj, torch.Tensor): traj = traj.numpy()
47+
# For visualization
48+
scale = 1.5
49+
s = 50
50+
dxm = scale*0.16 # arm length (m)
51+
dym = scale*0.16 # arm length (m)
52+
dzm = scale*0.05 # motor height (m)
53+
s_drone = scale*10 # drone body dimension
54+
lw = scale
55+
drone_size = [dxm/2, dym/2, dzm]
56+
drone_color = ["royalblue"]
57+
58+
lim = [0, x_star[2]*1.2]
59+
ax.set_xlim3d(lim[0], lim[1])
60+
ax.set_ylim3d(lim[0], lim[1])
61+
ax.set_zlim3d(lim[0], lim[1])
62+
63+
l1, = ax.plot([], [], [], lw=lw, color='red')
64+
l2, = ax.plot([], [], [], lw=lw, color='green')
65+
66+
body, = ax.plot([], [], [], marker='o', markersize=s_drone, color='black', markerfacecolor='grey')
67+
68+
initial = traj[0]
69+
70+
71+
init = ax.scatter(initial[0], initial[1], initial[2], marker='^', color='blue', label='Initial Position', s=s)
72+
fin = ax.scatter(x_star[0], x_star[1], x_star[2], marker='*', color='red', label='Target', s=s) # set linestyle to none
73+
74+
ax.plot(traj[:i, 0], traj[:i, 1], traj[:i, 2], alpha=1, linestyle='-')
75+
pos = traj[i-1]
76+
x = pos[0]
77+
y = pos[1]
78+
z = pos[2]
79+
80+
# Trick to reuse the same function
81+
R = euler_matrix(torch.Tensor([pos[3]]), torch.Tensor([pos[4]]), torch.Tensor([pos[5]])).numpy().squeeze(0)
82+
motorPoints = np.array([[dxm, -dym, dzm], [0, 0, 0], [dxm, dym, dzm], [-dxm, dym, dzm], [0, 0, 0], [-dxm, -dym, dzm], [-dxm, -dym, -dzm]])
83+
motorPoints = np.dot(R, np.transpose(motorPoints))
84+
motorPoints[0,:] += x
85+
motorPoints[1,:] += y
86+
motorPoints[2,:] += z
87+
88+
# Motors
89+
l1.set_data(motorPoints[0,0:3], motorPoints[1,0:3])
90+
l1.set_3d_properties(motorPoints[2,0:3])
91+
l2.set_data(motorPoints[0,3:6], motorPoints[1,3:6])
92+
l2.set_3d_properties(motorPoints[2,3:6])
93+
94+
# Body
95+
pos = ((motorPoints[:, 6] + 2*motorPoints[:, 1])/3)
96+
body = plot_cube(pos, drone_size, rotation=R, edgecolor="k")
97+
ax.add_collection3d(body)
98+
99+
ax.legend()
100+
ax.set_xlabel(f'$x~[m]$')
101+
ax.set_ylabel(f'$y~[m]$')
102+
ax.set_zlabel(f'$z~[m]$')
103+
104+
ax.legend(loc='upper center', bbox_to_anchor=(0.52, -0.05),
105+
fancybox=True, shadow=False, ncol=3)
106+
107+
108+
def animate_quadcopter_3d(traj, x_star, t_span, path='quadcopter_animation.gif', html_embed=False):
109+
'''
110+
Animate drone and save gif
111+
Args
112+
traj: drone trajectory
113+
x_star: target position
114+
t_span: time vector corresponding to each trajectory
115+
path: save path for
116+
html_embed: embed mp4 video in the page
117+
'''
118+
119+
fig = plt.figure(figsize=(10, 10))
120+
ax = plt.axes(projection='3d')
121+
122+
# For visualization
123+
scale = 1.5
124+
s = 50
125+
dxm = scale*0.16 # arm length (m)
126+
dym = scale*0.16 # arm length (m)
127+
dzm = scale*0.05 # motor height (m)
128+
s_drone = scale*10 # drone body dimension
129+
lw = scale
130+
drone_size = [dxm/2, dym/2, dzm]
131+
drone_color = ["royalblue"]
132+
133+
lim = [0, x_star[2]*1.2]
134+
ax.set_xlim3d(lim[0], lim[1])
135+
ax.set_ylim3d(lim[0], lim[1])
136+
ax.set_zlim3d(lim[0], lim[1])
137+
ax.set_xlabel('x[m]')
138+
ax.set_ylabel('y[m]')
139+
ax.set_zlabel('z[m]')
140+
141+
lines1, lines2 = [], []
142+
l1, = ax.plot([], [], [], lw=2, color='red')
143+
l2, = ax.plot([], [], [], lw=2, color='green')
144+
145+
body, = ax.plot([], [], [], marker='o', markersize=s_drone, color='black', markerfacecolor='black')
146+
147+
initial = traj[0]
148+
tr = traj
149+
150+
# Single frame plotting
151+
def get_frame(i):
152+
del ax.collections[:] # remove previous 3D elements
153+
init = ax.scatter(initial[0], initial[1], initial[2], marker='^', color='blue', label='Initial Position', s=s)
154+
fin = ax.scatter(x_star[0], x_star[1], x_star[2], marker='*', color='red', label='Target', s=s) # set linestyle to none
155+
ax.plot(tr[:i, 0], tr[:i, 1], tr[:i, 2], alpha=0.1, linestyle='-.', color='tab:blue')
156+
time = t_span[i]
157+
pos = tr[i]
158+
x = pos[0]
159+
y = pos[1]
160+
z = pos[2]
161+
162+
x_from0 = tr[0:i,0]
163+
y_from0 = tr[0:i,1]
164+
z_from0 = tr[0:i,2]
165+
166+
# Trick to reuse the same function
167+
R = euler_matrix(torch.Tensor([pos[3]]), torch.Tensor([pos[4]]), torch.Tensor([pos[5]])).numpy().squeeze(0)
168+
motorPoints = np.array([[dxm, -dym, dzm], [0, 0, 0], [dxm, dym, dzm], [-dxm, dym, dzm], [0, 0, 0], [-dxm, -dym, dzm], [-dxm, -dym, -dzm]])
169+
motorPoints = np.dot(R, np.transpose(motorPoints))
170+
motorPoints[0,:] += x
171+
motorPoints[1,:] += y
172+
motorPoints[2,:] += z
173+
174+
# Motors
175+
l1.set_data(motorPoints[0,0:3], motorPoints[1,0:3])
176+
l1.set_3d_properties(motorPoints[2,0:3])
177+
l2.set_data(motorPoints[0,3:6], motorPoints[1,3:6])
178+
l2.set_3d_properties(motorPoints[2,3:6])
179+
180+
# Body
181+
pos = ((motorPoints[:, 6] + 2*motorPoints[:, 1])/3)
182+
body = plot_cube(pos, drone_size, rotation=R, edgecolor="k")
183+
ax.add_collection3d(body)
184+
185+
ax.set_title("Quadcopter Trajectory, t = {:.2f} s".format(time))
186+
187+
# Unused for now
188+
def anim_callback(i, get_world_frame):
189+
frame = get_world_frame(i)
190+
set_frame(frame)
191+
192+
# Frame setting
193+
def set_frame(frame):
194+
# convert 3x6 world_frame matrix into three line_data objects which is 3x2 (row:point index, column:x,y,z)
195+
lines_data = [frame[:,[0,2]], frame[:,[1,3]], frame[:,[4,5]]]
196+
ax = plt.gca()
197+
lines = ax.get_lines()
198+
for line, line_data in zip(lines[:3], lines_data):
199+
x, y, z = line_data
200+
line.set_data(x, y)
201+
line.set_3d_properties(z)
202+
203+
an = FuncAnimation(fig,
204+
get_frame,
205+
init_func=None,
206+
frames=len(t_span)-1, interval=20, blit=False)
207+
208+
an.save(path, dpi=80, writer='imagemagick', fps=20)
209+
210+
if html_embed: HTML(an.to_html5_video())
211+
212+
213+
def plot_quadcopter_trajectories(traj):
214+
'''
215+
Simple plot with all variables in time
216+
'''
217+
218+
fig, axs = plt.subplots(12, 1, figsize=(10, 10))
219+
220+
axis_labels = ['$x$', '$y$', '$z$', '$\phi$', r'$\theta$', '$\psi$', '$\dot x$', '$\dot y$', '$\dot z$', '$\dot \phi$', '$\dot \theta$', '$\dot \psi$']
221+
222+
for ax, i, axis_label in zip(axs, range(len(axs)), axis_labels):
223+
ax.plot(traj[:, i].cpu().detach(), color='tab:red')
224+
ax.label_outer()
225+
ax.set_ylabel(axis_label)
226+
227+
fig.suptitle('Trajectories', y=0.92, fontweight='bold')
228+
229+
230+
def plot_quadcopter_controls(controls):
231+
'''
232+
Simple plot with all variables in time
233+
'''
234+
235+
fig, axs = plt.subplots(4, 1, figsize=(10, 5))
236+
237+
axis_labels = ['$u_0$ RPM', '$u_1$ RPM','$u_2$ RPM','$u_3$ RPM']
238+
239+
for ax, i, axis_label in zip(axs, range(len(axs)), axis_labels):
240+
ax.plot(controls[:, i].cpu().detach(), color='tab:red')
241+
ax.label_outer()
242+
ax.set_ylabel(axis_label)
243+
244+
fig.suptitle('Control inputs', y=0.94, fontweight='bold')

0 commit comments

Comments
 (0)