Skip to content

Commit 2dce76c

Browse files
author
Antoine Miech "WILLOW
committed
first commit
0 parents  commit 2dce76c

11 files changed

+1208
-0
lines changed

args.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import argparse
2+
3+
def get_args(description='Youtube-Text-Video'):
4+
parser = argparse.ArgumentParser(description=description)
5+
parser.add_argument(
6+
'--train_csv',
7+
type=str,
8+
default='/sequoia/data2/amiech/iccv19/csv/dataset_v10_1300k.csv',
9+
help='train csv')
10+
parser.add_argument(
11+
'--features_path',
12+
type=str,
13+
default='/local/dataset/youtube_text/2D_features',
14+
help='feature path')
15+
parser.add_argument(
16+
'--features_path_3D',
17+
type=str,
18+
default='/local/dataset/youtube_text/3D_features',
19+
help='feature path')
20+
parser.add_argument(
21+
'--caption_path',
22+
type=str,
23+
default='/sequoia/data2/amiech/iccv19/caption_pickle/caption_howto100m.pickle',
24+
help='caption csv path')
25+
parser.add_argument(
26+
'--word2vec_path',
27+
type=str,
28+
default='/local/dataset/youtube_text/GoogleNews-vectors-negative300.bin',
29+
help='word embedding path')
30+
parser.add_argument(
31+
'--pretrain_path',
32+
type=str,
33+
default='',
34+
help='pre train model path')
35+
parser.add_argument(
36+
'--checkpoint_dir',
37+
type=str,
38+
default='',
39+
help='checkpoint model folder')
40+
parser.add_argument('--sentence_pooling', type=str, default='max',
41+
help='sentence representation')
42+
parser.add_argument('--num_thread_reader', type=int, default=1,
43+
help='')
44+
parser.add_argument('--embd_dim', type=int, default=2048,
45+
help='embedding dim')
46+
parser.add_argument('--lr', type=float, default=0.0001,
47+
help='initial learning rate')
48+
parser.add_argument('--epochs', type=int, default=20,
49+
help='upper epoch limit')
50+
parser.add_argument('--batch_size', type=int, default=256,
51+
help='batch size')
52+
parser.add_argument('--batch_size_val', type=int, default=3500,
53+
help='batch size eval')
54+
parser.add_argument('--lr_decay', type=float, default=0.9,
55+
help='Learning rate exp epoch decay')
56+
parser.add_argument('--n_display', type=int, default=10,
57+
help='Information display frequence')
58+
parser.add_argument('--feature_dim', type=int, default=4096,
59+
help='video feature dimension')
60+
parser.add_argument('--we_dim', type=int, default=300,
61+
help='word embedding dimension')
62+
parser.add_argument('--seed', type=int, default=1,
63+
help='random seed')
64+
parser.add_argument('--verbose', type=int, default=1,
65+
help='')
66+
parser.add_argument('--max_words', type=int, default=20,
67+
help='')
68+
parser.add_argument('--min_words', type=int, default=0,
69+
help='')
70+
parser.add_argument('--feature_framerate', type=int, default=1,
71+
help='')
72+
parser.add_argument('--min_time', type=float, default=5.0,
73+
help='')
74+
parser.add_argument('--margin', type=float, default=0.1,
75+
help='')
76+
parser.add_argument('--hard_negative_rate', type=float, default=0.5,
77+
help='')
78+
parser.add_argument('--negative_weighting', type=int, default=1,
79+
help='')
80+
parser.add_argument('--gpu_mode', type=int, default=1,
81+
help='')
82+
parser.add_argument('--n_pair', type=int, default=1,
83+
help='Num of pair to output from data loader')
84+
parser.add_argument('--eval_lsmdc', type=int, default=0,
85+
help='Evaluate on LSMDC data')
86+
parser.add_argument('--eval_msrvtt', type=int, default=0,
87+
help='Evaluate on MSRVTT data')
88+
parser.add_argument('--lsmdc', type=int, default=0,
89+
help='LSMDC training')
90+
parser.add_argument('--sentence_dim', type=int, default=-1,
91+
help='sentence dimension')
92+
parser.add_argument(
93+
'--youcook_train_path',
94+
type=str,
95+
default='/sequoia/data2/dataset/YouCook2/scripts/train.pkl',
96+
help='')
97+
parser.add_argument(
98+
'--youcook_val_path',
99+
type=str,
100+
default='/sequoia/data2/dataset/YouCook2/scripts/val.pkl',
101+
help='')
102+
parser.add_argument('--youcook', type=int, default=0,
103+
help='')
104+
parser.add_argument('--msrvtt', type=int, default=0,
105+
help='')
106+
parser.add_argument('--eval_youcook', type=int, default=0,
107+
help='')
108+
parser.add_argument(
109+
'--msrvtt_test_csv_path',
110+
type=str,
111+
default='/sequoia/data2/dataset/MSR-VTT_Dataset/test_sentences.csv',
112+
help='')
113+
parser.add_argument(
114+
'--msrvtt_test_features_path',
115+
type=str,
116+
default='/sequoia/data2/dataset/MSR-VTT_Dataset/features.pth',
117+
help='')
118+
parser.add_argument(
119+
'--lsmdc_test_csv_path',
120+
type=str,
121+
default='/sequoia/data2/amiech/MPII/LSMDC16_challenge_1000_publictect.csv',
122+
help='')
123+
parser.add_argument(
124+
'--lsmdc_test_features_path',
125+
type=str,
126+
default='/sequoia/data2/amiech/MPII/features/retrieval_features.pth',
127+
help='')
128+
parser.add_argument(
129+
'--lsmdc_train_csv_path',
130+
type=str,
131+
default='/sequoia/data2/amiech/MPII/LSMDC16_annos_training.csv',
132+
help='')
133+
parser.add_argument(
134+
'--lsmdc_train_features_path',
135+
type=str,
136+
default='/sequoia/data2/amiech/MPII/features/train_features.pth',
137+
help='')
138+
parser.add_argument(
139+
'--lsmdc_val_csv_path',
140+
type=str,
141+
default='/sequoia/data2/amiech/MPII/LSMDC16_annos_val.csv',
142+
help='')
143+
parser.add_argument(
144+
'--lsmdc_val_features_path',
145+
type=str,
146+
default='/sequoia/data2/amiech/MPII/features/val_features.pth',
147+
help='')
148+
args = parser.parse_args()
149+
return args

