Skip to content

Commit 1fc1118

Browse files
committed
add claude support
1 parent 5ee9bdf commit 1fc1118

File tree

4 files changed

+188
-8
lines changed

4 files changed

+188
-8
lines changed

operate/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
from dotenv import load_dotenv
44
from openai import OpenAI
5+
import anthropic
56
from prompt_toolkit.shortcuts import input_dialog
67
import google.generativeai as genai
78

@@ -33,6 +34,10 @@ def __init__(self):
3334
self.google_api_key = (
3435
None # instance variables are backups in case saving to a `.env` fails
3536
)
37+
self.anthropic_api_key = (
38+
None # instance variables are backups in case saving to a `.env` fails
39+
)
40+
3641

3742
def initialize_openai(self):
3843
if self.verbose:
@@ -71,6 +76,14 @@ def initialize_google(self):
7176
model = genai.GenerativeModel("gemini-pro-vision")
7277

7378
return model
79+
80+
def initialize_anthropic(self):
81+
if self.anthropic_api_key:
82+
api_key = self.anthropic_api_key
83+
else:
84+
api_key = os.getenv("ANTHROPIC_API_KEY")
85+
return anthropic.Anthropic(api_key=api_key)
86+
7487

7588
def validation(self, model, voice_mode):
7689
"""
@@ -87,6 +100,9 @@ def validation(self, model, voice_mode):
87100
self.require_api_key(
88101
"GOOGLE_API_KEY", "Google API key", model == "gemini-pro-vision"
89102
)
103+
self.require_api_key(
104+
"ANTHROPIC_API_KEY", "Anthropic API key", model == "claude-3-with-ocr"
105+
)
90106

91107
def require_api_key(self, key_name, key_description, is_required):
92108
key_exists = bool(os.environ.get(key_name))

operate/models/apis.py

Lines changed: 162 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,16 @@ async def get_next_action(model, messages, objective, session_id):
4848
if model == "gpt-4-with-ocr":
4949
operation = await call_gpt_4_vision_preview_ocr(messages, objective, model)
5050
return operation, None
51-
elif model == "agent-1":
51+
if model == "agent-1":
5252
return "coming soon"
53-
elif model == "gemini-pro-vision":
53+
if model == "gemini-pro-vision":
5454
return call_gemini_pro_vision(messages, objective), None
55-
elif model == "llava":
56-
operation = call_ollama_llava(messages), None
57-
return operation
58-
55+
if model == "llava":
56+
operation = call_ollama_llava(messages)
57+
return operation, None
58+
if model == "claude-3-with-ocr":
59+
operation = await call_claude_3_with_ocr(messages, objective, model)
60+
return operation, None
5961
raise ModelNotRecognizedException(model)
6062

6163

@@ -261,7 +263,7 @@ async def call_gpt_4_vision_preview_ocr(messages, objective, model):
261263
result = reader.readtext(screenshot_filename)
262264

263265
text_element_index = get_text_element(
264-
result, text_to_click, screenshot_filename
266+
result, text_to_click[:3], screenshot_filename
265267
)
266268
coordinates = get_text_coordinates(
267269
result, text_element_index, screenshot_filename
@@ -528,6 +530,159 @@ def call_ollama_llava(messages):
528530
return call_ollama_llava(messages)
529531

530532

533+
async def call_claude_3_with_ocr(messages, objective, model):
534+
if config.verbose:
535+
print("[call_claude_3_with_ocr]")
536+
537+
try:
538+
time.sleep(1)
539+
client = config.initialize_anthropic()
540+
541+
confirm_system_prompt(messages, objective, model)
542+
screenshots_dir = "screenshots"
543+
if not os.path.exists(screenshots_dir):
544+
os.makedirs(screenshots_dir)
545+
546+
screenshot_filename = os.path.join(screenshots_dir, "screenshot.png")
547+
capture_screen_with_cursor(screenshot_filename)
548+
549+
with open(screenshot_filename, "rb") as img_file:
550+
img = Image.open(img_file)
551+
552+
# Calculate the new dimensions while maintaining the aspect ratio
553+
original_width, original_height = img.size
554+
aspect_ratio = original_width / original_height
555+
new_width = 2560 # Adjust this value to achieve the desired file size
556+
new_height = int(new_width / aspect_ratio)
557+
558+
# Resize the image
559+
img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
560+
561+
# Save the resized image to a BytesIO object
562+
img_buffer = io.BytesIO()
563+
img_resized.save(img_buffer, format='PNG')
564+
img_buffer.seek(0)
565+
566+
# Encode the resized image as base64
567+
img_data = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
568+
569+
if len(messages) == 1:
570+
user_prompt = get_user_first_message_prompt()
571+
else:
572+
user_prompt = get_user_prompt()
573+
574+
vision_message = {
575+
"role": "user",
576+
"content": [
577+
{
578+
"type": "image",
579+
"source": {
580+
"type": "base64",
581+
"media_type": "image/png",
582+
"data": img_data,
583+
},
584+
},
585+
{"type": "text", "text": user_prompt + "**REMEMBER** Only output json format, do not append any other text."},
586+
],
587+
}
588+
messages.append(vision_message)
589+
590+
# anthropic api expect system prompt as an separate argument
591+
response = client.messages.create(
592+
model="claude-3-opus-20240229",
593+
max_tokens=3000,
594+
system=messages[0]["content"],
595+
messages=messages[1:],
596+
)
597+
598+
content = response.content[0].text
599+
content = clean_json(content)
600+
content_str = content
601+
try:
602+
content = json.loads(content)
603+
except json.JSONDecodeError as e:
604+
if config.verbose:
605+
print(
606+
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_RED}[Error] JSONDecodeError: {e} {ANSI_RESET}"
607+
)
608+
response = client.messages.create(
609+
model="claude-3-opus-20240229",
610+
max_tokens=3000,
611+
system=f"This json string is not valid, when using with json.loads(content) \
612+
it throws the following error: {e}, return correct json string. **REMEMBER** Only output json format, do not append any other text.",
613+
messages=[{"role": "user", "content": content}],
614+
)
615+
content = response.content[0].text
616+
content = clean_json(content)
617+
content_str = content
618+
content = json.loads(content)
619+
620+
if config.verbose:
621+
print(
622+
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_BRIGHT_MAGENTA}[{model}] content: {content} {ANSI_RESET}"
623+
)
624+
processed_content = []
625+
626+
for operation in content:
627+
if operation.get("operation") == "click":
628+
text_to_click = operation.get("text")
629+
if config.verbose:
630+
print(
631+
"[call_claude_3_ocr][click] text_to_click",
632+
text_to_click,
633+
)
634+
# Initialize EasyOCR Reader
635+
reader = easyocr.Reader(["en"])
636+
637+
# Read the screenshot
638+
result = reader.readtext(screenshot_filename)
639+
640+
text_element_index = get_text_element(
641+
result, text_to_click[:3], screenshot_filename
642+
)
643+
coordinates = get_text_coordinates(
644+
result, text_element_index, screenshot_filename
645+
)
646+
647+
# add `coordinates`` to `content`
648+
operation["x"] = coordinates["x"]
649+
operation["y"] = coordinates["y"]
650+
651+
if config.verbose:
652+
print(
653+
"[call_claude_3_ocr][click] text_element_index",
654+
text_element_index,
655+
)
656+
print(
657+
"[call_claude_3_ocr][click] coordinates",
658+
coordinates,
659+
)
660+
print(
661+
"[call_claude_3_ocr][click] final operation",
662+
operation,
663+
)
664+
processed_content.append(operation)
665+
666+
else:
667+
processed_content.append(operation)
668+
669+
670+
assistant_message = {"role": "assistant", "content": content_str}
671+
messages.append(assistant_message)
672+
673+
return processed_content
674+
675+
except Exception as e:
676+
print(
677+
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_BRIGHT_MAGENTA}[{model}] That did not work. Trying another method {ANSI_RESET}"
678+
)
679+
if config.verbose:
680+
print("[Self-Operating Computer][Operate] error", e)
681+
traceback.print_exc()
682+
raise Exception(e)
683+
#return gpt_4_fallback(messages, objective, model)
684+
685+
531686
def get_last_assistant_message(messages):
532687
"""
533688
Retrieve the last message from the assistant in the messages array.

