Skip to content

Commit 47582b5

Browse files
committed
eval ChatGPT
1 parent 252668a commit 47582b5

16 files changed

+4848
-1141
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
.sqlite
44
eval/data/database
55
.DS_Store
6+
.env
67
models/*

eval/evaluation.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,8 @@ def evaluate(gold, predict, db_dir, etype, kmaps, plug_value, keep_distinct, pro
616616
turn_scores['exec'].append(1)
617617
else:
618618
turn_scores['exec'].append(0)
619-
incorrect_log_file.write(f"exec_score: {exec_score}\n")
619+
incorrect_log_file.write(f"index: {len(turn_scores['exec'])}\n")
620+
incorrect_log_file.write(f"db_id: {db_name}\n") # write to the log file
620621
incorrect_log_file.write("{} pred: {}\n".format(hardness, p_str)) # write to the log file
621622
incorrect_log_file.write("{} gold: {}\n\n".format(hardness, g_str)) # write to the log file
622623

eval/gen_predictions_chatgpt.ipynb

+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
{
2+
"cells": [
3+
{
4+
"attachments": {},
5+
"cell_type": "markdown",
6+
"metadata": {},
7+
"source": [
8+
"### Eval on Spider"
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": 3,
14+
"metadata": {},
15+
"outputs": [
16+
{
17+
"name": "stderr",
18+
"output_type": "stream",
19+
"text": [
20+
"Found cached dataset json (/Users/richardroberson/.cache/huggingface/datasets/richardr1126___json/richardr1126--spider-context-validation-166f6f4d7e17c532/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)\n"
21+
]
22+
},
23+
{
24+
"data": {
25+
"application/vnd.jupyter.widget-view+json": {
26+
"model_id": "f6540791f36b442d8dbfe5d0c77f02f2",
27+
"version_major": 2,
28+
"version_minor": 0
29+
},
30+
"text/plain": [
31+
"Generating responses: 0%| | 0/418 [00:00<?, ?it/s]"
32+
]
33+
},
34+
"metadata": {},
35+
"output_type": "display_data"
36+
},
37+
{
38+
"name": "stdout",
39+
"output_type": "stream",
40+
"text": [
41+
"Error occurred: The server is overloaded or not ready yet.\n",
42+
"Waiting for 5 seconds before retrying...\n",
43+
"Error occurred: The server is overloaded or not ready yet.\n",
44+
"Waiting for 5 seconds before retrying...\n",
45+
"Error occurred: The server is overloaded or not ready yet.\n",
46+
"Waiting for 5 seconds before retrying...\n",
47+
"Error occurred: The server is overloaded or not ready yet.\n",
48+
"Waiting for 5 seconds before retrying...\n",
49+
"Error occurred: The server is overloaded or not ready yet.\n",
50+
"Waiting for 5 seconds before retrying...\n",
51+
"Error occurred: The server is overloaded or not ready yet.\n",
52+
"Waiting for 5 seconds before retrying...\n",
53+
"Error occurred: The server is overloaded or not ready yet.\n",
54+
"Waiting for 5 seconds before retrying...\n",
55+
"Error occurred: The server is overloaded or not ready yet.\n",
56+
"Waiting for 5 seconds before retrying...\n",
57+
"Error occurred: The server is overloaded or not ready yet.\n",
58+
"Waiting for 5 seconds before retrying...\n",
59+
"Error occurred: The server is overloaded or not ready yet.\n",
60+
"Waiting for 5 seconds before retrying...\n",
61+
"Error occurred: The server is overloaded or not ready yet.\n",
62+
"Waiting for 5 seconds before retrying...\n"
63+
]
64+
}
65+
],
66+
"source": [
67+
"import openai\n",
68+
"import time\n",
69+
"from datasets import load_dataset\n",
70+
"from tqdm.notebook import tqdm\n",
71+
"import os\n",
72+
"from dotenv import load_dotenv\n",
73+
"\n",
74+
"# Load environment variables from .env file\n",
75+
"load_dotenv()\n",
76+
"\n",
77+
"openai.api_key = os.getenv('OPENAI_API_KEY')\n",
78+
"\n",
79+
"dataset = load_dataset(\"richardr1126/spider-context-validation\", split=\"validation\")\n",
80+
"last_line_written = 0\n",
81+
"\n",
82+
"for i in tqdm(range(last_line_written, len(dataset)), total=len(dataset)-last_line_written, desc=\"Generating responses\"):\n",
83+
" example = dataset[i]\n",
84+
" prompt = example[\"prompt\"]\n",
85+
" \n",
86+
" while True:\n",
87+
" try:\n",
88+
" response = openai.ChatCompletion.create(\n",
89+
" model=\"gpt-3.5-turbo\", # or other models available\n",
90+
" messages=[\n",
91+
" {\n",
92+
" \"role\": \"system\",\n",
93+
" \"content\": \"You are a sophisticated AI assistant capable of converting natural language queries into SQL queries. You'll be given database schema information with tables and columns, followed by a natural language question from the user. Your task is to generate the equivalent SQL query to answer the user's question. Only generate the SQL query, do not add anything other text.\"\n",
94+
" },\n",
95+
" {\"role\": \"user\", \"content\": prompt}\n",
96+
" ],\n",
97+
" )\n",
98+
" \n",
99+
" response_text = response['choices'][0]['message']['content'].strip().replace(\"\\n\", \" \").replace(\"\\t\", \" \")\n",
100+
" if response_text[-1] == \".\":\n",
101+
" response_text = response_text[:-1]\n",
102+
" \n",
103+
" with open('predictions/chatgpt.txt', 'a') as f:\n",
104+
" f.write(response_text + \"\\n\")\n",
105+
" \n",
106+
" # If we get to this line, it means the operation was successful and we break the while loop\n",
107+
" break\n",
108+
" \n",
109+
" except Exception as e:\n",
110+
" print(f'Error occurred: {str(e)}')\n",
111+
" #traceback.print_exc()\n",
112+
" print('Waiting for 5 seconds before retrying...')\n",
113+
" time.sleep(10) # Wait for 60 seconds before trying again\n",
114+
"\n"
115+
]
116+
}
117+
],
118+
"metadata": {
119+
"kernelspec": {
120+
"display_name": "llama",
121+
"language": "python",
122+
"name": "python3"
123+
},
124+
"language_info": {
125+
"codemirror_mode": {
126+
"name": "ipython",
127+
"version": 3
128+
},
129+
"file_extension": ".py",
130+
"mimetype": "text/x-python",
131+
"name": "python",
132+
"nbconvert_exporter": "python",
133+
"pygments_lexer": "ipython3",
134+
"version": "3.11.4"
135+
},
136+
"orig_nbformat": 4
137+
},
138+
"nbformat": 4,
139+
"nbformat_minor": 2
140+
}

eval/gen_predictions_hf_spaces.ipynb

+4-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
},
1111
{
1212
"cell_type": "code",
13-
"execution_count": 1,
13+
"execution_count": 2,
1414
"metadata": {},
1515
"outputs": [
1616
{
@@ -30,12 +30,12 @@
3030
{
3131
"data": {
3232
"application/vnd.jupyter.widget-view+json": {
33-
"model_id": "7c01e3c8d04d495fae30a81b3fb11f52",
33+
"model_id": "29726280a79640efaadadb9123b51436",
3434
"version_major": 2,
3535
"version_minor": 0
3636
},
3737
"text/plain": [
38-
"Generating responses: 0%| | 0/631 [00:00<?, ?it/s]"
38+
"Generating responses: 0%| | 0/87 [00:00<?, ?it/s]"
3939
]
4040
},
4141
"metadata": {},
@@ -51,7 +51,7 @@
5151
"client = Client(\"https://richardr1126-sql-skeleton-wizardcoder-demo.hf.space/\")\n",
5252
"\n",
5353
"dataset = load_dataset(\"richardr1126/spider-context-validation\", split=\"validation\")\n",
54-
"last_line_written = 403\n",
54+
"last_line_written = 0\n",
5555
"\n",
5656
"\n",
5757
"for i in tqdm(range(last_line_written, len(dataset)), total=len(dataset)-last_line_written, desc=\"Generating responses\"):\n",

eval/gen_predictions_hf_spaces2.ipynb

-104
This file was deleted.

0 commit comments

Comments
 (0)