Skip to content

Commit 7dc015b

Browse files
authored
Add files via upload
1 parent 56c7fb4 commit 7dc015b

25 files changed

+1300
-0
lines changed

benchmark.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import os
2+
import time
3+
import numpy as np
4+
import torch
5+
from models.FastFlowNet import FastFlowNet
6+
7+
def count_parameters(model):
8+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
9+
10+
input_t = torch.randn(1, 6, 448, 1024).cuda()
11+
print(input_t.shape)
12+
13+
model = FastFlowNet().cuda().eval()
14+
model.load_state_dict(torch.load('./checkpoints/fastflownet_ft_mix.pth'))
15+
16+
output_t = model(input_t)
17+
print(output_t.shape)
18+
19+
start = time.time()
20+
for x in range(1000):
21+
output_t = model(input_t)
22+
end = time.time()
23+
print('Time elapsed: {:.3f} ms'.format(end-start))
24+
25+
model = model.train()
26+
print('Number of parameters: {:.2f} M'.format(count_parameters(model) / 1e6))

checkpoints/fastflownet_chairs.pth

5.23 MB
Binary file not shown.

checkpoints/fastflownet_ft_kitti.pth

5.23 MB
Binary file not shown.

checkpoints/fastflownet_ft_mix.pth

5.23 MB
Binary file not shown.

checkpoints/fastflownet_ft_sintel.pth

5.23 MB
Binary file not shown.

checkpoints/fastflownet_things3d.pth

5.23 MB
Binary file not shown.

data/000038_10.png

770 KB
Loading

data/000038_10_flow.png

189 KB
Loading

data/000038_11.png

765 KB
Loading

data/fastflownet.png

427 KB
Loading

data/frame_0006.png

516 KB
Loading

data/frame_0006_flow.png

226 KB
Loading

data/frame_0007.png

508 KB
Loading

data/img_050.jpg

182 KB
Loading

data/img_050_flow.png

173 KB
Loading

data/img_051.jpg

183 KB
Loading

data/tx2_demo.gif

18 MB
Loading

demo.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import numpy as np
2+
import cv2
3+
import torch
4+
import torch.nn.functional as F
5+
from models.FastFlowNet import FastFlowNet
6+
from flow_vis import flow_to_color
7+
8+
div_flow = 20.0
9+
div_size = 64
10+
11+
def centralize(img1, img2):
12+
b, c, h, w = img1.shape
13+
rgb_mean = torch.cat([img1, img2], dim=2).view(b, c, -1).mean(2).view(b, c, 1, 1)
14+
return img1 - rgb_mean, img2 - rgb_mean, rgb_mean
15+
16+
model = FastFlowNet().cuda().eval()
17+
model.load_state_dict(torch.load('./checkpoints/fastflownet_ft_mix.pth'))
18+
19+
# img1_path = './data/img_050.jpg'
20+
# img2_path = './data/img_051.jpg'
21+
# img1_path = './data/frame_0006.png'
22+
# img2_path = './data/frame_0007.png'
23+
img1_path = './data/000038_10.png'
24+
img2_path = './data/000038_11.png'
25+
26+
img1 = torch.from_numpy(cv2.imread(img1_path)).float().permute(2, 0, 1).unsqueeze(0)/255.0
27+
img2 = torch.from_numpy(cv2.imread(img2_path)).float().permute(2, 0, 1).unsqueeze(0)/255.0
28+
img1, img2, _ = centralize(img1, img2)
29+
30+
height, width = img1.shape[-2:]
31+
orig_size = (int(height), int(width))
32+
33+
if height % div_size != 0 or width % div_size != 0:
34+
input_size = (
35+
int(div_size * np.ceil(height / div_size)),
36+
int(div_size * np.ceil(width / div_size))
37+
)
38+
img1 = F.interpolate(img1, size=input_size, mode='bilinear', align_corners=False)
39+
img2 = F.interpolate(img2, size=input_size, mode='bilinear', align_corners=False)
40+
else:
41+
input_size = orig_size
42+
43+
input_t = torch.cat([img1, img2], 1).cuda()
44+
45+
output = model(input_t).data
46+
47+
flow = div_flow * F.interpolate(output, size=input_size, mode='bilinear', align_corners=False)
48+
49+
if input_size != orig_size:
50+
scale_h = orig_size[0] / input_size[0]
51+
scale_w = orig_size[1] / input_size[1]
52+
flow = F.interpolate(flow, size=orig_size, mode='bilinear', align_corners=False)
53+
flow[:, 0, :, :] *= scale_w
54+
flow[:, 1, :, :] *= scale_h
55+
56+
flow = flow[0].cpu().permute(1, 2, 0).numpy()
57+
58+
flow_color = flow_to_color(flow, convert_to_bgr=True)
59+
60+
cv2.imwrite('./data/flow.png', flow_color)

