Skip to content

Commit 72aa233

Browse files
committed
Update
1 parent a671c9a commit 72aa233

9 files changed

+66680
-80
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/database
22
.sql
33
.sqlite
4-
eval/data/database
4+
eval/data/database
5+
.DS_Store

data/check_tokens.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import json
2+
3+
# Load the json data
4+
with open('train_sql_skeleton.json') as f:
5+
data = json.load(f)
6+
7+
# Initial longest token count is 0
8+
longest_token_count = 0
9+
10+
# Initial total token count is 0
11+
total_token_count = 0
12+
13+
longest_text = ""
14+
15+
for item in data:
16+
# Get the 'train_instruct' field
17+
train_instruct = item['text']
18+
19+
# Calculate the token count by dividing the character count by 3.6 and rounding up
20+
token_count = round(len(train_instruct) / 3.6)
21+
22+
# Add the token count to the total
23+
total_token_count += token_count
24+
25+
# If this token count is the longest, update longest_token_count
26+
if token_count > longest_token_count:
27+
longest_token_count = token_count
28+
longest_text = train_instruct
29+
30+
# Calculate the average token count
31+
average_token_count = total_token_count / len(data)
32+
33+
print(f"The longest token count for 'train_instruct' in the dataset is {longest_token_count}.")
34+
print(f"The average token count for 'train_instruct' in the dataset is {average_token_count}.")
35+
print(f"The longest text for 'train_instruct' in the dataset is {longest_text}.")

data/train_sql.json

+28,002
Large diffs are not rendered by default.

data/train_sql_skeleton.json

+28,002
Large diffs are not rendered by default.

data/validation_sql.json

+5,172
Large diffs are not rendered by default.

data/validation_sql_skeleton.json

+5,172
Large diffs are not rendered by default.

eval/generate_predict_eval.ipynb

