diff --git a/download_model.py b/download_model.py index 54e4bb60c..0bd4f26ab 100644 --- a/download_model.py +++ b/download_model.py @@ -1,28 +1,42 @@ +"""Downloads GPT-2 Model. + +Options: + python download_model.py 117M|124M|345M|774M|1558M +""" import os import sys import requests from tqdm import tqdm if len(sys.argv) != 2: - print('You must enter the model name as a parameter, e.g.: download_model.py 124M') - sys.exit(1) + print('You must enter the model name as a parameter, e.g.: download_model.py ' + '124M') + sys.exit(1) model = sys.argv[1] subdir = os.path.join('models', model) if not os.path.exists(subdir): - os.makedirs(subdir) -subdir = subdir.replace('\\','/') # needed for Windows - -for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']: - - r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/" + subdir + "/" + filename, stream=True) + os.makedirs(subdir) +subdir = subdir.replace('\\', '/') # needed for Windows - with open(os.path.join(subdir, filename), 'wb') as f: - file_size = int(r.headers["content-length"]) - chunk_size = 1000 - with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar: - # 1k for chunk_size, since Ethernet packet size is around 1500 bytes - for chunk in r.iter_content(chunk_size=chunk_size): - f.write(chunk) - pbar.update(chunk_size) +for filename in [ + 'checkpoint', 'encoder.json', 'hparams.json', + 'model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', + 'vocab.bpe' +]: + url = 'https://openaipublic.blob.core.windows.net/gpt-2/{}/{}'.format( + subdir, filename) + r = requests.get(url, stream=True) + with open(os.path.join(subdir, filename), 'wb') as f: + file_size = int(r.headers['content-length']) + chunk_size = 1000 + with tqdm( + ncols=100, + desc='Fetching ' + filename, + total=file_size, + unit_scale=True) as pbar: + # 1k for chunk_size, since Ethernet packet size is around 1500 bytes. + for chunk in r.iter_content(chunk_size=chunk_size): + f.write(chunk) + pbar.update(chunk_size)