Skip to content

Commit 82411d8

Browse files
authored
Merge pull request #1069 from Pythagora-io/relevant
Relevant
2 parents b33282f + 1c5ece7 commit 82411d8

23 files changed

+227
-101
lines changed

core/agents/convo.py

+4
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,7 @@ def remove_defs(d):
105105
f"YOU MUST NEVER add any additional fields to your response, and NEVER add additional preamble like 'Here is your JSON'."
106106
)
107107
return self
108+
109+
def remove_last_x_messages(self, x: int) -> "AgentConvo":
110+
self.messages = self.messages[:-x]
111+
return self

core/agents/developer.py

+48-33
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Optional
1+
from enum import Enum
2+
from typing import Annotated, Literal, Optional, Union
23
from uuid import uuid4
34

45
from pydantic import BaseModel, Field
56

67
from core.agents.base import BaseAgent
78
from core.agents.convo import AgentConvo
8-
from core.agents.mixins import TaskSteps
9+
from core.agents.mixins import RelevantFilesMixin
910
from core.agents.response import AgentResponse, ResponseType
1011
from core.config import TASK_BREAKDOWN_AGENT_NAME
1112
from core.db.models.project_state import IterationStatus, TaskStatus
@@ -17,11 +18,48 @@
1718
log = get_logger(__name__)
1819

1920

20-
class RelevantFiles(BaseModel):
21-
relevant_files: list[str] = Field(description="List of relevant files for the current task.")
21+
class StepType(str, Enum):
22+
COMMAND = "command"
23+
SAVE_FILE = "save_file"
24+
HUMAN_INTERVENTION = "human_intervention"
2225

2326

24-
class Developer(BaseAgent):
27+
class CommandOptions(BaseModel):
28+
command: str = Field(description="Command to run")
29+
timeout: int = Field(description="Timeout in seconds")
30+
success_message: str = ""
31+
32+
33+
class SaveFileOptions(BaseModel):
34+
path: str
35+
36+
37+
class SaveFileStep(BaseModel):
38+
type: Literal[StepType.SAVE_FILE] = StepType.SAVE_FILE
39+
save_file: SaveFileOptions
40+
41+
42+
class CommandStep(BaseModel):
43+
type: Literal[StepType.COMMAND] = StepType.COMMAND
44+
command: CommandOptions
45+
46+
47+
class HumanInterventionStep(BaseModel):
48+
type: Literal[StepType.HUMAN_INTERVENTION] = StepType.HUMAN_INTERVENTION
49+
human_intervention_description: str
50+
51+
52+
Step = Annotated[
53+
Union[SaveFileStep, CommandStep, HumanInterventionStep],
54+
Field(discriminator="type"),
55+
]
56+
57+
58+
class TaskSteps(BaseModel):
59+
steps: list[Step]
60+
61+
62+
class Developer(RelevantFilesMixin, BaseAgent):
2563
agent_type = "developer"
2664
display_name = "Developer"
2765