flow_vis.py

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# MIT License
2+
#
3+
# Copyright (c) 2018 Tom Runia
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to conditions.
11+
#
12+
# Author: Tom Runia
13+
# Date Created: 2018-08-03
14+
15+
import numpy as np
16+
17+
def make_colorwheel():
18+
"""
19+
Generates a color wheel for optical flow visualization as presented in:
20+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
21+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
22+
23+
Code follows the original C++ source code of Daniel Scharstein.
24+
Code follows the the Matlab source code of Deqing Sun.
25+
26+
Returns:
27+
np.ndarray: Color wheel
28+
"""
29+
30+
RY = 15
31+
YG = 6
32+
GC = 4
33+
CB = 11
34+
BM = 13
35+
MR = 6
36+
37+
ncols = RY + YG + GC + CB + BM + MR
38+
colorwheel = np.zeros((ncols, 3))
39+
col = 0
40+
41+
# RY
42+
colorwheel[0:RY, 0] = 255
43+
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
44+
col = col+RY
45+
# YG
46+
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
47+
colorwheel[col:col+YG, 1] = 255
48+
col = col+YG
49+
# GC
50+
colorwheel[col:col+GC, 1] = 255
51+
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
52+
col = col+GC
53+
# CB
54+
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
55+
colorwheel[col:col+CB, 2] = 255
56+
col = col+CB
57+
# BM
58+
colorwheel[col:col+BM, 2] = 255
59+
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
60+
col = col+BM
61+
# MR
62+
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
63+
colorwheel[col:col+MR, 0] = 255
64+
return colorwheel
65+
66+
67+
def flow_uv_to_colors(u, v, convert_to_bgr=False):
68+
"""
69+
Applies the flow color wheel to (possibly clipped) flow components u and v.
70+
71+
According to the C++ source code of Daniel Scharstein
72+
According to the Matlab source code of Deqing Sun
73+
74+
Args:
75+
u (np.ndarray): Input horizontal flow of shape [H,W]
76+
v (np.ndarray): Input vertical flow of shape [H,W]
77+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
78+
79+
Returns:
80+
np.ndarray: Flow visualization image of shape [H,W,3]
81+
"""
82+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
83+
colorwheel = make_colorwheel() # shape [55x3]
84+
ncols = colorwheel.shape[0]
85+
rad = np.sqrt(np.square(u) + np.square(v))
86+
a = np.arctan2(-v, -u)/np.pi
87+
fk = (a+1) / 2*(ncols-1)
88+
k0 = np.floor(fk).astype(np.int32)
89+
k1 = k0 + 1
90+
k1[k1 == ncols] = 0
91+
f = fk - k0
92+
for i in range(colorwheel.shape[1]):
93+
tmp = colorwheel[:,i]
94+
col0 = tmp[k0] / 255.0
95+
col1 = tmp[k1] / 255.0
96+
col = (1-f)*col0 + f*col1
97+
idx = (rad <= 1)
98+
col[idx] = 1 - rad[idx] * (1-col[idx])
99+
col[~idx] = col[~idx] * 0.75 # out of range
100+
# Note the 2-i => BGR instead of RGB
101+
ch_idx = 2-i if convert_to_bgr else i
102+
flow_image[:,:,ch_idx] = np.floor(255 * col)
103+
return flow_image
104+
105+
106+
def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
107+
"""
108+
Expects a two dimensional flow image of shape.
109+
110+
Args:
111+
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
112+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
113+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
114+
115+
Returns:
116+
np.ndarray: Flow visualization image of shape [H,W,3]
117+
"""
118+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
119+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
120+
if clip_flow is not None:
121+
flow_uv = np.clip(flow_uv, 0, clip_flow)
122+
u = flow_uv[:,:,0]
123+
v = flow_uv[:,:,1]
124+
rad = np.sqrt(np.square(u) + np.square(v))
125+
rad_max = np.max(rad)
126+
epsilon = 1e-5
127+
u = u / (rad_max + epsilon)
128+
v = v / (rad_max + epsilon)
129+
return flow_uv_to_colors(u, v, convert_to_bgr)

