Skip to content

Commit 729f002

Browse files
committed
Fixes React Agent to also work with other providers
1 parent 1cfb830 commit 729f002

File tree

5 files changed

+20
-23
lines changed

5 files changed

+20
-23
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,11 @@ def _create_turn_streaming(
162162
while not is_turn_complete:
163163
is_turn_complete = True
164164
for chunk in turn_response:
165-
tool_calls = self._get_tool_calls(chunk)
166165
if hasattr(chunk, "error"):
167166
yield chunk
168167
return
169-
elif not tool_calls:
168+
tool_calls = self._get_tool_calls(chunk)
169+
if not tool_calls:
170170
yield chunk
171171
else:
172172
is_turn_complete = False

src/llama_stack_client/lib/agents/client_tool.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def get_params_definition(self) -> Dict[str, Parameter]:
163163
params[name] = Parameter(
164164
name=name,
165165
description=param_doc or f"Parameter {name}",
166-
parameter_type=type_hint.__name__,
166+
# Hack: litellm/openai expects "string" for str type
167+
parameter_type=type_hint.__name__ if type_hint.__name__ != "str" else "string",
167168
default=(param.default if param.default != inspect.Parameter.empty else None),
168169
required=is_required,
169170
)

src/llama_stack_client/lib/agents/event_logger.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def yield_printable_events(self, chunk: Any) -> Iterator[TurnStreamPrintableEven
6363
for printable_event in self._yield_printable_events(chunk, self.previous_event_type, self.previous_step_type):
6464
yield printable_event
6565

66-
self.previous_event_type, self.previous_step_type = self._get_event_type_step_type(chunk)
66+
if not hasattr(chunk, "error"):
67+
self.previous_event_type, self.previous_step_type = self._get_event_type_step_type(chunk)
6768

6869
def _yield_printable_events(
6970
self, chunk: Any, previous_event_type: Optional[str] = None, previous_step_type: Optional[str] = None

src/llama_stack_client/lib/agents/react/agent.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,18 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6-
from typing import Any, Dict, Optional, Tuple
6+
from typing import Optional, Tuple
77

88
from llama_stack_client import LlamaStackClient
99
from llama_stack_client.types.agent_create_params import AgentConfig
10-
from pydantic import BaseModel
10+
1111

1212
from ..agent import Agent
1313
from ..client_tool import ClientTool
1414
from ..tool_parser import ToolParser
1515
from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE
1616

17-
from .tool_parser import ReActToolParser
18-
19-
20-
class Action(BaseModel):
21-
tool_name: str
22-
tool_params: Dict[str, Any]
23-
24-
25-
class ReActOutput(BaseModel):
26-
thought: str
27-
action: Optional[Action] = None
28-
answer: Optional[str] = None
17+
from .tool_parser import ReActToolParser, ReActOutput
2918

3019

3120
class ReActAgent(Agent):

src/llama_stack_client/lib/agents/react/tool_parser.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,28 @@
55
# the root directory of this source tree.
66

77
from pydantic import BaseModel, ValidationError
8-
from typing import Dict, Any, Optional, List
8+
from typing import Optional, List, Union
99
from ..tool_parser import ToolParser
1010
from llama_stack_client.types.shared.completion_message import CompletionMessage
1111
from llama_stack_client.types.shared.tool_call import ToolCall
1212

1313
import uuid
1414

1515

16+
class Param(BaseModel):
17+
name: str
18+
value: Union[str, int, float, bool]
19+
20+
1621
class Action(BaseModel):
1722
tool_name: str
18-
tool_params: Dict[str, Any]
23+
tool_params: List[Param]
1924

2025

2126
class ReActOutput(BaseModel):
2227
thought: str
23-
action: Optional[Action] = None
24-
answer: Optional[str] = None
28+
action: Optional[Action]
29+
answer: Optional[str]
2530

2631

2732
class ReActToolParser(ToolParser):
@@ -40,8 +45,9 @@ def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]:
4045
if react_output.action:
4146
tool_name = react_output.action.tool_name
4247
tool_params = react_output.action.tool_params
48+
params = {param.name: param.value for param in tool_params}
4349
if tool_name and tool_params:
4450
call_id = str(uuid.uuid4())
45-
tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)]
51+
tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=params)]
4652

4753
return tool_calls

0 commit comments

Comments
 (0)