Skip to content

Commit edd7308

Browse files
christ1nepetermattson
authored andcommitted
updating the object detection code (mlcommons#116)
* initial commit of ssd code * some readme fixes * some bugfixes, adding model download * requirements.txt for dockerfile * switching the backbone to R34 * updating ssd300.py file * removing imports that are no longer needed * bug fixes for resnet backbone update (mlcommons#112)
1 parent e91cec7 commit edd7308

13 files changed

+2159
-0
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ We provide reference implementations for each of the 7 benchmarks in the MLPerf
1414

1515
* image_classification - Resnet-50 v1 applied to Imagenet.
1616
* object_detection - Mask R-CNN applied to COCO.
17+
* single_stage_detector - SSD applied to COCO 2017.
1718
* speech_recognition - DeepSpeech2 applied to Librispeech.
1819
* translation - Transformer applied to WMT English-German.
1920
* recommendation - Neural Collaborative Filtering applied to MovieLens 20 Million (ml-20m).

single_stage_detector/Dockerfile

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
FROM pytorch/pytorch:0.4_cuda9_cudnn7
2+
3+
# Set working directory
4+
WORKDIR /mlperf
5+
6+
RUN apt-get update && \
7+
apt-get install -y python3-tk python-pip
8+
9+
# Necessary pip packages
10+
RUN pip install --upgrade pip
11+
RUN pip install Cython==0.28.4 \
12+
matplotlib==2.2.2
13+
RUN python3 -m pip install pycocotools==2.0.0
14+
15+
# Copy SSD code
16+
WORKDIR /mlperf
17+
COPY . .
18+
RUN pip install -r requirements.txt
19+
20+
WORKDIR /mlperf/ssd
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Get COCO 2017 data sets
2+
dir=$(pwd)
3+
mkdir /coco; cd /coco
4+
curl -O http://images.cocodataset.org/zips/train2017.zip; unzip train2017.zip
5+
curl -O http://images.cocodataset.org/zips/val2017.zip; unzip val2017.zip
6+
curl -O http://images.cocodataset.org/annotations/annotations_trainval2017.zip; unzip annotations_trainval2017.zip
7+
cd $dir
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
cycler==0.10.0
2+
kiwisolver==1.0.1
3+
matplotlib==2.2.2
4+
numpy==1.14.5
5+
Pillow==5.2.0
6+
pycocotools==2.0.0
7+
pyparsing==2.2.0
8+
python-dateutil==2.7.3
9+
pytz==2018.5
10+
six==1.11.0
11+
torch==0.4.0
12+
torchvision==0.2.1

single_stage_detector/ssd/README.md

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# 1. Problem
2+
Object detection.
3+
4+
# 2. Directions
5+
6+
### Steps to configure machine
7+
From Source
8+
9+
Standard script.
10+
11+
From Docker
12+
1. Checkout the MLPerf repository
13+
```
14+
git clone https://github.com/mlperf/reference.git
15+
```
16+
2. Install CUDA and Docker
17+
```
18+
source reference/install_cuda_docker.sh
19+
```
20+
3. Build the docker image for the single stage detection task
21+
```
22+
# Build from Dockerfile
23+
cd reference/single_stage_detector/
24+
sudo docker build -t mlperf/single_stage_detector .
25+
```
26+
27+
### Steps to download data
28+
```
29+
cd reference/single_stage_detector/
30+
source download_dataset.sh
31+
```
32+
33+
### Steps to run benchmark.
34+
From Source
35+
36+
Run the run_and_time.sh script
37+
```
38+
cd reference/single_stage_detector/ssd
39+
source run_and_time.sh SEED TARGET
40+
```
41+
where SEED is the random seed for a run, TARGET is the quality target from Section 5 below.
42+
43+
Docker Image
44+
```
45+
sudo nvidia-docker run -v /coco:/coco -t -i --rm --ipc=host mlperf/single_stage_detector ./run_and_time.sh SEED TARGET
46+
```
47+
48+
# 3. Dataset/Environment
49+
### Publiction/Attribution.
50+
Microsoft COCO: COmmon Objects in Context. 2017.
51+
52+
### Training and test data separation
53+
Train on 2017 COCO train data set, compute mAP on 2017 COCO val data set.
54+
55+
# 4. Model.
56+
### Publication/Attribution
57+
Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang Fu, Alexander C. Berg. SSD: Single Shot MultiBox Detector. In the Proceedings of the European Conference on Computer Vision (ECCV), 2016.
58+
59+
Backbone is ResNet34 pretrained on ILSVRC 2012 (from torchvision). Modifications to the backbone networks: remove conv_5x residual blocks, change the first 3x3 convolution of the conv_4x block from stride 2 to stride1 (this increases the resolution of the feature map to which detector heads are attached), attach all 6 detector heads to the output of the last conv_4x residual block. Thus detections are attached to 38x38, 19x19, 10x10, 5x5, 3x3, and 1x1 feature maps. Convolutions in the detector layers are followed by batch normalization layers.
60+
61+
# 5. Quality.
62+
### Quality metric
63+
Metric is COCO box mAP (averaged over IoU of 0.5:0.95), computed over 2017 COCO val data.
64+
65+
### Quality target
66+
mAP of 0.212
67+
68+
### Evaluation frequency
69+
70+
### Evaluation thoroughness
71+
All the images in COCO 2017 val data set.
+206
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
"""
2+
Load the vgg16 weight and save it to special file
3+
"""
4+
5+
#from torchvision.models.vgg import vgg16
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
import torch
9+
from torch.autograd import Variable
10+
from collections import OrderedDict
11+
12+
from torchvision.models.resnet import resnet18, resnet34, resnet50
13+
14+
def _ModifyConvStrideDilation(conv, stride=(1, 1), padding=None):
15+
conv.stride = stride
16+
17+
if padding is not None:
18+
conv.padding = padding
19+
20+
def _ModifyBlock(block, bottleneck=False, **kwargs):
21+
for m in list(block.children()):
22+
if bottleneck:
23+
_ModifyConvStrideDilation(m.conv2, **kwargs)
24+
else:
25+
_ModifyConvStrideDilation(m.conv1, **kwargs)
26+
27+
if m.downsample is not None:
28+
# need to make sure no padding for the 1x1 residual connection
29+
_ModifyConvStrideDilation(list(m.downsample.children())[0], **kwargs)
30+
31+
class ResNet18(nn.Module):
32+
def __init__(self):
33+
super().__init__()
34+
rn18 = resnet18(pretrained=True)
35+
36+
37+
# discard last Resnet block, avrpooling and classification FC
38+
# layer1 = up to and including conv3 block
39+
self.layer1 = nn.Sequential(*list(rn18.children())[:6])
40+
# layer2 = conv4 block only
41+
self.layer2 = nn.Sequential(*list(rn18.children())[6:7])
42+
43+
# modify conv4 if necessary
44+
# Always deal with stride in first block
45+
modulelist = list(self.layer2.children())
46+
_ModifyBlock(modulelist[0], stride=(1,1))
47+
48+
def forward(self, data):
49+
layer1_activation = self.layer1(data)
50+
x = layer1_activation
51+
layer2_activation = self.layer2(x)
52+
53+
# Only need the output of conv4
54+
return [layer2_activation]
55+
56+
class ResNet34(nn.Module):
57+
def __init__(self):
58+
super().__init__()
59+
rn34 = resnet34(pretrained=True)
60+
61+
# discard last Resnet block, avrpooling and classification FC
62+
self.layer1 = nn.Sequential(*list(rn34.children())[:6])
63+
self.layer2 = nn.Sequential(*list(rn34.children())[6:7])
64+
# modify conv4 if necessary
65+
# Always deal with stride in first block
66+
modulelist = list(self.layer2.children())
67+
_ModifyBlock(modulelist[0], stride=(1,1))
68+
69+
70+
def forward(self, data):
71+
layer1_activation = self.layer1(data)
72+
x = layer1_activation
73+
layer2_activation = self.layer2(x)
74+
75+
return [layer2_activation]
76+
77+
class L2Norm(nn.Module):
78+
"""
79+
Scale shall be learnable according to original paper
80+
scale: initial scale number
81+
chan_num: L2Norm channel number (norm over all channels)
82+
"""
83+
def __init__(self, scale=20, chan_num=512):
84+
super(L2Norm, self).__init__()
85+
# Scale across channels
86+
self.scale = \
87+
nn.Parameter(torch.Tensor([scale]*chan_num).view(1, chan_num, 1, 1))
88+
89+
def forward(self, data):
90+
# normalize accross channel
91+
return self.scale*data*data.pow(2).sum(dim=1, keepdim=True).clamp(min=1e-12).rsqrt()
92+
93+
94+
95+
def tailor_module(src_model, src_dir, tgt_model, tgt_dir):
96+
state = torch.load(src_dir)
97+
src_model.load_state_dict(state)
98+
src_state = src_model.state_dict()
99+
# only need features
100+
keys1 = src_state.keys()
101+
keys1 = [k for k in src_state.keys() if k.startswith("features")]
102+
keys2 = tgt_model.state_dict().keys()
103+
104+
assert len(keys1) == len(keys2)
105+
state = OrderedDict()
106+
107+
for k1, k2 in zip(keys1, keys2):
108+
# print(k1, k2)
109+
state[k2] = src_state[k1]
110+
#diff_keys = state.keys() - target_model.state_dict().keys()
111+
#print("Different Keys:", diff_keys)
112+
# Remove unecessary keys
113+
#for k in diff_keys:
114+
# state.pop(k)
115+
tgt_model.load_state_dict(state)
116+
torch.save(tgt_model.state_dict(), tgt_dir)
117+
118+
# Default vgg16 in pytorch seems different from ssd
119+
def make_layers(cfg, batch_norm=False):
120+
layers = []
121+
in_channels = 3
122+
for v in cfg:
123+
if v == 'M':
124+
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
125+
elif v == 'C':
126+
# Notice ceil_mode is true
127+
layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
128+
else:
129+
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
130+
if batch_norm:
131+
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
132+
else:
133+
layers += [conv2d, nn.ReLU(inplace=True)]
134+
in_channels = v
135+
return layers
136+
137+
class Loss(nn.Module):
138+
"""
139+
Implements the loss as the sum of the followings:
140+
1. Confidence Loss: All labels, with hard negative mining
141+
2. Localization Loss: Only on positive labels
142+
Suppose input dboxes has the shape 8732x4
143+
"""
144+
145+
def __init__(self, dboxes):
146+
super(Loss, self).__init__()
147+
self.scale_xy = 1.0/dboxes.scale_xy
148+
self.scale_wh = 1.0/dboxes.scale_wh
149+
150+
self.sl1_loss = nn.SmoothL1Loss(reduce=False)
151+
self.dboxes = nn.Parameter(dboxes(order="xywh").transpose(0, 1).unsqueeze(dim = 0),
152+
requires_grad=False)
153+
# Two factor are from following links
154+
# http://jany.st/post/2017-11-05-single-shot-detector-ssd-from-scratch-in-tensorflow.html
155+
self.con_loss = nn.CrossEntropyLoss(reduce=False)
156+
157+
def _loc_vec(self, loc):
158+
"""
159+
Generate Location Vectors
160+
"""
161+
gxy = self.scale_xy*(loc[:, :2, :] - self.dboxes[:, :2, :])/self.dboxes[:, 2:, ]
162+
gwh = self.scale_wh*(loc[:, 2:, :]/self.dboxes[:, 2:, :]).log()
163+
164+
return torch.cat((gxy, gwh), dim=1).contiguous()
165+
166+
def forward(self, ploc, plabel, gloc, glabel):
167+
"""
168+
ploc, plabel: Nx4x8732, Nxlabel_numx8732
169+
predicted location and labels
170+
171+
gloc, glabel: Nx4x8732, Nx8732
172+
ground truth location and labels
173+
"""
174+
175+
mask = glabel > 0
176+
pos_num = mask.sum(dim=1)
177+
178+
vec_gd = self._loc_vec(gloc)
179+
180+
# sum on four coordinates, and mask
181+
sl1 = self.sl1_loss(ploc, vec_gd).sum(dim=1)
182+
sl1 = (mask.float()*sl1).sum(dim=1)
183+
184+
# hard negative mining
185+
con = self.con_loss(plabel, glabel)
186+
187+
# postive mask will never selected
188+
con_neg = con.clone()
189+
con_neg[mask] = 0
190+
_, con_idx = con_neg.sort(dim=1, descending=True)
191+
_, con_rank = con_idx.sort(dim=1)
192+
193+
# number of negative three times positive
194+
neg_num = torch.clamp(3*pos_num, max=mask.size(1)).unsqueeze(-1)
195+
neg_mask = con_rank < neg_num
196+
197+
closs = (con*(mask.float() + neg_mask.float())).sum(dim=1)
198+
199+
# avoid no object detected
200+
total_loss = sl1 + closs
201+
num_mask = (pos_num > 0).float()
202+
pos_num = pos_num.float().clamp(min=1e-6)
203+
204+
ret = (total_loss*num_mask/pos_num).mean(dim=0)
205+
return ret
206+

0 commit comments

Comments
 (0)