-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp_state.py
83 lines (65 loc) · 3.03 KB
/
app_state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Literal
from datetime import datetime
from utils import GENERATE_SQL_SYSTEM_PROMPT, ANALYZE_SQL_SYSTEM_PROMPT, INFO_CLARIFICATION_SYSTEM_PROMPT
@dataclass
class Query:
sql: str
is_safe: bool
output: Optional[str] = None
analysis: Optional[str] = None
@dataclass
class AppState:
"""
This class is used to maintain the state of the application.
"""
# Chat context for maintaining conversation history
chat_context: List[Dict] = field(default_factory=lambda: [{"role": "system", "content": ""}])
# Query history tracking
query_history: List[Query] = field(default_factory=list)
# Response mode of the agent
response_mode: Literal["sql_generation", "sql_analysis", "info_clarification"] = "sql_generation"
# Database connection state
is_connected: bool = False
last_connection_time: Optional[datetime] = None
connection_errors: List[str] = field(default_factory=list)
# User preferences
max_query_history: int = 100
show_execution_time: bool = True
auto_analyze_results: bool = True
def add_to_chat_context(self, role: str, content: str) -> None:
"""Add a message to the chat context."""
self.chat_context.append({"role": role, "content": content})
def add_query_to_history(self, query: Query) -> None:
"""Add a query to the history, maintaining the max size limit."""
self.query_history.append(query)
if len(self.query_history) > self.max_query_history:
self.query_history.pop(0)
def clear_chat_context(self) -> None:
"""Clear the chat context."""
self.chat_context = [{"role": "system", "content": ""}]
def clear_query_history(self) -> None:
"""Clear the query history."""
self.query_history.clear()
def get_recent_queries(self, limit: int = 5) -> List[Query]:
"""Get the most recent queries."""
return self.query_history[-limit:]
def get_successful_queries(self) -> List[Query]:
"""Get all successful queries."""
return [q for q in self.query_history if q.is_safe]
def get_failed_queries(self) -> List[Query]:
"""Get all failed queries."""
return [q for q in self.query_history if not q.is_safe]
def set_response_mode(self, mode: Literal["sql_generation", "sql_analysis", "info_clarification"]) -> None:
"""Update the response mode and set the appropriate system prompt."""
self.response_mode = mode
if mode == "sql_generation":
self.chat_context[0]["content"] = GENERATE_SQL_SYSTEM_PROMPT
elif mode == "sql_analysis":
self.chat_context[0]["content"] = ANALYZE_SQL_SYSTEM_PROMPT
elif mode == "info_clarification":
self.chat_context[0]["content"] = INFO_CLARIFICATION_SYSTEM_PROMPT
# Initialize global app state
app_state = AppState()
# Set initial system prompt
app_state.set_response_mode("sql_generation")