|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "attachments": {}, |
| 5 | + "cell_type": "markdown", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "### Eval on Spider" |
| 9 | + ] |
| 10 | + }, |
| 11 | + { |
| 12 | + "cell_type": "code", |
| 13 | + "execution_count": 3, |
| 14 | + "metadata": {}, |
| 15 | + "outputs": [ |
| 16 | + { |
| 17 | + "name": "stderr", |
| 18 | + "output_type": "stream", |
| 19 | + "text": [ |
| 20 | + "Found cached dataset json (/Users/richardroberson/.cache/huggingface/datasets/richardr1126___json/richardr1126--spider-context-validation-166f6f4d7e17c532/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)\n" |
| 21 | + ] |
| 22 | + }, |
| 23 | + { |
| 24 | + "data": { |
| 25 | + "application/vnd.jupyter.widget-view+json": { |
| 26 | + "model_id": "f6540791f36b442d8dbfe5d0c77f02f2", |
| 27 | + "version_major": 2, |
| 28 | + "version_minor": 0 |
| 29 | + }, |
| 30 | + "text/plain": [ |
| 31 | + "Generating responses: 0%| | 0/418 [00:00<?, ?it/s]" |
| 32 | + ] |
| 33 | + }, |
| 34 | + "metadata": {}, |
| 35 | + "output_type": "display_data" |
| 36 | + }, |
| 37 | + { |
| 38 | + "name": "stdout", |
| 39 | + "output_type": "stream", |
| 40 | + "text": [ |
| 41 | + "Error occurred: The server is overloaded or not ready yet.\n", |
| 42 | + "Waiting for 5 seconds before retrying...\n", |
| 43 | + "Error occurred: The server is overloaded or not ready yet.\n", |
| 44 | + "Waiting for 5 seconds before retrying...\n", |
| 45 | + "Error occurred: The server is overloaded or not ready yet.\n", |
| 46 | + "Waiting for 5 seconds before retrying...\n", |
| 47 | + "Error occurred: The server is overloaded or not ready yet.\n", |
| 48 | + "Waiting for 5 seconds before retrying...\n", |
| 49 | + "Error occurred: The server is overloaded or not ready yet.\n", |
| 50 | + "Waiting for 5 seconds before retrying...\n", |
| 51 | + "Error occurred: The server is overloaded or not ready yet.\n", |
| 52 | + "Waiting for 5 seconds before retrying...\n", |
| 53 | + "Error occurred: The server is overloaded or not ready yet.\n", |
| 54 | + "Waiting for 5 seconds before retrying...\n", |
| 55 | + "Error occurred: The server is overloaded or not ready yet.\n", |
| 56 | + "Waiting for 5 seconds before retrying...\n", |
| 57 | + "Error occurred: The server is overloaded or not ready yet.\n", |
| 58 | + "Waiting for 5 seconds before retrying...\n", |
| 59 | + "Error occurred: The server is overloaded or not ready yet.\n", |
| 60 | + "Waiting for 5 seconds before retrying...\n", |
| 61 | + "Error occurred: The server is overloaded or not ready yet.\n", |
| 62 | + "Waiting for 5 seconds before retrying...\n" |
| 63 | + ] |
| 64 | + } |
| 65 | + ], |
| 66 | + "source": [ |
| 67 | + "import openai\n", |
| 68 | + "import time\n", |
| 69 | + "from datasets import load_dataset\n", |
| 70 | + "from tqdm.notebook import tqdm\n", |
| 71 | + "import os\n", |
| 72 | + "from dotenv import load_dotenv\n", |
| 73 | + "\n", |
| 74 | + "# Load environment variables from .env file\n", |
| 75 | + "load_dotenv()\n", |
| 76 | + "\n", |
| 77 | + "openai.api_key = os.getenv('OPENAI_API_KEY')\n", |
| 78 | + "\n", |
| 79 | + "dataset = load_dataset(\"richardr1126/spider-context-validation\", split=\"validation\")\n", |
| 80 | + "last_line_written = 0\n", |
| 81 | + "\n", |
| 82 | + "for i in tqdm(range(last_line_written, len(dataset)), total=len(dataset)-last_line_written, desc=\"Generating responses\"):\n", |
| 83 | + " example = dataset[i]\n", |
| 84 | + " prompt = example[\"prompt\"]\n", |
| 85 | + " \n", |
| 86 | + " while True:\n", |
| 87 | + " try:\n", |
| 88 | + " response = openai.ChatCompletion.create(\n", |
| 89 | + " model=\"gpt-3.5-turbo\", # or other models available\n", |
| 90 | + " messages=[\n", |
| 91 | + " {\n", |
| 92 | + " \"role\": \"system\",\n", |
| 93 | + " \"content\": \"You are a sophisticated AI assistant capable of converting natural language queries into SQL queries. You'll be given database schema information with tables and columns, followed by a natural language question from the user. Your task is to generate the equivalent SQL query to answer the user's question. Only generate the SQL query, do not add anything other text.\"\n", |
| 94 | + " },\n", |
| 95 | + " {\"role\": \"user\", \"content\": prompt}\n", |
| 96 | + " ],\n", |
| 97 | + " )\n", |
| 98 | + " \n", |
| 99 | + " response_text = response['choices'][0]['message']['content'].strip().replace(\"\\n\", \" \").replace(\"\\t\", \" \")\n", |
| 100 | + " if response_text[-1] == \".\":\n", |
| 101 | + " response_text = response_text[:-1]\n", |
| 102 | + " \n", |
| 103 | + " with open('predictions/chatgpt.txt', 'a') as f:\n", |
| 104 | + " f.write(response_text + \"\\n\")\n", |
| 105 | + " \n", |
| 106 | + " # If we get to this line, it means the operation was successful and we break the while loop\n", |
| 107 | + " break\n", |
| 108 | + " \n", |
| 109 | + " except Exception as e:\n", |
| 110 | + " print(f'Error occurred: {str(e)}')\n", |
| 111 | + " #traceback.print_exc()\n", |
| 112 | + " print('Waiting for 5 seconds before retrying...')\n", |
| 113 | + " time.sleep(10) # Wait for 60 seconds before trying again\n", |
| 114 | + "\n" |
| 115 | + ] |
| 116 | + } |
| 117 | + ], |
| 118 | + "metadata": { |
| 119 | + "kernelspec": { |
| 120 | + "display_name": "llama", |
| 121 | + "language": "python", |
| 122 | + "name": "python3" |
| 123 | + }, |
| 124 | + "language_info": { |
| 125 | + "codemirror_mode": { |
| 126 | + "name": "ipython", |
| 127 | + "version": 3 |
| 128 | + }, |
| 129 | + "file_extension": ".py", |
| 130 | + "mimetype": "text/x-python", |
| 131 | + "name": "python", |
| 132 | + "nbconvert_exporter": "python", |
| 133 | + "pygments_lexer": "ipython3", |
| 134 | + "version": "3.11.4" |
| 135 | + }, |
| 136 | + "orig_nbformat": 4 |
| 137 | + }, |
| 138 | + "nbformat": 4, |
| 139 | + "nbformat_minor": 2 |
| 140 | +} |
0 commit comments