@@ -96,7 +134,8 @@ async def breakdown_current_iteration(self, task_review_feedback: Optional[str]
96134
log.debug(f"Breaking down the iteration {description}")
97135
await self.send_message("Breaking down the current task iteration ...")
98136

99-
await self.get_relevant_files(user_feedback, description)
137+
if self.current_state.files and self.current_state.relevant_files is None:
138+
return await self.get_relevant_files(user_feedback, description)
100139

101140
await self.ui.send_task_progress(
102141
n_tasks, # iterations and reviews can be created only one at a time, so we are always on last one
@@ -114,7 +153,6 @@ async def breakdown_current_iteration(self, task_review_feedback: Optional[str]
114153
AgentConvo(self)
115154
.template(
116155
"iteration",
117-
current_task=current_task,
118156
user_feedback=user_feedback,
119157
user_feedback_qa=None,
120158
next_solution_to_try=None,
@@ -175,7 +213,7 @@ async def breakdown_current_task(self) -> AgentResponse:
175213
log.debug(f"Current state files: {len(self.current_state.files)}, relevant {self.current_state.relevant_files}")
176214
# Check which files are relevant to the current task
177215
if self.current_state.files and self.current_state.relevant_files is None:
178-
await self.get_relevant_files()
216+
return await self.get_relevant_files()
179217

180218
current_task_index = self.current_state.tasks.index(current_task)
181219

@@ -189,6 +227,8 @@ async def breakdown_current_task(self) -> AgentResponse:
189227
)
190228
response: str = await llm(convo)
191229

230+
await self.get_relevant_files(None, response)
231+
192232
self.next_state.tasks[current_task_index] = {
193233
**current_task,
194234
"instructions": response,
@@ -214,31 +254,6 @@ async def breakdown_current_task(self) -> AgentResponse:
214254
)
215255
return AgentResponse.done(self)
216256

217-
async def get_relevant_files(
218-
self, user_feedback: Optional[str] = None, solution_description: Optional[str] = None
219-
) -> AgentResponse:
220-
log.debug("Getting relevant files for the current task")
221-
await self.send_message("Figuring out which project files are relevant for the next task ...")
222-
223-
llm = self.get_llm()
224-
convo = (
225-
AgentConvo(self)
226-
.template(
227-
"filter_files",
228-
current_task=self.current_state.current_task,
229-
user_feedback=user_feedback,
230-
solution_description=solution_description,
231-
)
232-
.require_schema(RelevantFiles)
233-
)
234-
235-
llm_response: list[str] = await llm(convo, parser=JSONParser(RelevantFiles), temperature=0)
236-
237-
existing_files = {file.path for file in self.current_state.files}
238-
self.next_state.relevant_files = [path for path in llm_response.relevant_files if path in existing_files]
239-
240-
return AgentResponse.done(self)
241-
242257
def set_next_steps(self, response: TaskSteps, source: str):
243258
# For logging/debugging purposes, we don't want to remove the finished steps
244259
# until we're done with the task.

core/agents/mixins.py

+65-41
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,21 @@
1-
from enum import Enum
2-
from typing import Annotated, Literal, Optional, Union
1+
from typing import Optional
32

43
from pydantic import BaseModel, Field
54

65
from core.agents.convo import AgentConvo
6+
from core.agents.response import AgentResponse
7+
from core.config import GET_RELEVANT_FILES_AGENT_NAME
8+
from core.llm.parser import JSONParser
9+
from core.log import get_logger
710

11+
log = get_logger(__name__)
812

9-
class StepType(str, Enum):
10-
COMMAND = "command"
11-
SAVE_FILE = "save_file"
12-
HUMAN_INTERVENTION = "human_intervention"
1313

14-
15-
class CommandOptions(BaseModel):
16-
command: str = Field(description="Command to run")
17-
timeout: int = Field(description="Timeout in seconds")
18-
success_message: str = ""
19-
20-
21-
class SaveFileOptions(BaseModel):
22-
path: str
23-
24-
25-
class SaveFileStep(BaseModel):
26-
type: Literal[StepType.SAVE_FILE] = StepType.SAVE_FILE
27-
save_file: SaveFileOptions
28-
29-
30-
class CommandStep(BaseModel):
31-
type: Literal[StepType.COMMAND] = StepType.COMMAND
32-
command: CommandOptions
33-
34-
35-
class HumanInterventionStep(BaseModel):
36-
type: Literal[StepType.HUMAN_INTERVENTION] = StepType.HUMAN_INTERVENTION
37-
human_intervention_description: str
38-
39-
40-
Step = Annotated[
41-
Union[SaveFileStep, CommandStep, HumanInterventionStep],
42-
Field(discriminator="type"),
43-
]
44-
45-
46-
class TaskSteps(BaseModel):
47-
steps: list[Step]
14+
class RelevantFiles(BaseModel):
15+
read_files: list[str] = Field(description="List of files you want to read.")
16+
add_files: list[str] = Field(description="List of files you want to add to the list of relevant files.")
17+
remove_files: list[str] = Field(description="List of files you want to remove from the list of relevant files.")
18+
done: bool = Field(description="Boolean flag to indicate that you are done selecting relevant files.")
4819

4920

5021
class IterationPromptMixin:
@@ -74,11 +45,64 @@ async def find_solution(
7445
llm = self.get_llm()
7546
convo = AgentConvo(self).template(
7647
"iteration",
77-
current_task=self.current_state.current_task,
7848
user_feedback=user_feedback,
7949
user_feedback_qa=user_feedback_qa,
8050
next_solution_to_try=next_solution_to_try,
8151
bug_hunting_cycles=bug_hunting_cycles,
8252
)
8353
llm_solution: str = await llm(convo)
8454
return llm_solution
55+
56+
57+
class RelevantFilesMixin:
58+
"""
59+
Provides a method to get relevant files for the current task.
60+
"""
61+
62+
async def get_relevant_files(
63+
self, user_feedback: Optional[str] = None, solution_description: Optional[str] = None
64+
) -> AgentResponse:
65+
log.debug("Getting relevant files for the current task")
66+
await self.send_message("Figuring out which project files are relevant for the next task ...")
67+
68+
done = False
69+
relevant_files = set()
70+
llm = self.get_llm(GET_RELEVANT_FILES_AGENT_NAME)
71+
convo = (
72+
AgentConvo(self)
73+
.template(
74+
"filter_files",
75+
user_feedback=user_feedback,
76+
solution_description=solution_description,
77+
relevant_files=relevant_files,
78+
)
79+
.require_schema(RelevantFiles)
80+
)
81+
82+
while not done and len(convo.messages) < 13:
83+
llm_response: RelevantFiles = await llm(convo, parser=JSONParser(RelevantFiles), temperature=0)
84+
85+
# Check if there are files to add to the list
86+
if llm_response.add_files:
87+
# Add only the files from add_files that are not already in relevant_files
88+
relevant_files.update(file for file in llm_response.add_files if file not in relevant_files)
89+
90+
# Check if there are files to remove from the list
91+
if llm_response.remove_files:
92+
# Remove files from relevant_files that are in remove_files
93+
relevant_files.difference_update(llm_response.remove_files)
94+
95+
read_files = [file for file in self.current_state.files if file.path in llm_response.read_files]
96+
97+
convo.remove_last_x_messages(1)
98+
convo.assistant(llm_response.original_response)
99+
convo.template("filter_files_loop", read_files=read_files, relevant_files=relevant_files).require_schema(
100+
RelevantFiles
101+
)
102+
done = llm_response.done
103+
104+
existing_files = {file.path for file in self.current_state.files}
105+
relevant_files = [path for path in relevant_files if path in existing_files]
106+
self.next_state.relevant_files = relevant_files
107+
108+
return AgentResponse.done(self)

core/agents/spec_writer.py

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ async def initialize_spec(self) -> AgentResponse:
7373
},
7474
)
7575

76+
reviewed_spec = user_description
7677
if len(user_description) < ANALYZE_THRESHOLD and complexity != Complexity.SIMPLE:
7778
initial_spec = await self.analyze_spec(user_description)
7879
reviewed_spec = await self.review_spec(desc=user_description, spec=initial_spec)

core/agents/task_reviewer.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ async def review_code_changes(self) -> AgentResponse:
2828
# Some iterations are created by the task reviewer and have no user feedback
2929
if iteration["user_feedback"]
3030
]
31+
bug_hunter_instructions = [
32+
iteration["bug_hunting_cycles"][-1]["human_readable_instructions"].replace("```", "").strip()
33+
for iteration in self.current_state.iterations
34+
if iteration["bug_hunting_cycles"]
35+
]
3136

3237
files_before_modification = self.current_state.modified_files
3338
files_after_modification = [
@@ -40,10 +45,10 @@ async def review_code_changes(self) -> AgentResponse:
4045
# TODO instead of sending files before and after maybe add nice way to show diff for multiple files
4146
convo = AgentConvo(self).template(
4247
"review_task",
43-
current_task=self.current_state.current_task,
4448
all_feedbacks=all_feedbacks,
4549
files_before_modification=files_before_modification,
4650
files_after_modification=files_after_modification,
51+
bug_hunter_instructions=bug_hunter_instructions,
4752
)
4853
llm_response: str = await llm(convo, temperature=0.7)
4954

core/agents/troubleshooter.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from core.agents.base import BaseAgent
77
from core.agents.convo import AgentConvo
8-
from core.agents.mixins import IterationPromptMixin
8+
from core.agents.mixins import IterationPromptMixin, RelevantFilesMixin
99
from core.agents.response import AgentResponse
1010
from core.db.models.file import File
1111
from core.db.models.project_state import IterationStatus, TaskStatus
@@ -28,7 +28,7 @@ class RouteFilePaths(BaseModel):
2828
files: list[str] = Field(description="List of paths for files that contain routes")
2929

3030

31-
class Troubleshooter(IterationPromptMixin, BaseAgent):
31+
class Troubleshooter(IterationPromptMixin, RelevantFilesMixin, BaseAgent):
3232
agent_type = "troubleshooter"
3333
display_name = "Troubleshooter"
3434

@@ -102,6 +102,7 @@ async def create_iteration(self) -> AgentResponse:
102102
else:
103103
# should be - elif change_description is not None: - but to prevent bugs with the extension
104104
# this might be caused if we show the input field instead of buttons
105+
await self.get_relevant_files(user_feedback)
105106
iteration_status = IterationStatus.NEW_FEATURE_REQUESTED
106107

107108
self.next_state.iterations = self.current_state.iterations + [

core/config/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
CHECK_LOGS_AGENT_NAME = "BugHunter.check_logs"
4040
TASK_BREAKDOWN_AGENT_NAME = "Developer.breakdown_current_task"
4141
SPEC_WRITER_AGENT_NAME = "SpecWriter"
42+
GET_RELEVANT_FILES_AGENT_NAME = "get_relevant_files"
4243

4344
# Endpoint for the external documentation
4445
EXTERNAL_DOCUMENTATION_API = "http://docs-pythagora-io-439719575.us-east-1.elb.amazonaws.com"
@@ -330,6 +331,7 @@ class Config(_StrictModel):
330331
temperature=0.5,
331332
),
332333
SPEC_WRITER_AGENT_NAME: AgentLLMConfig(model="gpt-4-0125-preview", temperature=0.0),
334+
GET_RELEVANT_FILES_AGENT_NAME: AgentLLMConfig(model="claude-3-5-sonnet-20240620", temperature=0.0),
333335
}
334336
)
335337
prompt: PromptConfig = PromptConfig()