operate/models/prompts.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
From looking at the screen, the objective, and your previous actions, take the next best series of action.
7373
7474
You have 4 possible operation actions available to you. The `pyautogui` library will be used to execute your decision. Your output will be used in a `json.loads` loads statement.
75+
**REMEMBER** Only output json format, do not append any other text.
7576
7677
1. click - Move mouse and click - We labeled the clickable elements with red bounding boxes and IDs. Label IDs are in the following format with `x` being a number: `~x`
7778
```
@@ -238,6 +239,13 @@ def get_system_prompt(model, objective):
238239
os_search_str=os_search_str,
239240
operating_system=operating_system,
240241
)
242+
elif model == "claude-3-with-ocr":
243+
prompt = SYSTEM_PROMPT_OCR.format(
244+
objective=objective,
245+
cmd_string=cmd_string,
246+
os_search_str=os_search_str,
247+
operating_system=operating_system,
248+
)
241249
else:
242250
prompt = SYSTEM_PROMPT_STANDARD.format(
243251
objective=objective,

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,5 @@ google-generativeai==0.3.0
5151
aiohttp==3.9.1
5252
ultralytics==8.0.227
5353
easyocr==1.7.1
54-
ollama==0.1.6
54+
ollama==0.1.6
55+
anthropic

0 commit comments

Comments
 (0)