1
+ import os
2
+ import torch
3
+ import torch .nn as nn
4
+ import torch .nn .functional as F
5
+ from .correlation_package .correlation import Correlation
6
+
7
+
8
+ def convrelu (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 , dilation = 1 , groups = 1 , bias = True ):
9
+ return nn .Sequential (
10
+ nn .Conv2d (in_channels , out_channels , kernel_size , stride , padding , dilation , groups , bias = bias ),
11
+ nn .LeakyReLU (0.1 , inplace = True )
12
+ )
13
+
14
+
15
+ def deconv (in_planes , out_planes , kernel_size = 4 , stride = 2 , padding = 1 ):
16
+ return nn .ConvTranspose2d (in_planes , out_planes , kernel_size , stride , padding , bias = True )
17
+
18
+
19
+ class Decoder (nn .Module ):
20
+ def __init__ (self , in_channels , groups ):
21
+ super (Decoder , self ).__init__ ()
22
+ self .in_channels = in_channels
23
+ self .groups = groups
24
+ self .conv1 = convrelu (in_channels , 96 , 3 , 1 )
25
+ self .conv2 = convrelu (96 , 96 , 3 , 1 , groups = groups )
26
+ self .conv3 = convrelu (96 , 96 , 3 , 1 , groups = groups )
27
+ self .conv4 = convrelu (96 , 96 , 3 , 1 , groups = groups )
28
+ self .conv5 = convrelu (96 , 64 , 3 , 1 )
29
+ self .conv6 = convrelu (64 , 32 , 3 , 1 )
30
+ self .conv7 = nn .Conv2d (32 , 2 , 3 , 1 , 1 )
31
+
32
+
33
+ def channel_shuffle (self , x , groups ):
34
+ b , c , h , w = x .size ()
35
+ channels_per_group = c // groups
36
+ x = x .view (b , groups , channels_per_group , h , w )
37
+ x = x .transpose (1 , 2 ).contiguous ()
38
+ x = x .view (b , - 1 , h , w )
39
+ return x
40
+
41
+
42
+ def forward (self , x ):
43
+ if self .groups == 1 :
44
+ out = self .conv7 (self .conv6 (self .conv5 (self .conv4 (self .conv3 (self .conv2 (self .conv1 (x )))))))
45
+ else :
46
+ out = self .conv1 (x )
47
+ out = self .channel_shuffle (self .conv2 (out ), self .groups )
48
+ out = self .channel_shuffle (self .conv3 (out ), self .groups )
49
+ out = self .channel_shuffle (self .conv4 (out ), self .groups )
50
+ out = self .conv7 (self .conv6 (self .conv5 (out )))
51
+ return out
52
+
53
+
54
+ class FastFlowNet (nn .Module ):
55
+ def __init__ (self , groups = 3 ):
56
+ super (FastFlowNet , self ).__init__ ()
57
+ self .groups = groups
58
+ self .pconv1_1 = convrelu (3 , 16 , 3 , 2 )
59
+ self .pconv1_2 = convrelu (16 , 16 , 3 , 1 )
60
+ self .pconv2_1 = convrelu (16 , 32 , 3 , 2 )
61
+ self .pconv2_2 = convrelu (32 , 32 , 3 , 1 )
62
+ self .pconv2_3 = convrelu (32 , 32 , 3 , 1 )
63
+ self .pconv3_1 = convrelu (32 , 64 , 3 , 2 )
64
+ self .pconv3_2 = convrelu (64 , 64 , 3 , 1 )
65
+ self .pconv3_3 = convrelu (64 , 64 , 3 , 1 )
66
+
67
+ self .corr = Correlation (pad_size = 4 , kernel_size = 1 , max_displacement = 4 , stride1 = 1 , stride2 = 1 , corr_multiply = 1 )
68
+ self .index = torch .tensor ([0 , 2 , 4 , 6 , 8 ,
69
+ 10 , 12 , 14 , 16 ,
70
+ 18 , 20 , 21 , 22 , 23 , 24 , 26 ,
71
+ 28 , 29 , 30 , 31 , 32 , 33 , 34 ,
72
+ 36 , 38 , 39 , 40 , 41 , 42 , 44 ,
73
+ 46 , 47 , 48 , 49 , 50 , 51 , 52 ,
74
+ 54 , 56 , 57 , 58 , 59 , 60 , 62 ,
75
+ 64 , 66 , 68 , 70 ,
76
+ 72 , 74 , 76 , 78 , 80 ])
77
+
78
+ self .rconv2 = convrelu (32 , 32 , 3 , 1 )
79
+ self .rconv3 = convrelu (64 , 32 , 3 , 1 )
80
+ self .rconv4 = convrelu (64 , 32 , 3 , 1 )
81
+ self .rconv5 = convrelu (64 , 32 , 3 , 1 )
82
+ self .rconv6 = convrelu (64 , 32 , 3 , 1 )
83
+
84
+ self .up3 = deconv (2 , 2 )
85
+ self .up4 = deconv (2 , 2 )
86
+ self .up5 = deconv (2 , 2 )
87
+ self .up6 = deconv (2 , 2 )
88
+
89
+ self .decoder2 = Decoder (87 , groups )
90
+ self .decoder3 = Decoder (87 , groups )
91
+ self .decoder4 = Decoder (87 , groups )
92
+ self .decoder5 = Decoder (87 , groups )
93
+ self .decoder6 = Decoder (87 , groups )
94
+
95
+ for m in self .modules ():
96
+ if isinstance (m , nn .Conv2d ) or isinstance (m , nn .ConvTranspose2d ):
97
+ nn .init .kaiming_normal_ (m .weight )
98
+ if m .bias is not None :
99
+ nn .init .zeros_ (m .bias )
100
+
101
+
102
+ def warp (self , x , flo ):
103
+ B , C , H , W = x .size ()
104
+ xx = torch .arange (0 , W ).view (1 , - 1 ).repeat (H , 1 )
105
+ yy = torch .arange (0 , H ).view (- 1 , 1 ).repeat (1 , W )
106
+ xx = xx .view (1 , 1 , H , W ).repeat (B , 1 , 1 , 1 )
107
+ yy = yy .view (1 , 1 , H , W ).repeat (B , 1 , 1 , 1 )
108
+ grid = torch .cat ([xx , yy ], 1 ).to (x )
109
+ vgrid = grid + flo
110
+ vgrid [:, 0 , :, :] = 2.0 * vgrid [:, 0 , :, :] / max (W - 1 , 1 ) - 1.0
111
+ vgrid [:, 1 , :, :] = 2.0 * vgrid [:, 1 , :, :] / max (H - 1 , 1 ) - 1.0
112
+ vgrid = vgrid .permute (0 , 2 , 3 , 1 )
113
+ output = F .grid_sample (x , vgrid , mode = 'bilinear' )
114
+ return output
115
+
116
+
117
+ def forward (self , x ):
118
+ img1 = x [:, :3 , :, :]
119
+ img2 = x [:, 3 :6 , :, :]
120
+ f11 = self .pconv1_2 (self .pconv1_1 (img1 ))
121
+ f21 = self .pconv1_2 (self .pconv1_1 (img2 ))
122
+ f12 = self .pconv2_3 (self .pconv2_2 (self .pconv2_1 (f11 )))
123
+ f22 = self .pconv2_3 (self .pconv2_2 (self .pconv2_1 (f21 )))
124
+ f13 = self .pconv3_3 (self .pconv3_2 (self .pconv3_1 (f12 )))
125
+ f23 = self .pconv3_3 (self .pconv3_2 (self .pconv3_1 (f22 )))
126
+ f14 = F .avg_pool2d (f13 , kernel_size = (2 , 2 ), stride = (2 , 2 ))
127
+ f24 = F .avg_pool2d (f23 , kernel_size = (2 , 2 ), stride = (2 , 2 ))
128
+ f15 = F .avg_pool2d (f14 , kernel_size = (2 , 2 ), stride = (2 , 2 ))
129
+ f25 = F .avg_pool2d (f24 , kernel_size = (2 , 2 ), stride = (2 , 2 ))
130
+ f16 = F .avg_pool2d (f15 , kernel_size = (2 , 2 ), stride = (2 , 2 ))
131
+ f26 = F .avg_pool2d (f25 , kernel_size = (2 , 2 ), stride = (2 , 2 ))
132
+
133
+ flow7_up = torch .zeros (f16 .size (0 ), 2 , f16 .size (2 ), f16 .size (3 )).to (f15 )
134
+ cv6 = torch .index_select (self .corr (f16 , f26 ), dim = 1 , index = self .index .to (f16 ).long ())
135
+ r16 = self .rconv6 (f16 )
136
+ cat6 = torch .cat ([cv6 , r16 , flow7_up ], 1 )
137
+ flow6 = self .decoder6 (cat6 )
138
+
139
+ flow6_up = self .up6 (flow6 )
140
+ f25_w = self .warp (f25 , flow6_up * 0.625 )
141
+ cv5 = torch .index_select (self .corr (f15 , f25_w ), dim = 1 , index = self .index .to (f15 ).long ())
142
+ r15 = self .rconv5 (f15 )
143
+ cat5 = torch .cat ([cv5 , r15 , flow6_up ], 1 )
144
+ flow5 = self .decoder5 (cat5 ) + flow6_up
145
+
146
+ flow5_up = self .up5 (flow5 )
147
+ f24_w = self .warp (f24 , flow5_up * 1.25 )
148
+ cv4 = torch .index_select (self .corr (f14 , f24_w ), dim = 1 , index = self .index .to (f14 ).long ())
149
+ r14 = self .rconv4 (f14 )
150
+ cat4 = torch .cat ([cv4 , r14 , flow5_up ], 1 )
151
+ flow4 = self .decoder4 (cat4 ) + flow5_up
152
+
153
+ flow4_up = self .up4 (flow4 )
154
+ f23_w = self .warp (f23 , flow4_up * 2.5 )
155
+ cv3 = torch .index_select (self .corr (f13 , f23_w ), dim = 1 , index = self .index .to (f13 ).long ())
156
+ r13 = self .rconv3 (f13 )
157
+ cat3 = torch .cat ([cv3 , r13 , flow4_up ], 1 )
158
+ flow3 = self .decoder3 (cat3 ) + flow4_up
159
+
160
+ flow3_up = self .up3 (flow3 )
161
+ f22_w = self .warp (f22 , flow3_up * 5.0 )
162
+ cv2 = torch .index_select (self .corr (f12 , f22_w ), dim = 1 , index = self .index .to (f12 ).long ())
163
+ r12 = self .rconv2 (f12 )
164
+ cat2 = torch .cat ([cv2 , r12 , flow3_up ], 1 )
165
+ flow2 = self .decoder2 (cat2 ) + flow3_up
166
+
167
+ if self .training :
168
+ return flow2 , flow3 , flow4 , flow5 , flow6
169
+ else :
170
+ return flow2
0 commit comments