core/db/models/project_state.py

+1
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def complete_iteration(self):
303303

304304
log.debug(f"Completing iteration {self.unfinished_iterations[0]}")
305305
self.unfinished_iterations[0]["status"] = IterationStatus.DONE
306+
self.relevant_files = None
306307
self.flag_iterations_as_modified()
307308

308309
def flag_iterations_as_modified(self):

core/llm/parser.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import Enum
44
from typing import Optional, Union
55

6-
from pydantic import BaseModel, ValidationError
6+
from pydantic import BaseModel, ValidationError, create_model
77

88

99
class MultiCodeBlockParser:
@@ -86,6 +86,7 @@ class JSONParser:
8686
def __init__(self, spec: Optional[BaseModel] = None, strict: bool = True):
8787
self.spec = spec
8888
self.strict = strict or (spec is not None)
89+
self.original_response = None
8990

9091
@property
9192
def schema(self):
@@ -102,7 +103,8 @@ def errors_to_markdown(errors: list) -> str:
102103
return "\n".join(error_txt)
103104

104105
def __call__(self, text: str) -> Union[BaseModel, dict, None]:
105-
text = text.strip()
106+
self.original_response = text.strip() # Store the original text
107+
text = self.original_response
106108
if text.startswith("```"):
107109
try:
108110
text = CodeBlockParser()(text)
@@ -130,7 +132,17 @@ def __call__(self, text: str) -> Union[BaseModel, dict, None]:
130132
except Exception as err:
131133
raise ValueError(f"Error parsing JSON: {err}") from err
132134

133-
return model
135+
# Create a new model that includes the original model fields and the original text
136+
ExtendedModel = create_model(
137+
f"Extended{self.spec.__name__}",
138+
original_response=(str, ...),
139+
**{field_name: (field.annotation, field.default) for field_name, field in self.spec.__fields__.items()},
140+
)
141+
142+
# Instantiate the extended model
143+
extended_model = ExtendedModel(original_response=self.original_response, **model.dict())
144+
145+
return extended_model
134146

135147

136148
class EnumParser:

0 commit comments

Comments
 (0)