diff --git a/5-2.BERT/BERT.py b/5-2.BERT/BERT.py index 1a7f625..b4c0c9c 100644 --- a/5-2.BERT/BERT.py +++ b/5-2.BERT/BERT.py @@ -209,7 +209,10 @@ def forward(self, input_ids, segment_ids, masked_pos): optimizer = optim.Adam(model.parameters(), lr=0.001) batch = make_batch() - input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch)) + if torch.cuda.is_available(): + input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.cuda.LongTensor, zip(*batch)) + else: + input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch)) for epoch in range(100): optimizer.zero_grad() @@ -224,7 +227,10 @@ def forward(self, input_ids, segment_ids, masked_pos): optimizer.step() # Predict mask tokens ans isNext - input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(batch[0])) + if torch.cuda.is_available(): + input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.cuda.LongTensor, zip(batch[0])) + else: + input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(batch[0])) print(text) print([number_dict[w.item()] for w in input_ids[0] if number_dict[w.item()] != '[PAD]'])