|
1 |
| -from codeinterpreterapi.chains import remove_download_link, get_file_modifications |
| 1 | +import asyncio |
| 2 | + |
| 3 | +from codeinterpreterapi.chains import ( |
| 4 | + aget_file_modifications, |
| 5 | + aremove_download_link, |
| 6 | + get_file_modifications, |
| 7 | + remove_download_link, |
| 8 | +) |
| 9 | +from langchain_openai import ChatOpenAI |
| 10 | + |
| 11 | +llm = ChatOpenAI(model="gpt-3.5-turbo") |
| 12 | + |
| 13 | +remove_download_link_example = ( |
| 14 | + "I have created the plot to your dataset.\n\n" |
| 15 | + "Link to the file [here](sandbox:/plot.png)." |
| 16 | +) |
| 17 | + |
| 18 | +base_code = """ |
| 19 | + import matplotlib.pyplot as plt |
| 20 | +
|
| 21 | + x = list(range(1, 11)) |
| 22 | + y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77] |
| 23 | +
|
| 24 | + plt.plot(x, y, marker='o') |
| 25 | + plt.xlabel('Index') |
| 26 | + plt.ylabel('Value') |
| 27 | + plt.title('Data Plot') |
| 28 | + """ |
| 29 | +code_with_mod = base_code + "\nplt.savefig('plot.png')" |
| 30 | + |
| 31 | +code_no_mod = base_code + "\nplt.show()" |
2 | 32 |
|
3 | 33 |
|
4 | 34 | def test_remove_download_link() -> None:
|
5 |
| - example = ( |
6 |
| - "I have created the plot to your dataset.\n\n" |
7 |
| - "Link to the file [here](sandbox:/plot.png)." |
8 |
| - ) |
9 | 35 | assert (
|
10 |
| - remove_download_link(example).formatted_response.strip() |
| 36 | + remove_download_link(remove_download_link_example, llm=llm).strip() |
11 | 37 | == "I have created the plot to your dataset."
|
12 | 38 | )
|
13 | 39 |
|
14 | 40 |
|
15 |
| -def test_get_file_modifications() -> None: |
16 |
| - base_code = """ |
17 |
| - import matplotlib.pyplot as plt |
| 41 | +async def test_remove_download_link_async() -> None: |
| 42 | + assert ( |
| 43 | + await aremove_download_link(remove_download_link_example, llm=llm) |
| 44 | + ).strip() == "I have created the plot to your dataset." |
18 | 45 |
|
19 |
| - x = list(range(1, 11)) |
20 |
| - y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77] |
21 | 46 |
|
22 |
| - plt.plot(x, y, marker='o') |
23 |
| - plt.xlabel('Index') |
24 |
| - plt.ylabel('Value') |
25 |
| - plt.title('Data Plot') |
26 |
| - """ |
27 |
| - code_with_mod = base_code + "\nplt.savefig('plot.png')" |
| 47 | +def test_get_file_modifications() -> None: |
| 48 | + assert get_file_modifications(code_with_mod, llm=llm) == ["plot.png"] |
| 49 | + assert get_file_modifications(code_no_mod, llm=llm) == [] |
28 | 50 |
|
29 |
| - code_no_mod = base_code + "\nplt.show()" |
30 | 51 |
|
31 |
| - assert get_file_modifications(code_with_mod).modifications == ["plot.png"] |
32 |
| - assert get_file_modifications(code_no_mod).modifications == [] |
| 52 | +async def test_get_file_modifications_async() -> None: |
| 53 | + assert await aget_file_modifications(code_with_mod, llm=llm) == ["plot.png"] |
| 54 | + assert await aget_file_modifications(code_no_mod, llm=llm) == [] |
33 | 55 |
|
34 | 56 |
|
35 | 57 | if __name__ == "__main__":
|
36 |
| - # test_remove_download_link() |
| 58 | + test_remove_download_link() |
| 59 | + asyncio.run(test_remove_download_link_async()) |
37 | 60 | test_get_file_modifications()
|
| 61 | + asyncio.run(test_get_file_modifications_async()) |
0 commit comments