Skip to content

Commit 556c00d

Browse files
Fix optimizer for Reader training (#307)
* fix optimizer for Reader training * bump version * tmp fix for lr schedule * fixed train and updated tuto
1 parent 4dd0ff4 commit 556c00d

File tree

4 files changed

+149
-23
lines changed

4 files changed

+149
-23
lines changed

cdqa/reader/bertqa_sklearn.py

+31-13
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,15 @@
3131
import torch
3232
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
3333
from torch.utils.data.distributed import DistributedSampler
34+
from torch.optim.lr_scheduler import LambdaLR
3435
from tqdm.autonotebook import tqdm, trange
3536

3637
from transformers import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
3738

3839
from transformers import BertForQuestionAnswering, DistilBertForQuestionAnswering
3940
from transformers import BertConfig, DistilBertConfig
4041
from transformers import BertTokenizer, DistilBertTokenizer
41-
from transformers import AdamW, WarmupLinearSchedule
42+
from transformers import AdamW
4243
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
4344

4445
from sklearn.base import BaseEstimator, TransformerMixin
@@ -934,6 +935,16 @@ def _n_best_predictions(final_predictions_sorted, n):
934935
final_prediction_list.append(curr_pred)
935936
return final_prediction_list
936937

938+
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
939+
""" Create a schedule with a learning rate that decreases linearly after
940+
linearly increasing during a warmup period.
941+
"""
942+
def lr_lambda(current_step):
943+
if current_step < num_warmup_steps:
944+
return float(current_step) / float(max(1, num_warmup_steps))
945+
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
946+
947+
return LambdaLR(optimizer, lr_lambda, last_epoch)
937948

938949
class BertProcessor(BaseEstimator, TransformerMixin):
939950
"""
@@ -1052,6 +1063,10 @@ class BertQA(BaseEstimator):
10521063
warmup_proportion : float, optional
10531064
Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%%
10541065
of training. (the default is 0.1)
1066+
warmup_steps : int, optional
1067+
Linear warmup over warmup_steps.
1068+
adam_epsilon : float
1069+
Epsilon for Adam optimizer. (default: 1e-8)
10551070
n_best_size : int, optional
10561071
The total number of n-best predictions to generate in the nbest_predictions.json
10571072
output file. (the default is 20)
@@ -1122,6 +1137,8 @@ def __init__(
11221137
learning_rate=5e-5,
11231138
num_train_epochs=3.0,
11241139
warmup_proportion=0.1,
1140+
warmup_steps=0,
1141+
adam_epsilon=1e-8,
11251142
n_best_size=20,
11261143
max_answer_length=30,
11271144
verbose_logging=False,
@@ -1145,6 +1162,8 @@ def __init__(
11451162
self.learning_rate = learning_rate
11461163
self.num_train_epochs = num_train_epochs
11471164
self.warmup_proportion = warmup_proportion
1165+
self.warmup_steps = warmup_steps
1166+
self.adam_epsilon = adam_epsilon
11481167
self.n_best_size = n_best_size
11491168
self.max_answer_length = max_answer_length
11501169
self.verbose_logging = verbose_logging
@@ -1344,12 +1363,8 @@ def fit(self, X, y=None):
13441363
warmup=self.warmup_proportion, t_total=num_train_optimization_steps
13451364
)
13461365
else:
1347-
optimizer = BertAdam(
1348-
optimizer_grouped_parameters,
1349-
lr=self.learning_rate,
1350-
warmup=self.warmup_proportion,
1351-
t_total=num_train_optimization_steps,
1352-
)
1366+
optimizer = AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=self.adam_epsilon)
1367+
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=num_train_optimization_steps)
13531368

13541369
self.model.train()
13551370
for _ in trange(int(self.num_train_epochs), desc="Epoch"):
@@ -1364,12 +1379,14 @@ def fit(self, X, y=None):
13641379
batch = tuple(
13651380
t.to(self.device) for t in batch
13661381
) # multi-gpu does scattering it-self
1367-
input_ids, input_mask, segment_ids, start_positions, end_positions = (
1368-
batch
1369-
)
1370-
loss = self.model(
1371-
input_ids, segment_ids, input_mask, start_positions, end_positions
1372-
)
1382+
inputs = {'input_ids': batch[0],
1383+
'attention_mask': batch[1],
1384+
'start_positions': batch[3],
1385+
'end_positions': batch[4]}
1386+
if 'distilbert' not in self.bert_model:
1387+
inputs['token_type_ids'] = batch[2]
1388+
outputs = self.model(**inputs)
1389+
loss = outputs[0]
13731390
if self.n_gpu > 1:
13741391
loss = loss.mean() # mean() to average on multi-gpu.
13751392
if self.gradient_accumulation_steps > 1:
@@ -1389,6 +1406,7 @@ def fit(self, X, y=None):
13891406
for param_group in optimizer.param_groups:
13901407
param_group["lr"] = lr_this_step
13911408
optimizer.step()
1409+
scheduler.step() # Update learning rate schedule
13921410
optimizer.zero_grad()
13931411
global_step += 1
13941412

examples/tutorial-train-reader-squad.ipynb

+116-8
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,9 @@
3535
"name": "stderr",
3636
"output_type": "stream",
3737
"text": [
38-
"/home/supercalculateur/source/andre/cdqa-dev/env-cdqa/lib/python3.6/site-packages/sklearn/externals/joblib/__init__.py:15: DeprecationWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.\n",
39-
" warnings.warn(msg, category=DeprecationWarning)\n",
4038
"/home/supercalculateur/source/andre/cdqa-dev/env-cdqa/lib/python3.6/site-packages/tqdm/autonotebook/__init__.py:18: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
41-
" \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n"
39+
" \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n",
40+
"I1120 11:43:47.615704 140657575868224 file_utils.py:39] PyTorch version 1.2.0 available.\n"
4241
]
4342
}
4443
],
@@ -99,11 +98,20 @@
9998
"ExecuteTime": {
10099
"end_time": "2019-07-20T13:58:36.512980Z",
101100
"start_time": "2019-07-20T13:46:44.792080Z"
102-
}
101+
},
102+
"collapsed": true
103103
},
104-
"outputs": [],
104+
"outputs": [
105+
{
106+
"name": "stderr",
107+
"output_type": "stream",
108+
"text": [
109+
"I1120 11:43:48.194295 140657575868224 tokenization_utils.py:375] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/supercalculateur/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n"
110+
]
111+
}
112+
],
105113
"source": [
106-
"train_processor = BertProcessor(do_lower_case=True, is_training=True, n_jobs=-1)\n",
114+
"train_processor = BertProcessor(do_lower_case=True, is_training=True)\n",
107115
"train_examples, train_features = train_processor.fit_transform(X='./data/SQuAD_1.1/train-v1.1.json')"
108116
]
109117
},
@@ -116,9 +124,109 @@
116124
},
117125
{
118126
"cell_type": "code",
119-
"execution_count": null,
127+
"execution_count": 4,
120128
"metadata": {},
121-
"outputs": [],
129+
"outputs": [
130+
{
131+
"name": "stderr",
132+
"output_type": "stream",
133+
"text": [
134+
"I1120 11:43:53.164162 140657575868224 configuration_utils.py:152] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/supercalculateur/.cache/torch/transformers/distributed_-1/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c\n",
135+
"I1120 11:43:53.165523 140657575868224 configuration_utils.py:169] Model config {\n",
136+
" \"attention_probs_dropout_prob\": 0.1,\n",
137+
" \"finetuning_task\": null,\n",
138+
" \"hidden_act\": \"gelu\",\n",
139+
" \"hidden_dropout_prob\": 0.1,\n",
140+
" \"hidden_size\": 768,\n",
141+
" \"initializer_range\": 0.02,\n",
142+
" \"intermediate_size\": 3072,\n",
143+
" \"is_decoder\": false,\n",
144+
" \"layer_norm_eps\": 1e-12,\n",
145+
" \"max_position_embeddings\": 512,\n",
146+
" \"num_attention_heads\": 12,\n",
147+
" \"num_hidden_layers\": 12,\n",
148+
" \"num_labels\": 2,\n",
149+
" \"output_attentions\": false,\n",
150+
" \"output_hidden_states\": false,\n",
151+
" \"output_past\": true,\n",
152+
" \"pruned_heads\": {},\n",
153+
" \"torchscript\": false,\n",
154+
" \"type_vocab_size\": 2,\n",
155+
" \"use_bfloat16\": false,\n",
156+
" \"vocab_size\": 30522\n",
157+
"}\n",
158+
"\n",
159+
"I1120 11:43:53.591548 140657575868224 modeling_utils.py:383] loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin from cache at /home/supercalculateur/.cache/torch/transformers/distributed_-1/aa1ef1aede4482d0dbcd4d52baad8ae300e60902e88fcb0bebdec09afd232066.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n",
160+
"I1120 11:43:55.430284 140657575868224 modeling_utils.py:453] Weights of BertForQuestionAnswering not initialized from pretrained model: ['qa_outputs.weight', 'qa_outputs.bias']\n",
161+
"I1120 11:43:55.431005 140657575868224 modeling_utils.py:456] Weights from pretrained model not used in BertForQuestionAnswering: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n"
162+
]
163+
},
164+
{
165+
"data": {
166+
"application/vnd.jupyter.widget-view+json": {
167+
"model_id": "9fb44d7dc6854474a6dc36ea50168573",
168+
"version_major": 2,
169+
"version_minor": 0
170+
},
171+
"text/plain": [
172+
"HBox(children=(IntProgress(value=0, description='Epoch', max=2, style=ProgressStyle(description_width='initial…"
173+
]
174+
},
175+
"metadata": {},
176+
"output_type": "display_data"
177+
},
178+
{
179+
"data": {
180+
"application/vnd.jupyter.widget-view+json": {
181+
"model_id": "f50802fcf50043f1bfa008c9b911d3df",
182+
"version_major": 2,
183+
"version_minor": 0
184+
},
185+
"text/plain": [
186+
"HBox(children=(IntProgress(value=0, description='Iteration', max=4, style=ProgressStyle(description_width='ini…"
187+
]
188+
},
189+
"metadata": {},
190+
"output_type": "display_data"
191+
},
192+
{
193+
"data": {
194+
"application/vnd.jupyter.widget-view+json": {
195+
"model_id": "c8eaa69941804829bc7c2c984487f7d2",
196+
"version_major": 2,
197+
"version_minor": 0
198+
},
199+
"text/plain": [
200+
"HBox(children=(IntProgress(value=0, description='Iteration', max=4, style=ProgressStyle(description_width='ini…"
201+
]
202+
},
203+
"metadata": {},
204+
"output_type": "display_data"
205+
},
206+
{
207+
"name": "stdout",
208+
"output_type": "stream",
209+
"text": [
210+
"\n"
211+
]
212+
},
213+
{
214+
"data": {
215+
"text/plain": [
216+
"BertQA(adam_epsilon=1e-08, bert_model='bert-base-uncased', do_lower_case=True,\n",
217+
" fp16=False, gradient_accumulation_steps=1, learning_rate=3e-05,\n",
218+
" local_rank=-1, loss_scale=0, max_answer_length=30, n_best_size=20,\n",
219+
" no_cuda=False, null_score_diff_threshold=0.0, num_train_epochs=2,\n",
220+
" output_dir='models', predict_batch_size=8, seed=42, server_ip='',\n",
221+
" server_port='', train_batch_size=12, verbose_logging=False,\n",
222+
" version_2_with_negative=False, warmup_proportion=0.1, warmup_steps=0)"
223+
]
224+
},
225+
"execution_count": 4,
226+
"metadata": {},
227+
"output_type": "execute_result"
228+
}
229+
],
122230
"source": [
123231
"reader = BertQA(train_batch_size=12,\n",
124232
" learning_rate=3e-5,\n",

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ flask_cors==3.0.8
33
joblib==0.13.2
44
pandas==0.25.0
55
prettytable==0.7.2
6-
transformers==2.1.1
6+
transformers>=2.1.1
77
scikit_learn==0.21.2
88
tika==1.19
99
torch>=1.2.0

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def read(file):
88

99
setup(
1010
name="cdqa",
11-
version="1.3.6",
11+
version="1.3.7",
1212
author="Félix MIKAELIAN, André FARIAS, Matyas AMROUCHE, Olivier SANS, Théo NAZON",
1313
description="An End-To-End Closed Domain Question Answering System",
1414
long_description=read("README.md"),

0 commit comments

Comments
 (0)