-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
164 lines (130 loc) · 5.93 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import ast
import os.path
import pandas as pd
from langchain.schema.messages import HumanMessage
from langchain_community.chat_models import ChatOpenAI
from langchain_core.output_parsers import JsonOutputParser
from chains import create_concept_chain, create_augment_chain
from prompt import get_concept_prompt, get_aug_questions_prompt, concept_prompt_start_msg, concept_prompt_end_msg, \
sys_prompt_msg
from structure import Process, AugmentedQuestions
from utils import encode_image, dump_to_json
def generate(
q: str = "竹籃裡有24顆蘋果,紅蘋果有6顆,其他是青蘋果,青蘋果有幾顆?",
grade: str = "first",
hint: str = "單元:減法的應用問題(99以內)。老師的話:其他是白湯圓,表示要求剩下來的湯圓",
n_questions: int = 5
):
parser = JsonOutputParser(pydantic_object=Process)
aug_parser = JsonOutputParser(pydantic_object=AugmentedQuestions)
concept_prompt = get_concept_prompt(output_format=parser.get_format_instructions())
aug_questions_prompt = get_aug_questions_prompt(output_format=aug_parser.get_format_instructions())
chat = ChatOpenAI(temperature=0,
model="gpt-4-vision-preview",
max_tokens=1024)
chain = concept_prompt | chat | parser
aug_chain = aug_questions_prompt | chat | aug_parser
result = chain.invoke({"question": q, "grade": grade, "hint": hint})
# print(result)
for c in result['concepts']:
q_results = aug_chain.invoke({"concept": c, "n_questions": n_questions})
c['sample_questions'] += q_results['questions']
result['question'] = q
dump_to_json(result)
print(result)
def generate_w_images(
q: str = "下圖中的虛線是對稱軸的話,請寫出編號。",
grade: str = "fifth",
hint: str = "單元:認識線對稱圖形和對稱軸。老師的話:沿著虛線摺摺看,摺線兩側可以使圖形,完全疊合,虛線就是對稱軸。",
n_questions: int = 5,
image_path: str = "data/images/ex5-1.jpg"
):
parser = JsonOutputParser(pydantic_object=Process)
aug_parser = JsonOutputParser(pydantic_object=AugmentedQuestions)
aug_questions_prompt = get_aug_questions_prompt(output_format=aug_parser.get_format_instructions())
chat = ChatOpenAI(temperature=0,
model="gpt-4-vision-preview",
max_tokens=1024)
# Getting the base64 string
base64_image = encode_image(image_path)
chain = chat | parser
aug_chain = aug_questions_prompt | chat | aug_parser
inputs = [
sys_prompt_msg.format(), # SystemMessage
concept_prompt_start_msg.format(question=q, hint=hint),
HumanMessage(
content=[
{"type": "text", "text": "IMAGES:"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "auto",
},
},
]
),
concept_prompt_end_msg.format(format_instructions=parser.get_format_instructions(), grade=grade)
]
result = chain.invoke(inputs)
# print(result)
for c in result['concepts']:
q_results = aug_chain.invoke({"concept": c, "n_questions": n_questions})
c['sample_questions'] += q_results['questions']
result['question'] = q
dump_to_json(result)
print(result)
def generate_from_csv(
csv_file: str = "data/question_list.csv",
image_folder: str = "data/images",
n_questions: int = 5,
):
dtypes = {'unit': str, 'grade': str, 'question': str, 'hint': str, 'images': object}
df = pd.read_csv(csv_file, dtype=dtypes)
df.fillna({'hint': '', 'unit': ''}, inplace=True)
question_list = df['question'].to_list()
grade_list = df['grade'].to_list()
images_list = df['images'].apply(ast.literal_eval).to_list()
hint_list = [f"{u} {h}" for u, h in zip(df['unit'].to_list(), df['hint'].to_list())]
chain_q = create_concept_chain()
chain_img_q = create_concept_chain(prompt_template=False)
aug_chain = create_augment_chain()
parser = JsonOutputParser(pydantic_object=Process)
results = []
for q, image_files, grade, hint in zip(question_list, images_list, grade_list, hint_list):
if q:
if image_files:
msg = HumanMessage(content=[{"type": "text", "text": "IMAGES:"}])
for img_f in image_files:
# Getting the base64 string
base64_image = encode_image(os.path.join(image_folder, img_f))
img_data = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "auto",
},
}
msg.content.append(img_data)
inputs = [
sys_prompt_msg.format(), # SystemMessage
concept_prompt_start_msg.format(question=q, hint=hint),
msg,
concept_prompt_end_msg.format(format_instructions=parser.get_format_instructions(),
grade=grade)
]
result = chain_img_q.invoke(inputs)
else:
result = chain_q.invoke({"question": q, "grade": grade, "hint": hint})
for c in result['concepts']:
q_results = aug_chain.invoke({"concept": c, "n_questions": n_questions})
c['sample_questions'] += q_results['questions']
result = {'question': q, 'concepts': result['concepts']}
else:
result = {'question': q}
results.append(result)
dump_to_json({'results': results})
if __name__ == '__main__':
generate()
generate_w_images()
generate_from_csv()