We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 110266b commit 4e19fa9Copy full SHA for 4e19fa9
eval.py
@@ -76,10 +76,12 @@ def main(args):
76
77
checkpoint = torch.load(args.checkpoint , map_location=lambda storage, loc: storage)
78
config = checkpoint['config']
79
+ src_tok, target_tok = checkpoint['tokenizers'].values()
80
+
81
args.data_config = literal_eval(args.data_config)
82
dataset = getattr(datasets, args.dataset)
83
+ args.data_config['tokenizers'] = checkpoint['tokenizers']
84
val_data = dataset(args.dataset_dir, split='dev', **args.data_config)
- src_tok, target_tok = checkpoint['tokenizers'].values()
85
86
model = getattr(models, config.model)(**config.model_config)
87
model.load_state_dict(checkpoint['state_dict'])
0 commit comments