Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning_budget_increment parameter to trigger warnings at cost thresholds #7639

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
6 changes: 4 additions & 2 deletions frontend/__tests__/services/actions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ describe("Actions Service", () => {
describe("handleStatusMessage", () => {
it("should dispatch info messages to status state", () => {
const message = {
type: "info",
type: "info" as const,
message: "Runtime is not available",
id: "runtime.unavailable",
status_update: true as const,
Expand All @@ -36,10 +36,12 @@ describe("Actions Service", () => {
payload: message,
}));
});

// Test for cost threshold warning messages will be added in a separate PR

it("should log error messages and display them in chat", () => {
const message = {
type: "error",
type: "error" as const,
message: "Runtime connection failed",
id: "runtime.connection.failed",
status_update: true as const,
Expand Down
86 changes: 86 additions & 0 deletions frontend/src/components/shared/cost-threshold-toast.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import React from "react";
import toast from "react-hot-toast";
import { useWsClient } from "#/context/ws-client-provider";
import { generateAgentStateChangeEvent } from "#/services/agent-state-service";
import { AgentState } from "#/types/agent-state";

interface CostThresholdToastProps {
message: string;
}

// Define the component first before using it
function CostThresholdToast({
message,
}: CostThresholdToastProps): React.ReactElement {
const { send } = useWsClient();

const handleApprove = (): void => {
// Change agent state to RUNNING
send(generateAgentStateChangeEvent(AgentState.RUNNING));
toast.dismiss("cost-threshold-toast");
};

const handleReject = (): void => {
// Keep agent in PAUSED state
toast.dismiss("cost-threshold-toast");
};

return (
<div className="max-w-md w-full bg-gray-800 shadow-lg rounded-lg pointer-events-auto flex flex-col ring-1 ring-black ring-opacity-5">
<div className="p-4">
<div className="flex items-start">
<div className="flex-shrink-0 pt-0.5">
<svg
className="h-6 w-6 text-yellow-500"
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
/>
</svg>
</div>
<div className="ml-3 flex-1">
<p className="text-sm font-medium text-white">
Cost Threshold Alert
</p>
<p className="mt-1 text-sm text-gray-300">{message}</p>
</div>
</div>
</div>
<div className="flex border-t border-gray-700">
<button
type="button"
onClick={handleApprove}
className="flex-1 px-4 py-2 text-sm font-medium text-white bg-green-600 hover:bg-green-700 rounded-bl-lg transition-colors"
>
Approve & Continue
</button>
<button
type="button"
onClick={handleReject}
className="flex-1 px-4 py-2 text-sm font-medium text-white bg-red-600 hover:bg-red-700 rounded-br-lg transition-colors border-l border-gray-700"
>
Reject
</button>
</div>
</div>
);
}

export const showCostThresholdToast = (message: string): void => {
// Dismiss any existing cost threshold toasts
toast.dismiss("cost-threshold-toast");

// Show the custom toast
toast.custom(() => <CostThresholdToast message={message} />, {
id: "cost-threshold-toast",
duration: Infinity, // Toast stays until user interacts with it
position: "top-center",
});
};
7 changes: 7 additions & 0 deletions frontend/src/services/actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
} from "#/types/message";
import { handleObservationMessage } from "./observations";
import { appendInput } from "#/state/command-slice";
import { showCostThresholdToast } from "#/components/shared/cost-threshold-toast";

const messageActions = {
[ActionType.BROWSE]: (message: ActionMessage) => {
Expand Down Expand Up @@ -126,6 +127,12 @@ export function handleStatusMessage(message: StatusMessage) {
...message,
}),
);
} else if (
message.type === "warning" &&
message.id === "STATUS$COST_THRESHOLD_REACHED"
) {
// Show the cost threshold toast for user approval
showCostThresholdToast(message.message);
} else if (message.type === "error") {
trackError({
message: message.message,
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/types/message.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ export interface ObservationMessage {

export interface StatusMessage {
status_update: true;
type: string;
type: "info" | "warning" | "error";
id?: string;
message: string;
}
66 changes: 56 additions & 10 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
event_stream: EventStream,
max_iterations: int,
max_budget_per_task: float | None = None,
warning_budget_increment: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
sid: str | None = None,
Expand All @@ -118,6 +119,8 @@ def __init__(
event_stream: The event stream to publish events to.
max_iterations: The maximum number of iterations the agent can run.
max_budget_per_task: The maximum budget (in USD) allowed per task, beyond which the agent will stop.
warning_budget_increment: The budget increment (in USD) at which to warn the user and ask for approval to continue.
For example, if set to 5.0, the agent will pause and ask for approval at $5, $10, $15, etc.
agent_to_llm_config: A dictionary mapping agent names to LLM configurations in the case that
we delegate to a different agent.
agent_configs: A dictionary mapping agent names to agent configurations in the case that
Expand Down Expand Up @@ -151,6 +154,7 @@ def __init__(
confirmation_mode=confirmation_mode,
)
self.max_budget_per_task = max_budget_per_task
self.warning_budget_increment = warning_budget_increment
self.agent_to_llm_config = agent_to_llm_config if agent_to_llm_config else {}
self.agent_configs = agent_configs if agent_configs else {}
self._initial_max_iterations = max_iterations
Expand Down Expand Up @@ -552,14 +556,15 @@ async def set_agent_state_to(self, new_state: AgentState) -> None:
await self.update_state_after_step()
self.state.metrics.merge(self.state.local_metrics)
self._reset()
elif (
new_state == AgentState.RUNNING
and self.state.agent_state == AgentState.PAUSED
# TODO: do we really need both THROTTLING and PAUSED states, or can we clean up one of them completely?
and self.state.traffic_control_state == TrafficControlState.THROTTLING
):
# user intends to interrupt traffic control and let the task resume temporarily
self.state.traffic_control_state = TrafficControlState.PAUSED
elif new_state == AgentState.RUNNING:
# When the agent starts running from a paused state
if (
self.state.agent_state == AgentState.PAUSED
# TODO: do we really need both THROTTLING and PAUSED states, or can we clean up one of them completely?
and self.state.traffic_control_state == TrafficControlState.THROTTLING
):
# user intends to interrupt traffic control and let the task resume temporarily
self.state.traffic_control_state = TrafficControlState.PAUSED
# User has chosen to deliberately continue - lets double the max iterations
if (
self.state.iteration is not None
Expand Down Expand Up @@ -655,6 +660,7 @@ async def start_delegate(self, action: AgentDelegateAction) -> None:
event_stream=self.event_stream,
max_iterations=self.state.max_iterations,
max_budget_per_task=self.max_budget_per_task,
warning_budget_increment=self.warning_budget_increment,
agent_to_llm_config=self.agent_to_llm_config,
agent_configs=self.agent_configs,
initial_state=state,
Expand Down Expand Up @@ -733,6 +739,31 @@ async def _step(self) -> None:
stop_step = await self._handle_traffic_control(
'iteration', self.state.iteration, self.state.max_iterations
)

# Check for warning budget increment
if (
self.state.metrics.accumulated_cost is not None
and self.warning_budget_increment is not None
and self.warning_budget_increment > 0
and self.get_agent_state() == AgentState.RUNNING
):
# Calculate the current threshold we're at and the next threshold
current_cost = self.state.metrics.accumulated_cost
current_threshold = (
int(current_cost / self.warning_budget_increment)
* self.warning_budget_increment
)
next_threshold = current_threshold + self.warning_budget_increment

# If we've just crossed a threshold (within 0.01 of the next threshold)
if (
current_cost >= next_threshold - 0.01
and current_cost < next_threshold + self.warning_budget_increment - 0.01
):
stop_step = await self._handle_traffic_control(
'warning_budget', current_cost, next_threshold
)

if self.max_budget_per_task is not None:
current_cost = self.state.metrics.accumulated_cost
if current_cost > self.max_budget_per_task:
Expand Down Expand Up @@ -848,15 +879,30 @@ async def _handle_traffic_control(
self.state.traffic_control_state = TrafficControlState.NORMAL
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
# Format values as integers for iterations, keep decimals for budget
# Format values as integers for iterations, keep decimals for budget and cost
if limit_type == 'iteration':
current_str = str(int(current_value))
max_str = str(int(max_value))
else:
current_str = f'{current_value:.2f}'
max_str = f'{max_value:.2f}'

if self.headless_mode:
if limit_type == 'warning_budget':
# Special handling for warning budget increment
await self.set_agent_state_to(AgentState.PAUSED)
if self.status_callback is not None:
next_threshold = (
max_value + self.warning_budget_increment
if self.warning_budget_increment
else 0
)
self.status_callback(
'warning',
'STATUS$COST_THRESHOLD_REACHED',
f'Cost threshold of ${max_str} USD reached. Current cost: ${current_str} USD. '
+ f'Next warning at ${next_threshold:.2f} USD. Please approve to continue.',
)
elif self.headless_mode:
e = RuntimeError(
f'Agent reached maximum {limit_type} in headless mode. '
f'Current {limit_type}: {current_str}, max {limit_type}: {max_str}'
Expand Down
5 changes: 5 additions & 0 deletions openhands/server/session/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ async def start(
config.security.confirmation_mode,
max_iterations,
max_budget_per_task=max_budget_per_task,
warning_budget_increment=5.0, # Default to 5.0 USD increments
agent_to_llm_config=agent_to_llm_config,
agent_configs=agent_configs,
)
Expand Down Expand Up @@ -232,6 +233,7 @@ def _run_replay(
config.security.confirmation_mode,
max_iterations,
max_budget_per_task=max_budget_per_task,
warning_budget_increment=5.0, # Default to 5.0 USD increments
agent_to_llm_config=agent_to_llm_config,
agent_configs=agent_configs,
replay_events=replay_events[1:],
Expand Down Expand Up @@ -340,6 +342,7 @@ def _create_controller(
confirmation_mode: bool,
max_iterations: int,
max_budget_per_task: float | None = None,
warning_budget_increment: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
replay_events: list[Event] | None = None,
Expand All @@ -351,6 +354,7 @@ def _create_controller(
- confirmation_mode: Whether to use confirmation mode
- max_iterations:
- max_budget_per_task:
- warning_budget_increment: Budget increment at which to warn the user and ask for approval
- agent_to_llm_config:
- agent_configs:
"""
Expand Down Expand Up @@ -382,6 +386,7 @@ def _create_controller(
agent=agent,
max_iterations=int(max_iterations),
max_budget_per_task=max_budget_per_task,
warning_budget_increment=warning_budget_increment,
agent_to_llm_config=agent_to_llm_config,
agent_configs=agent_configs,
confirmation_mode=confirmation_mode,
Expand Down
Loading