+9-17
Original file line numberDiff line numberDiff line change
@@ -10,47 +10,39 @@
1010
},
1111
{
1212
"cell_type": "code",
13-
"execution_count": 10,
13+
"execution_count": 3,
1414
"metadata": {},
1515
"outputs": [
1616
{
1717
"name": "stderr",
1818
"output_type": "stream",
1919
"text": [
20-
"Found cached dataset json (/Users/richardroberson/.cache/huggingface/datasets/richardr1126___json/richardr1126--spider-natsql-context-validation-b246ae2fc7e9e5cb/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)\n"
20+
"Found cached dataset json (/Users/richardroberson/.cache/huggingface/datasets/richardr1126___json/richardr1126--spider-context-validation-8fba68e4e3727374/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)\n"
2121
]
2222
},
2323
{
2424
"data": {
2525
"application/vnd.jupyter.widget-view+json": {
26-
"model_id": "642df2eb63a3414782afd0b0074db10a",
26+
"model_id": "3f604c3cd41a4f3ba02e0405f22bb98d",
2727
"version_major": 2,
2828
"version_minor": 0
2929
},
3030
"text/plain": [
31-
"Generating responses: 0%| | 0/498 [00:00<?, ?it/s]"
31+
"Generating responses: 0%| | 0/1027 [00:00<?, ?it/s]"
3232
]
3333
},
3434
"metadata": {},
3535
"output_type": "display_data"
3636
},
3737
{
38-
"ename": "JSONDecodeError",
39-
"evalue": "Expecting value: line 1 column 1 (char 0)",
38+
"ename": "KeyError",
39+
"evalue": "'results'",
4040
"output_type": "error",
4141
"traceback": [
4242
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
43-
"\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)",
44-
"File \u001b[0;32m~/miniforge3/envs/llm/lib/python3.11/site-packages/requests/models.py:971\u001b[0m, in \u001b[0;36mResponse.json\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 970\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 971\u001b[0m \u001b[39mreturn\u001b[39;00m complexjson\u001b[39m.\u001b[39;49mloads(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtext, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 972\u001b[0m \u001b[39mexcept\u001b[39;00m JSONDecodeError \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 973\u001b[0m \u001b[39m# Catch JSON-related errors and raise as requests.JSONDecodeError\u001b[39;00m\n\u001b[1;32m 974\u001b[0m \u001b[39m# This aliases json.JSONDecodeError and simplejson.JSONDecodeError\u001b[39;00m\n",
45-
"File \u001b[0;32m~/miniforge3/envs/llm/lib/python3.11/json/__init__.py:346\u001b[0m, in \u001b[0;36mloads\u001b[0;34m(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m 343\u001b[0m \u001b[39mif\u001b[39;00m (\u001b[39mcls\u001b[39m \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m object_hook \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m\n\u001b[1;32m 344\u001b[0m parse_int \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m parse_float \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m\n\u001b[1;32m 345\u001b[0m parse_constant \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m object_pairs_hook \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m kw):\n\u001b[0;32m--> 346\u001b[0m \u001b[39mreturn\u001b[39;00m _default_decoder\u001b[39m.\u001b[39;49mdecode(s)\n\u001b[1;32m 347\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mcls\u001b[39m \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n",
46-
"File \u001b[0;32m~/miniforge3/envs/llm/lib/python3.11/json/decoder.py:337\u001b[0m, in \u001b[0;36mJSONDecoder.decode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Return the Python representation of ``s`` (a ``str`` instance\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[39mcontaining a JSON document).\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \n\u001b[1;32m 336\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m--> 337\u001b[0m obj, end \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mraw_decode(s, idx\u001b[39m=\u001b[39;49m_w(s, \u001b[39m0\u001b[39;49m)\u001b[39m.\u001b[39;49mend())\n\u001b[1;32m 338\u001b[0m end \u001b[39m=\u001b[39m _w(s, end)\u001b[39m.\u001b[39mend()\n",
47-
"File \u001b[0;32m~/miniforge3/envs/llm/lib/python3.11/json/decoder.py:355\u001b[0m, in \u001b[0;36mJSONDecoder.raw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mStopIteration\u001b[39;00m \u001b[39mas\u001b[39;00m err:\n\u001b[0;32m--> 355\u001b[0m \u001b[39mraise\u001b[39;00m JSONDecodeError(\u001b[39m\"\u001b[39m\u001b[39mExpecting value\u001b[39m\u001b[39m\"\u001b[39m, s, err\u001b[39m.\u001b[39mvalue) \u001b[39mfrom\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 356\u001b[0m \u001b[39mreturn\u001b[39;00m obj, end\n",
48-
"\u001b[0;31mJSONDecodeError\u001b[0m: Expecting value: line 1 column 1 (char 0)",
49-
"\nDuring handling of the above exception, another exception occurred:\n",
50-
"\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)",
51-
"Cell \u001b[0;32mIn[10], line 30\u001b[0m\n\u001b[1;32m 27\u001b[0m headers \u001b[39m=\u001b[39m {\u001b[39m\"\u001b[39m\u001b[39mContent-Type\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39mapplication/json\u001b[39m\u001b[39m\"\u001b[39m}\n\u001b[1;32m 29\u001b[0m response \u001b[39m=\u001b[39m requests\u001b[39m.\u001b[39mpost(url, json\u001b[39m=\u001b[39mpayload, headers\u001b[39m=\u001b[39mheaders)\n\u001b[0;32m---> 30\u001b[0m response_text \u001b[39m=\u001b[39m response\u001b[39m.\u001b[39;49mjson()[\u001b[39m\"\u001b[39m\u001b[39mresults\u001b[39m\u001b[39m\"\u001b[39m][\u001b[39m0\u001b[39m][\u001b[39m\"\u001b[39m\u001b[39mtext\u001b[39m\u001b[39m\"\u001b[39m]\n\u001b[1;32m 31\u001b[0m response_text \u001b[39m=\u001b[39m response_text\u001b[39m.\u001b[39mreplace(\u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 33\u001b[0m \u001b[39m# append the result to 'results.txt'\u001b[39;00m\n",
52-
"File \u001b[0;32m~/miniforge3/envs/llm/lib/python3.11/site-packages/requests/models.py:975\u001b[0m, in \u001b[0;36mResponse.json\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 971\u001b[0m \u001b[39mreturn\u001b[39;00m complexjson\u001b[39m.\u001b[39mloads(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtext, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m 972\u001b[0m \u001b[39mexcept\u001b[39;00m JSONDecodeError \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 973\u001b[0m \u001b[39m# Catch JSON-related errors and raise as requests.JSONDecodeError\u001b[39;00m\n\u001b[1;32m 974\u001b[0m \u001b[39m# This aliases json.JSONDecodeError and simplejson.JSONDecodeError\u001b[39;00m\n\u001b[0;32m--> 975\u001b[0m \u001b[39mraise\u001b[39;00m RequestsJSONDecodeError(e\u001b[39m.\u001b[39mmsg, e\u001b[39m.\u001b[39mdoc, e\u001b[39m.\u001b[39mpos)\n",
53-
"\u001b[0;31mJSONDecodeError\u001b[0m: Expecting value: line 1 column 1 (char 0)"
43+
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
44+
"Cell \u001b[0;32mIn[3], line 30\u001b[0m\n\u001b[1;32m 27\u001b[0m headers \u001b[39m=\u001b[39m {\u001b[39m\"\u001b[39m\u001b[39mContent-Type\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39mapplication/json\u001b[39m\u001b[39m\"\u001b[39m}\n\u001b[1;32m 29\u001b[0m response \u001b[39m=\u001b[39m requests\u001b[39m.\u001b[39mpost(url, json\u001b[39m=\u001b[39mpayload, headers\u001b[39m=\u001b[39mheaders)\n\u001b[0;32m---> 30\u001b[0m response_text \u001b[39m=\u001b[39m response\u001b[39m.\u001b[39;49mjson()[\u001b[39m\"\u001b[39;49m\u001b[39mresults\u001b[39;49m\u001b[39m\"\u001b[39;49m][\u001b[39m0\u001b[39m][\u001b[39m\"\u001b[39m\u001b[39mtext\u001b[39m\u001b[39m\"\u001b[39m]\n\u001b[1;32m 31\u001b[0m response_text \u001b[39m=\u001b[39m response_text\u001b[39m.\u001b[39mreplace(\u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 33\u001b[0m \u001b[39m# append the result to 'results.txt'\u001b[39;00m\n",
45+
"\u001b[0;31mKeyError\u001b[0m: 'results'"
5446
]
5547
}
5648
],
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import json
2-
31
import json
42
import argparse
53

6-
def process_dataset(input_dataset_path, output_dataset_path, mode):
4+
def process_dataset(input_dataset_path, output_dataset_path, mode, sql_type, use_skeleton):
75
# Load the input dataset
86
dataset = json.load(open(input_dataset_path, "r"))
97
output_dataset = []
@@ -21,45 +19,47 @@ def process_dataset(input_dataset_path, output_dataset_path, mode):
2119
for fk in data["fk"]:
2220
input_sequence += fk["source_table_name_original"]+"."+fk["source_column_name_original"]+" = "+fk["target_table_name_original"]+"."+fk["target_column_name_original"] + " | "
2321

24-
output_sequence = data["natsql"]
22+
if sql_type == "natsql":
23+
output_sequence = data["natsql_skeleton"] + " | " + data["natsql"] if use_skeleton else data["natsql"]
24+
else: # regular sql
25+
output_sequence = data["sql_skeleton"] + " | " + data["norm_sql"] if use_skeleton else data["norm_sql"]
2526

2627
# Generate text for training mode, prompt and ground_truth for validation mode
2728
if mode == "train":
28-
text = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n\nConvert text to NatSQL: " + input_sequence + "\n\n" + "### Response:\n\n" + output_sequence
29+
text = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n\nConvert text to {sql_type}: " + input_sequence + "\n\n" + "### Response:\n\n" + output_sequence
2930
output_dataset.append({
3031
"db_id": db_id,
3132
"text": text,
32-
#"tc_original": tc_original
3333
})
3434
else: # validation mode
35-
prompt = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n\nConvert text to NatSQL: " + input_sequence + "\n\n### Response:\n\n"
35+
prompt = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n\nConvert text to {sql_type}: " + input_sequence + "\n\n### Response:\n\n"
3636
ground_truth = output_sequence
3737
output_dataset.append({
3838
"db_id": db_id,
3939
"prompt": prompt,
4040
"ground_truth": ground_truth,
41-
#"tc_original": tc_original
4241
})
4342

4443
# Save the output dataset
4544
with open(output_dataset_path, "w") as f:
4645
json.dump(output_dataset, f, indent=2, ensure_ascii=False)
4746

48-
def main(mode):
47+
def main(mode, sql_type, use_skeleton):
4948
if mode == "train":
50-
process_dataset("./data/preprocessed/preprocessed_train_spider_natsql.json", "./data/train.json", "train")
49+
process_dataset("./data/preprocessed/preprocessed_train_spider_natsql.json", f"./data/train_{sql_type}{'_skeleton' if use_skeleton else ''}.json", mode, sql_type, use_skeleton)
5150
elif mode == "validation":
52-
process_dataset("./data/preprocessed/preprocessed_dev_natsql.json", "./data/validation.json", "validation")
51+
process_dataset("./data/preprocessed/preprocessed_dev_natsql.json", f"./data/validation_{sql_type}{'_skeleton' if use_skeleton else ''}.json", mode, sql_type, use_skeleton)
5352
elif mode == "both":
54-
process_dataset("./data/preprocessed/preprocessed_train_spider_natsql.json", "./data/train.json", "train")
55-
process_dataset("./data/preprocessed/preprocessed_dev_natsql.json", "./data/validation.json", "validation")
53+
process_dataset("./data/preprocessed/preprocessed_train_spider_natsql.json", f"./data/train_{sql_type}{'_skeleton' if use_skeleton else ''}.json", "train", sql_type, use_skeleton)
54+
process_dataset("./data/preprocessed/preprocessed_dev_natsql.json", f"./data/validation_{sql_type}{'_skeleton' if use_skeleton else ''}.json", "validation", sql_type, use_skeleton)
5655
else:
5756
print("Specify mode flag with `--mode [train / validation / both].")
5857

5958
if __name__ == "__main__":
6059
parser = argparse.ArgumentParser()
61-
parser.add_argument('--mode', type=str, required=True, help="Specify mode flag with `--mode [train / validation / both].")
60+
parser.add_argument('--mode', type=str, default="both", help="Specify mode flag with `--mode [train / validation / both].")
61+
parser.add_argument('--sql_type', type=str, required=True, help="Specify SQL type with `--sql_type [natsql / sql].")
62+
parser.add_argument('--skeleton', action='store_true', default=False, help="Use SQL skeleton in the output sequence.")
6263
args = parser.parse_args()
6364

64-
main(args.mode)
65-
65+
main(args.mode, args.sql_type, args.skeleton)

0 commit comments

Comments
 (0)