1
+ import torch
2
+ from torch import nn
3
+ from warnings import warn
4
+ tanh = nn .Tanh ()
5
+
6
+ class BoxConstrainedController (nn .Module ):
7
+ """Simple controller based on a Neural Network with
8
+ bounded control inputs
9
+
10
+ Args:
11
+ in_dim: input dimension
12
+ out_dim: output dimension
13
+ hid_dim: hidden dimension
14
+ zero_init: initialize last layer to zeros
15
+ """
16
+ def __init__ (self ,
17
+ in_dim ,
18
+ out_dim ,
19
+ h_dim = 64 ,
20
+ num_layers = 2 ,
21
+ zero_init = True ,
22
+ input_scaling = None ,
23
+ output_scaling = None ,
24
+ constrained = False ):
25
+
26
+ super ().__init__ ()
27
+ # Create Neural Network
28
+ layers = []
29
+ layers .append (nn .Linear (in_dim , h_dim ))
30
+ for i in range (num_layers ):
31
+ if i < num_layers - 1 :
32
+ layers .append (nn .Softplus ())
33
+ else :
34
+ # last layer has tanh as activation function
35
+ # which acts as a regulator
36
+ layers .append (nn .Tanh ())
37
+ break
38
+ layers .append (nn .Linear (h_dim , h_dim ))
39
+ layers .append (nn .Linear (h_dim , out_dim ))
40
+ self .layers = nn .Sequential (* layers )
41
+
42
+ # Initialize controller with zeros in the last layer
43
+ if zero_init : self ._init_zeros ()
44
+ self .zero_init = zero_init
45
+
46
+ # Scaling
47
+ if constrained is False and output_scaling is not None :
48
+ warn ("Output scaling has no effect without the `constrained` variable set to true" )
49
+ if input_scaling is None :
50
+ input_scaling = torch .ones (in_dim )
51
+ if output_scaling is None :
52
+ # scaling[:, 0] -> min value
53
+ # scaling[:, 1] -> max value
54
+ output_scaling = torch .cat ([- torch .ones (out_dim ),
55
+ torch .ones (out_dim )], - 1 )
56
+ self .in_scaling = input_scaling
57
+ self .out_scaling = output_scaling
58
+ self .constrained = constrained
59
+
60
+ def forward (self , t , x ):
61
+ x = self .layers (self .in_scaling .to (x )* x )
62
+ if self .constrained :
63
+ # We consider the constraints between -1 and 1
64
+ # and then we rescale them
65
+ x = tanh (x )
66
+ # TODO: fix the tanh to clamp
67
+ # x = torch.clamp(x, -1, 1) # not working in some applications
68
+ x = self ._rescale (x )
69
+ return x
70
+
71
+ def _rescale (self , x ):
72
+ s = self .out_scaling .to (x )
73
+ return 0.5 * (x + 1 )* (s [...,1 ]- s [...,0 ]) + s [...,0 ]
74
+
75
+ def _reset (self ):
76
+ '''Reinitialize layers'''
77
+ for p in self .layers .children ():
78
+ if hasattr (p , 'reset_parameters' ):
79
+ p .reset_parameters ()
80
+ if self .zero_init : self ._init_zeros ()
81
+
82
+ def _init_zeros (self ):
83
+ '''Reinitialize last layer with zeros'''
84
+ for p in self .layers [- 1 ].parameters ():
85
+ nn .init .zeros_ (p )
86
+
87
+
88
+ class RandConstController (nn .Module ):
89
+ """Constant controller
90
+ We can use this for residual propagation and MPC steps (forward propagation)"""
91
+ def __init__ (self , shape = (1 ,1 ), u_min = - 1 , u_max = 1 ):
92
+ super ().__init__ ()
93
+ self .u0 = torch .Tensor (* shape ).uniform_ (u_min , u_max )
94
+
95
+ def forward (self , t , x ):
96
+ return self .u0
0 commit comments