models/FastFlowNet.py

+170
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import os
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from .correlation_package.correlation import Correlation
6+
7+
8+
def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
9+
return nn.Sequential(
10+
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
11+
nn.LeakyReLU(0.1, inplace=True)
12+
)
13+
14+
15+
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
16+
return nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True)
17+
18+
19+
class Decoder(nn.Module):
20+
def __init__(self, in_channels, groups):
21+
super(Decoder, self).__init__()
22+
self.in_channels = in_channels
23+
self.groups = groups
24+
self.conv1 = convrelu(in_channels, 96, 3, 1)
25+
self.conv2 = convrelu(96, 96, 3, 1, groups=groups)
26+
self.conv3 = convrelu(96, 96, 3, 1, groups=groups)
27+
self.conv4 = convrelu(96, 96, 3, 1, groups=groups)
28+
self.conv5 = convrelu(96, 64, 3, 1)
29+
self.conv6 = convrelu(64, 32, 3, 1)
30+
self.conv7 = nn.Conv2d(32, 2, 3, 1, 1)
31+
32+
33+
def channel_shuffle(self, x, groups):
34+
b, c, h, w = x.size()
35+
channels_per_group = c // groups
36+
x = x.view(b, groups, channels_per_group, h, w)
37+
x = x.transpose(1, 2).contiguous()
38+
x = x.view(b, -1, h, w)
39+
return x
40+
41+
42+
def forward(self, x):
43+
if self.groups == 1:
44+
out = self.conv7(self.conv6(self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))))
45+
else:
46+
out = self.conv1(x)
47+
out = self.channel_shuffle(self.conv2(out), self.groups)
48+
out = self.channel_shuffle(self.conv3(out), self.groups)
49+
out = self.channel_shuffle(self.conv4(out), self.groups)
50+
out = self.conv7(self.conv6(self.conv5(out)))
51+
return out
52+
53+
54+
class FastFlowNet(nn.Module):
55+
def __init__(self, groups=3):
56+
super(FastFlowNet, self).__init__()
57+
self.groups = groups
58+
self.pconv1_1 = convrelu(3, 16, 3, 2)
59+
self.pconv1_2 = convrelu(16, 16, 3, 1)
60+
self.pconv2_1 = convrelu(16, 32, 3, 2)
61+
self.pconv2_2 = convrelu(32, 32, 3, 1)
62+
self.pconv2_3 = convrelu(32, 32, 3, 1)
63+
self.pconv3_1 = convrelu(32, 64, 3, 2)
64+
self.pconv3_2 = convrelu(64, 64, 3, 1)
65+
self.pconv3_3 = convrelu(64, 64, 3, 1)
66+
67+
self.corr = Correlation(pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1, corr_multiply=1)
68+
self.index = torch.tensor([0, 2, 4, 6, 8,
69+
10, 12, 14, 16,
70+
18, 20, 21, 22, 23, 24, 26,
71+
28, 29, 30, 31, 32, 33, 34,
72+
36, 38, 39, 40, 41, 42, 44,
73+
46, 47, 48, 49, 50, 51, 52,
74+
54, 56, 57, 58, 59, 60, 62,
75+
64, 66, 68, 70,
76+
72, 74, 76, 78, 80])
77+
78+
self.rconv2 = convrelu(32, 32, 3, 1)
79+
self.rconv3 = convrelu(64, 32, 3, 1)
80+
self.rconv4 = convrelu(64, 32, 3, 1)
81+
self.rconv5 = convrelu(64, 32, 3, 1)
82+
self.rconv6 = convrelu(64, 32, 3, 1)
83+
84+
self.up3 = deconv(2, 2)
85+
self.up4 = deconv(2, 2)
86+
self.up5 = deconv(2, 2)
87+
self.up6 = deconv(2, 2)
88+
89+
self.decoder2 = Decoder(87, groups)
90+
self.decoder3 = Decoder(87, groups)
91+
self.decoder4 = Decoder(87, groups)
92+
self.decoder5 = Decoder(87, groups)
93+
self.decoder6 = Decoder(87, groups)
94+
95+
for m in self.modules():
96+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
97+
nn.init.kaiming_normal_(m.weight)
98+
if m.bias is not None:
99+
nn.init.zeros_(m.bias)
100+
101+
102+
def warp(self, x, flo):
103+
B, C, H, W = x.size()
104+
xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
105+
yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
106+
xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
107+
yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
108+
grid = torch.cat([xx, yy], 1).to(x)
109+
vgrid = grid + flo
110+
vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(W-1, 1) - 1.0
111+
vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(H-1, 1) - 1.0
112+
vgrid = vgrid.permute(0, 2, 3, 1)
113+
output = F.grid_sample(x, vgrid, mode='bilinear')
114+
return output
115+
116+
117+
def forward(self, x):
118+
img1 = x[:, :3, :, :]
119+
img2 = x[:, 3:6, :, :]
120+
f11 = self.pconv1_2(self.pconv1_1(img1))
121+
f21 = self.pconv1_2(self.pconv1_1(img2))
122+
f12 = self.pconv2_3(self.pconv2_2(self.pconv2_1(f11)))
123+
f22 = self.pconv2_3(self.pconv2_2(self.pconv2_1(f21)))
124+
f13 = self.pconv3_3(self.pconv3_2(self.pconv3_1(f12)))
125+
f23 = self.pconv3_3(self.pconv3_2(self.pconv3_1(f22)))
126+
f14 = F.avg_pool2d(f13, kernel_size=(2, 2), stride=(2, 2))
127+
f24 = F.avg_pool2d(f23, kernel_size=(2, 2), stride=(2, 2))
128+
f15 = F.avg_pool2d(f14, kernel_size=(2, 2), stride=(2, 2))
129+
f25 = F.avg_pool2d(f24, kernel_size=(2, 2), stride=(2, 2))
130+
f16 = F.avg_pool2d(f15, kernel_size=(2, 2), stride=(2, 2))
131+
f26 = F.avg_pool2d(f25, kernel_size=(2, 2), stride=(2, 2))
132+
133+
flow7_up = torch.zeros(f16.size(0), 2, f16.size(2), f16.size(3)).to(f15)
134+
cv6 = torch.index_select(self.corr(f16, f26), dim=1, index=self.index.to(f16).long())
135+
r16 = self.rconv6(f16)
136+
cat6 = torch.cat([cv6, r16, flow7_up], 1)
137+
flow6 = self.decoder6(cat6)
138+
139+
flow6_up = self.up6(flow6)
140+
f25_w = self.warp(f25, flow6_up*0.625)
141+
cv5 = torch.index_select(self.corr(f15, f25_w), dim=1, index=self.index.to(f15).long())
142+
r15 = self.rconv5(f15)
143+
cat5 = torch.cat([cv5, r15, flow6_up], 1)
144+
flow5 = self.decoder5(cat5) + flow6_up
145+
146+
flow5_up = self.up5(flow5)
147+
f24_w = self.warp(f24, flow5_up*1.25)
148+
cv4 = torch.index_select(self.corr(f14, f24_w), dim=1, index=self.index.to(f14).long())
149+
r14 = self.rconv4(f14)
150+
cat4 = torch.cat([cv4, r14, flow5_up], 1)
151+
flow4 = self.decoder4(cat4) + flow5_up
152+
153+
flow4_up = self.up4(flow4)
154+
f23_w = self.warp(f23, flow4_up*2.5)
155+
cv3 = torch.index_select(self.corr(f13, f23_w), dim=1, index=self.index.to(f13).long())
156+
r13 = self.rconv3(f13)
157+
cat3 = torch.cat([cv3, r13, flow4_up], 1)
158+
flow3 = self.decoder3(cat3) + flow4_up
159+
160+
flow3_up = self.up3(flow3)
161+
f22_w = self.warp(f22, flow3_up*5.0)
162+
cv2 = torch.index_select(self.corr(f12, f22_w), dim=1, index=self.index.to(f12).long())
163+
r12 = self.rconv2(f12)
164+
cat2 = torch.cat([cv2, r12, flow3_up], 1)
165+
flow2 = self.decoder2(cat2) + flow3_up
166+
167+
if self.training:
168+
return flow2, flow3, flow4, flow5, flow6
169+
else:
170+
return flow2

0 commit comments

Comments
 (0)