Skip to content

Commit cb07bfe

Browse files
committed
First Commit
1 parent 8a9b296 commit cb07bfe

27 files changed

+708
-2
lines changed

LICENSE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2022 cohere.ai
3+
Copyright (c) 2022 Cohere Inc. and its affiliates.
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal
@@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1818
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1919
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2020
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21-
SOFTWARE.
21+
SOFTWARE.

README.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
```
2+
################################################################################
3+
# ____ _ ____ _ _ #
4+
# / ___|___ | |__ ___ _ __ ___ / ___| __ _ _ __ __| | |__ _____ __ #
5+
# | | / _ \| '_ \ / _ \ '__/ _ \ \___ \ / _` | '_ \ / _` | '_ \ / _ \ \/ / #
6+
# | |__| (_) | | | | __/ | | __/ ___) | (_| | | | | (_| | |_) | (_) > < #
7+
# \____\___/|_| |_|\___|_| \___| |____/ \__,_|_| |_|\__,_|_.__/ \___/_/\_\ #
8+
# #
9+
# This project is part of Cohere Sandbox, Cohere's Experimental Open Source #
10+
# offering. This project provides a library, tooling, or demo making use of #
11+
# the Cohere Platform. You should expect (self-)documented, high quality code #
12+
# but be warned that this is EXPERIMENTAL. Therefore, also expect rough edges, #
13+
# non-backwards compatible changes, or potential changes in functionality as #
14+
# the library, tool, or demo evolves. Please consider referencing a specific #
15+
# git commit or version if depending upon the project in any mission-critical #
16+
# code as part of your own projects. #
17+
# #
18+
# Please don't hesitate to raise issues or submit pull requests, and thanks #
19+
# for checking out this project! #
20+
# #
21+
################################################################################
22+
```
23+
24+
**Maintainer:** [nickfrosst](https://github.com/nickfrosst) \
25+
**Project maintained until at least (YYYY-MM-DD):** 2023-01-01
26+
27+
# Grounded Question Answering — A Cohere Sandbox Project
28+
29+
This is a cohere API powered contextualized factual question answering bot!
30+
31+
It responds to question in discord by understanding the context, google
32+
searching what it believes to be the appropriate question, finding relevant
33+
information on the google result pages and then answering the question based on
34+
what it found.
35+
36+
## Installation
37+
1- Clone the repository.
38+
39+
2- Install all the dependencies:
40+
41+
```pip install -r requirements.txt```
42+
43+
4- Running the streamlit demo
44+
Try the demo by running the cli tool
45+
46+
```python3 cli_demo.py --cohere_api_key <API_KEY> --serp_api_key <API_KEY>```
47+
48+
# License
49+
COHERE-GROUNDED-QA has an MIT license, as found in the LICENSE file.

__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) 2022 Cohere Inc. and its affiliates.
2+
#
3+
# Licensed under the MIT License (the "License");
4+
# you may not use this file except in compliance with the License.
5+
#
6+
# You may obtain a copy of the License in the LICENSE file at the top
7+
# level of this repository.

cli_demo.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) 2022 Cohere Inc. and its affiliates.
2+
#
3+
# Licensed under the MIT License (the "License");
4+
# you may not use this file except in compliance with the License.
5+
#
6+
# You may obtain a copy of the License in the LICENSE file at the top
7+
# level of this repository.
8+
9+
# this is a cli demo of you the bot. You can run it and ask questions directly in your terminal
10+
11+
import argparse
12+
13+
from qa.bot import GroundedQaBot
14+
15+
parser = argparse.ArgumentParser(description="A grounded QA bot with cohere and google search")
16+
parser.add_argument("--cohere_api_key", type=str, help="api key for cohere", required=True)
17+
parser.add_argument("--serp_api_key", type=str, help="api key for serpAPI", required=True)
18+
parser.add_argument("--verbosity", type=int, default=0, help="verbosity level")
19+
args = parser.parse_args()
20+
21+
bot = GroundedQaBot(args.cohere_api_key, args.serp_api_key)
22+
23+
if __name__ == "__main__":
24+
while True:
25+
question = input("question: ")
26+
reply = bot.answer(question, verbosity=args.verbosity, n_paragraphs=2)
27+
print("answer: " + reply)

discord_bot.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) 2022 Cohere Inc. and its affiliates.
2+
#
3+
# Licensed under the MIT License (the "License");
4+
# you may not use this file except in compliance with the License.
5+
#
6+
# You may obtain a copy of the License in the LICENSE file at the top
7+
# level of this repository.
8+
9+
# this is a demo discord bot. You can make a discord bot token by visiting https://discord.com/developers
10+
11+
import argparse
12+
13+
import discord
14+
from discord import Embed
15+
from discord.ext import commands
16+
17+
from qa.bot import GroundedQaBot
18+
19+
parser = argparse.ArgumentParser(description="A grounded QA bot with cohere and google search")
20+
parser.add_argument("--cohere_api_key", type=str, help="api key for cohere", required=True)
21+
parser.add_argument("--serp_api_key", type=str, help="api key for serpAPI", required=True)
22+
parser.add_argument("--discord_key", type=str, help="api key for discord boat", required=True)
23+
parser.add_argument("--verbosity", type=int, default=0, help="verbosity level")
24+
args = parser.parse_args()
25+
26+
bot = GroundedQaBot(args.cohere_api_key, args.serp_api_key)
27+
28+
29+
class MyClient(discord.Client):
30+
31+
async def on_ready(self):
32+
"""Initializes bot"""
33+
print(f"We have logged in as {self.user}")
34+
35+
for guild in self.guilds:
36+
print(f"{self.user} is connected to the following guild:\n"
37+
f"{guild.name}(id: {guild.id})")
38+
39+
async def answer(self, message):
40+
"""Answers a question based on the context of the conversation and information from the web"""
41+
history = []
42+
async for historic_msg in message.channel.history(limit=6, before=message):
43+
if historic_msg.content:
44+
name = "user"
45+
if historic_msg.author.name == self.user.name:
46+
name = "bot"
47+
history = [f"{name}: {historic_msg.clean_content}"] + history
48+
49+
print(history)
50+
bot.set_chat_history(history)
51+
52+
async with message.channel.typing():
53+
reply = bot.answer(message.clean_content, verbosity=2)
54+
response_msg = await message.channel.send(reply, reference=message)
55+
await response_msg.edit(suppress=True)
56+
return
57+
58+
async def on_message(self, message):
59+
"""Handles query messages triggered by direct messages to the bot"""
60+
if isinstance(message.channel, discord.channel.DMChannel) and message.author != self.user:
61+
await self.answer(message)
62+
63+
async def on_reaction_add(self, reaction, user):
64+
"""Handles query messages triggered by emoji from user."""
65+
if user != self.user:
66+
if str(reaction.emoji) == "❓" and reaction.count == 1:
67+
await self.answer(reaction.message)
68+
69+
70+
if __name__ == "__main__":
71+
intents = discord.Intents.all()
72+
client = MyClient(intents=intents)
73+
client.run(args.discord_key)

qa/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) 2022 Cohere Inc. and its affiliates.
2+
#
3+
# Licensed under the MIT License (the "License");
4+
# you may not use this file except in compliance with the License.
5+
#
6+
# You may obtain a copy of the License in the LICENSE file at the top
7+
# level of this repository.
153 Bytes
Binary file not shown.
151 Bytes
Binary file not shown.

qa/__pycache__/answer.cpython-310.pyc

3.04 KB
Binary file not shown.

qa/__pycache__/answer.cpython-39.pyc

3.11 KB
Binary file not shown.

qa/__pycache__/bot.cpython-310.pyc

1.9 KB
Binary file not shown.

qa/__pycache__/bot.cpython-39.pyc

1.89 KB
Binary file not shown.

qa/__pycache__/model.cpython-310.pyc

1.87 KB
Binary file not shown.

qa/__pycache__/model.cpython-39.pyc

1.8 KB
Binary file not shown.

qa/__pycache__/search.cpython-310.pyc

5.28 KB
Binary file not shown.

qa/__pycache__/search.cpython-39.pyc

5.25 KB
Binary file not shown.

qa/__pycache__/util.cpython-310.pyc

528 Bytes
Binary file not shown.

qa/__pycache__/util.cpython-39.pyc

526 Bytes
Binary file not shown.

qa/answer.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) 2022 Cohere Inc. and its affiliates.
2+
#
3+
# Licensed under the MIT License (the "License");
4+
# you may not use this file except in compliance with the License.
5+
#
6+
# You may obtain a copy of the License in the LICENSE file at the top
7+
# level of this repository.
8+
9+
import numpy as np
10+
11+
from qa.model import get_sample_answer
12+
from qa.search import embedding_search, get_results_paragraphs_multi_process
13+
from qa.util import pretty_print
14+
15+
16+
def trim_stop_sequences(s, stop_sequences):
17+
"""Remove stop sequences found at the end of returned generated text."""
18+
19+
for stop_sequence in stop_sequences:
20+
if s.endswith(stop_sequence):
21+
return s[:-len(stop_sequence)]
22+
return s
23+
24+
25+
def answer(question, context, co, model, chat_history=""):
26+
"""Answer a question given some context."""
27+
28+
prompt = ("This is an example of question answering based on a text passage:\n "
29+
f"Context:-{context}\nQuestion:\n-{question}\nAnswer:\n-")
30+
if chat_history:
31+
prompt = ("This is an example of factual question answering chat bot. It "
32+
"takes the text context and answers related questions:\n "
33+
f"Context:-{context}\nChat Log\n{chat_history}\nbot:")
34+
35+
stop_sequences = ["\n"]
36+
37+
num_generations = 4
38+
prompt = "".join(co.tokenize(text=prompt).token_strings[-1900:])
39+
prediction = co.generate(model=model,
40+
prompt=prompt,
41+
max_tokens=100,
42+
temperature=0.3,
43+
stop_sequences=stop_sequences,
44+
num_generations=num_generations,
45+
return_likelihoods="GENERATION")
46+
generations = [[
47+
trim_stop_sequences(prediction.generations[i].text.strip(), stop_sequences),
48+
prediction.generations[i].likelihood
49+
] for i in range(num_generations)]
50+
generations = list(filter(lambda x: not x[0].isspace(), generations))
51+
response = generations[np.argmax([g[1] for g in generations])][0]
52+
return response.strip()
53+
54+
55+
def answer_with_search(question,
56+
co,
57+
serp_api_token,
58+
chat_history="",
59+
model="xlarge",
60+
embedding_model="small",
61+
url=None,
62+
n_paragraphs=1,
63+
verbosity=0):
64+
"""Generates completion based on search results."""
65+
66+
paragraphs, paragraph_sources = get_results_paragraphs_multi_process(question, serp_api_token, url=url)
67+
if not paragraphs:
68+
return ("", "", "")
69+
sample_answer = get_sample_answer(question, co)
70+
71+
results = embedding_search(paragraphs, paragraph_sources, sample_answer, co, model=embedding_model)
72+
73+
if verbosity > 1:
74+
pprint_results = "\n".join([r[0] for r in results])
75+
pretty_print("OKGREEN", f"all search result context: {pprint_results}")
76+
77+
results = results[-n_paragraphs:]
78+
context = "\n".join([r[0] for r in results])
79+
80+
if verbosity:
81+
pretty_print("OKCYAN", "relevant result context: " + context)
82+
83+
response = answer(question, context, co, chat_history=chat_history, model=model)
84+
85+
return (response, [r[1] for r in results], [r[0] for r in results])

qa/bot.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) 2022 Cohere Inc. and its affiliates.
2+
#
3+
# Licensed under the MIT License (the "License");
4+
# you may not use this file except in compliance with the License.
5+
#
6+
# You may obtain a copy of the License in the LICENSE file at the top
7+
# level of this repository.
8+
9+
from sys import settrace
10+
11+
import cohere
12+
13+
from qa.answer import answer_with_search
14+
from qa.model import get_contextual_search_query
15+
from qa.util import pretty_print
16+
17+
18+
class GroundedQaBot():
19+
"""A class yielding Grounded question-answering conversational agents."""
20+
21+
def __init__(self, cohere_api_key, serp_api_key):
22+
self._cohere_api_key = cohere_api_key
23+
self._serp_api_key = serp_api_key
24+
self._chat_history = []
25+
self._co = cohere.Client(self._cohere_api_key)
26+
27+
@property
28+
def chat_history(self):
29+
return self._chat_history
30+
31+
def set_chat_history(self, chat_history):
32+
self._chat_history = chat_history
33+
34+
def answer(self, question, verbosity=0, n_paragraphs=1):
35+
"""Answer a question, based on recent conversational history."""
36+
37+
self.chat_history.append("user: " + question)
38+
39+
history = "\n".join(self.chat_history[-6:])
40+
question = get_contextual_search_query(history, self._co, verbosity=verbosity)
41+
42+
answer_text, source_urls, source_texts = answer_with_search(question,
43+
self._co,
44+
self._serp_api_key,
45+
verbosity=verbosity,
46+
n_paragraphs=n_paragraphs)
47+
48+
self._chat_history.append("bot: " + answer_text)
49+
50+
if not source_texts or "".join(source_texts).isspace():
51+
reply = ("Sorry, I could not find any relevant information for that "
52+
"question.")
53+
elif answer_text.strip() == question.strip():
54+
reply = ("I had trouble answering the question, but maybe this link on "
55+
"the right will help.")
56+
else:
57+
sources_str = "\n--".join(source_urls)
58+
reply = f"{answer_text}\nSource: {sources_str}"
59+
60+
return reply

qa/model.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2022 Cohere Inc. and its affiliates.
2+
#
3+
# Licensed under the MIT License (the "License");
4+
# you may not use this file except in compliance with the License.
5+
#
6+
# You may obtain a copy of the License in the LICENSE file at the top
7+
# level of this repository.
8+
9+
import os
10+
11+
import cohere
12+
import numpy as np
13+
from cohere.classify import Example
14+
15+
from qa.util import pretty_print
16+
17+
_DATA_DIRNAME = os.path.join(os.path.dirname(__file__), "prompt_data")
18+
19+
20+
def get_contextual_search_query(history, co, model="xlarge", verbosity=0):
21+
"""Adds message history context to user query."""
22+
23+
prompt_path = os.path.join(_DATA_DIRNAME, "get_contextual_search_query.prompt")
24+
with open(prompt_path) as f:
25+
prompt = f.read() + f"{history}\n-"
26+
prediction = co.generate(
27+
model=model,
28+
prompt=prompt,
29+
max_tokens=50,
30+
temperature=0.75,
31+
k=0,
32+
p=0.75,
33+
frequency_penalty=0,
34+
presence_penalty=0,
35+
stop_sequences=["\n"],
36+
return_likelihoods="GENERATION",
37+
num_generations=4,
38+
)
39+
likelihood = [g.likelihood for g in prediction.generations]
40+
result = prediction.generations[np.argmax(likelihood)].text
41+
if verbosity:
42+
pretty_print("OKGREEN", "contextual question prompt: " + prompt)
43+
pretty_print("OKCYAN", "contextual question: " + result)
44+
return result.strip()
45+
46+
47+
def get_sample_answer(question, co, model="xlarge"):
48+
"""Return a sample answer to a question based on the model's training data.
49+
"""
50+
51+
prompt_path = os.path.join(_DATA_DIRNAME, "get_sample_answer.prompt")
52+
with open(prompt_path) as f:
53+
prompt = f.read() + f"{question}\nAnswer:"
54+
response = co.generate(model=model,
55+
prompt=prompt,
56+
max_tokens=50,
57+
temperature=0.8,
58+
k=0,
59+
p=0.7,
60+
stop_sequences=["--"])
61+
62+
return response.generations[0].text

0 commit comments

Comments
 (0)