eval.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import unicode_literals
4+
from __future__ import print_function
5+
6+
import torch as th
7+
from torch.utils.data import DataLoader
8+
import numpy as np
9+
from args import get_args
10+
import random
11+
import os
12+
from youcook_dataloader import Youcook_DataLoader
13+
from model import Net
14+
from metrics import compute_metrics, print_computed_metrics
15+
from gensim.models.keyedvectors import KeyedVectors
16+
import pickle
17+
import glob
18+
from lsmdc_dataloader import LSMDC_DataLoader
19+
from msrvtt_dataloader import MSRVTT_DataLoader
20+
21+
22+
args = get_args()
23+
if args.verbose:
24+
print(args)
25+
26+
assert args.pretrain_path != '', 'Need to specify pretrain_path argument'
27+
28+
# predefining random initial seeds
29+
th.manual_seed(args.seed)
30+
np.random.seed(args.seed)
31+
random.seed(args.seed)
32+
33+
print('Loading word vectors: {}'.format(args.word2vec_path))
34+
we = KeyedVectors.load_word2vec_format(args.word2vec_path, binary=True)
35+
print('done')
36+
37+
38+
if args.eval_youcook:
39+
dataset_val = Youcook_DataLoader(
40+
data=args.youcook_val_path,
41+
we=we,
42+
max_words=args.max_words,
43+
we_dim=args.we_dim,
44+
n_pair=1,
45+
)
46+
dataloader_val = DataLoader(
47+
dataset_val,
48+
batch_size=args.batch_size_val,
49+
num_workers=args.num_thread_reader,
50+
shuffle=False,
51+
)
52+
if args.eval_lsmdc:
53+
dataset_lsmdc = LSMDC_DataLoader(
54+
csv_path=args.lsmdc_test_csv_path,
55+
features_path=args.lsmdc_test_features_path,
56+
we=we,
57+
max_words=args.max_words,
58+
we_dim=args.we_dim,
59+
)
60+
dataloader_lsmdc = DataLoader(
61+
dataset_lsmdc,
62+
batch_size=args.batch_size_val,
63+
num_workers=args.num_thread_reader,
64+
shuffle=False,
65+
)
66+
if args.eval_msrvtt:
67+
msrvtt_testset = MSRVTT_DataLoader(
68+
csv_path='/sequoia/data2/dataset/MSR-VTT_Dataset/test_sentences.csv',
69+
features_path=args.msrvtt_test_features_path,
70+
we=we,
71+
max_words=args.max_words,
72+
we_dim=args.we_dim,
73+
)
74+
dataloader_msrvtt = DataLoader(
75+
msrvtt_testset,
76+
batch_size=3000,
77+
num_workers=args.num_thread_reader,
78+
shuffle=False,
79+
drop_last=False,
80+
)
81+
net = Net(
82+
video_dim=args.feature_dim,
83+
embd_dim=args.embd_dim,
84+
we_dim=args.we_dim,
85+
max_words=args.max_words,
86+
)
87+
net.eval()
88+
# Optimizers + Loss
89+
if args.gpu_mode:
90+
net.cuda()
91+
92+
if args.verbose:
93+
print('Starting evaluation loop ...')
94+
95+
def Eval_msrvtt(model, eval_dataloader):
96+
model.eval()
97+
print ('Evaluating Text-Video retrieval on MSRVTT data')
98+
with th.no_grad():
99+
for i_batch, data in enumerate(eval_dataloader):
100+
text = data['text'].cuda() if args.gpu_mode else data['text']
101+
vid = data['video_id']
102+
video = data['video'].cuda() if args.gpu_mode else data['video']
103+
m = model(video, text)
104+
m = m.cpu().detach().numpy()
105+
metrics = compute_metrics(m)
106+
print_computed_metrics(metrics)
107+
108+
def Eval_lsmdc(model, eval_dataloader):
109+
model.eval()
110+
print ('Evaluating Text-Video retrieval on LSMDC data')
111+
with th.no_grad():
112+
for i_batch, data in enumerate(eval_dataloader):
113+
text = data['text'].cuda() if args.gpu_mode else data['text']
114+
video = data['video'].cuda() if args.gpu_mode else data['video']
115+
vid = data['video_id']
116+
m = model(video, text)
117+
m = m.cpu().detach().numpy()
118+
metrics = compute_metrics(m)
119+
print_computed_metrics(metrics)
120+
121+
122+
def Eval_youcook(model, eval_dataloader):
123+
model.eval()
124+
print ('Evaluating Text-Video retrieval on Youcook data')
125+
with th.no_grad():
126+
for i_batch, data in enumerate(eval_dataloader):
127+
text = data['text'].cuda() if args.gpu_mode else data['text']
128+
video = data['video'].cuda() if args.gpu_mode else data['video']
129+
vid = data['video_id']
130+
m = model(video, text)
131+
m = m.cpu().detach().numpy()
132+
metrics = compute_metrics(m)
133+
print_computed_metrics(metrics)
134+
135+
all_checkpoints = glob.glob(args.pretrain_path)
136+
137+
for c in all_checkpoints:
138+
print('Eval checkpoint: {}'.format(c))
139+
print('Loading checkpoint: {}'.format(c))
140+
net.load_checkpoint(c)
141+
if args.eval_youcook:
142+
Eval_youcook(net, dataloader_val)
143+
if args.eval_msrvtt:
144+
Eval_msrvtt(net, dataloader_msrvtt)
145+
if args.eval_lsmdc:
146+
Eval_lsmdc(net, dataloader_lsmdc)

loss.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import unicode_literals
4+
from __future__ import print_function
5+
6+
import torch.nn.functional as F
7+
import torch as th
8+
import numpy as np
9+
10+
class MaxMarginRankingLoss(th.nn.Module):
11+
def __init__(self,
12+
margin=1.0,
13+
negative_weighting=False,
14+
batch_size=1,
15+
n_pair=1,
16+
hard_negative_rate=0.5,
17+
):
18+
super(MaxMarginRankingLoss, self).__init__()
19+
self.margin = margin
20+
self.n_pair = n_pair
21+
self.batch_size = batch_size
22+
easy_negative_rate = 1 - hard_negative_rate
23+
self.easy_negative_rate = easy_negative_rate
24+
self.negative_weighting = negative_weighting
25+
if n_pair > 1:
26+
alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate))
27+
mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha
28+
mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair)))
29+
mm_mask = th.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate))
30+
self.mm_mask = mm_mask.float().cuda()
31+
32+
33+
def forward(self, x):
34+
d = th.diag(x)
35+
max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \
36+
F.relu(self.margin + x - d.view(1, -1))
37+
if self.negative_weighting:
38+
max_margin = max_margin * self.mm_mask
39+
return max_margin.mean()

0 commit comments

Comments
 (0)