-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbert_test.py
53 lines (41 loc) · 1.78 KB
/
bert_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from transformers import pipeline
from transformers import AutoModelForMaskedLM, AutoTokenizer
def load_model_and_tokenizer():
# Load a BERT-based model fine-tuned for masked language modeling
model_name = "bert-base-uncased" # You can change this to another BERT model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
return tokenizer, model
def correct_spelling(misspelled_text, misspelled_words):
"""
Correct misspelled words by processing the input as:
"misspelled: <original sentence>"
"corrected: <sentence with [MASK]>"
"""
# Prepare the input text for BERT
text_correction_template = " ".join(
word if word not in misspelled_words else "[MASK]"
for word in misspelled_text.split()
)
input_text = f"misspelled: {misspelled_text} [SEP] spell corrected: {text_correction_template}"
# Create a pipeline for masked language modeling
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
# Get predictions for the masked token
predictions = fill_mask(input_text)
# Use the top prediction to replace [MASK]
corrected_word = predictions[0]['token_str']
# Replace [MASK] with the predicted word in the corrected sentence
corrected_text = input_text.replace("[MASK]", corrected_word)
return corrected_text
# Load the model and tokenizer
tokenizer, model = load_model_and_tokenizer()
# Input text with a misspelled word
# Perform spelling correction
misspelled_text="I went to the shop because my laptop changing cable was broken"
misspelled_words=["changing"]
corrected_text = correct_spelling(
misspelled_text,
misspelled_words
)
print("Model input:\n", misspelled_text, "\n")
print("Model output:\n", corrected_text)