Skip to content

Commit b9a9985

Browse files
committed
wip
1 parent 7ccded6 commit b9a9985

File tree

7 files changed

+169
-123
lines changed

7 files changed

+169
-123
lines changed

examples/olsconfig.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ ols_config:
6060
# product_docs_index_path: "./vector_db/ocp_product_docs/4.15"
6161
# product_docs_index_id: ocp-product-docs-4_15
6262
# embeddings_model_path: "./embeddings_model"
63-
introspection_enabled: true # Default is false, OLS tool calling
6463
conversation_cache:
6564
type: memory
6665
memory:
@@ -117,4 +116,4 @@ dev_config:
117116
# uvicorn_port_number: 8081
118117
# llm_params:
119118
# temperature_override: 0
120-
# k8s_auth_token: optional_token_when_no_available_kube_config
119+
# k8s_auth_token: optional_token_when_no_available_kube_config

ols/app/models/config.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,6 @@ class OLSConfig(BaseModel):
913913
"""OLS configuration."""
914914

915915
conversation_cache: Optional[ConversationCacheConfig] = None
916-
introspection_enabled: Optional[bool] = False
917916
logging_config: Optional[LoggingConfig] = None
918917
reference_content: Optional[ReferenceContent] = None
919918
authentication_config: AuthenticationConfig = AuthenticationConfig()
@@ -944,7 +943,6 @@ def __init__(
944943
if data is None:
945944
return
946945

947-
self.introspection_enabled = data.get("introspection_enabled", False)
948946
self.conversation_cache = ConversationCacheConfig(
949947
data.get("conversation_cache", None)
950948
)
@@ -994,7 +992,6 @@ def __eq__(self, other: object) -> bool:
994992
if isinstance(other, OLSConfig):
995993
return (
996994
self.conversation_cache == other.conversation_cache
997-
and self.introspection_enabled == other.introspection_enabled
998995
and self.logging_config == other.logging_config
999996
and self.reference_content == other.reference_content
1000997
and self.default_provider == other.default_provider
@@ -1111,13 +1108,21 @@ def validate_token_is_set_when_needed(self) -> Self:
11111108
return self
11121109

11131110

1111+
class Tool(BaseModel):
1112+
"""Tool definition."""
1113+
1114+
name: str
1115+
type: Literal["tool-set"]
1116+
1117+
11141118
class Config(BaseModel):
11151119
"""Global service configuration."""
11161120

11171121
llm_providers: LLMProviders = LLMProviders()
11181122
ols_config: OLSConfig = OLSConfig()
11191123
dev_config: DevConfig = DevConfig()
11201124
user_data_collector_config: Optional[UserDataCollectorConfig] = None
1125+
tools: list[Tool] = []
11211126

11221127
def __init__(
11231128
self,
@@ -1148,7 +1153,9 @@ def __init__(
11481153
self.user_data_collector_config = UserDataCollectorConfig(
11491154
**data.get("user_data_collector_config", {})
11501155
)
1156+
self.tools = [Tool(**tool) for tool in data.get("tools", [])]
11511157

1158+
# TODO: tyhle comparison jsou uplne zbytecne - smazat
11521159
def __eq__(self, other: object) -> bool:
11531160
"""Compare two objects for equality."""
11541161
if isinstance(other, Config):

ols/plugins/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ def import_modules_from_dir(dir_name: str) -> None:
2525

2626

2727
# import_modules_from_dir("providers")
28-
# import_modules_from_dir("tools")
28+
import_modules_from_dir("tools")

ols/src/tools/oc_cli.py renamed to ols/plugins/tools/openshift.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from langchain.tools import tool
1818
from langchain_core.tools import InjectedToolArg
1919

20+
from ols.src.tools.tools import ToolSetProvider, register_tool_provider_as
21+
2022
logger = logging.getLogger(__name__)
2123

2224

@@ -60,41 +62,12 @@ def run_oc(args: list[str]) -> subprocess.CompletedProcess:
6062
["oc", *args], # noqa: S607
6163
capture_output=True,
6264
text=True,
63-
check=False,
65+
check=True,
6466
shell=False,
6567
)
6668
return res
6769

6870

69-
def token_works_for_oc(token: str) -> bool:
70-
"""Check if the token can be used with `oc` CLI.
71-
72-
Args:
73-
token: OpenShift user token.
74-
75-
Returns:
76-
True if user token works, False otherwise.
77-
"""
78-
r = run_oc(["version", f"--token={token}"])
79-
80-
if r.returncode == 0:
81-
logger.info("Token is usable for oc CLI")
82-
return True
83-
84-
logger.error(
85-
"Unable to use the token for oc CLI; stdout: %s, stderr: %s",
86-
r.stdout,
87-
r.stderr,
88-
)
89-
return False
90-
91-
92-
def stdout_or_stderr(result: subprocess.CompletedProcess) -> str:
93-
"""Return stdout if return code is 0, otherwise return stderr."""
94-
return result.stdout if result.returncode == 0 else result.stderr
95-
96-
97-
# NOTE: tools description comes from oc cli --help for each subcommand (shortened)
9871
@tool
9972
def oc_get(oc_get_args: list[str], token: Annotated[str, InjectedToolArg]) -> str:
10073
"""Display one or many resources from OpenShift cluster.
@@ -130,7 +103,7 @@ def oc_get(oc_get_args: list[str], token: Annotated[str, InjectedToolArg]) -> st
130103
oc get rc,services
131104
"""
132105
result = run_oc(["get", *sanitize_oc_args(oc_get_args), "--token", token])
133-
return stdout_or_stderr(result)
106+
return result.stdout
134107

135108

136109
@tool
@@ -165,7 +138,7 @@ def oc_describe(
165138
oc describe pods frontend
166139
""" # noqa: E501
167140
result = run_oc(["describe", *sanitize_oc_args(oc_describe_args), "--token", token])
168-
return stdout_or_stderr(result)
141+
return result.stdout
169142

170143

171144
@tool
@@ -193,7 +166,7 @@ def oc_logs(oc_logs_args: list[str], token: Annotated[str, InjectedToolArg]) ->
193166
oc logs -f pod/backend -c ruby-container
194167
""" # noqa: E501
195168
result = run_oc(["logs", *sanitize_oc_args(oc_logs_args), "--token", token])
196-
return stdout_or_stderr(result)
169+
return result.stdout
197170

198171

199172
@tool
@@ -219,7 +192,7 @@ def oc_status(oc_status_args: list[str], token: Annotated[str, InjectedToolArg])
219192
oc --suggest
220193
"""
221194
result = run_oc(["status", *sanitize_oc_args(oc_status_args), "--token", token])
222-
return stdout_or_stderr(result)
195+
return result.stdout
223196

224197

225198
@tool
@@ -237,7 +210,7 @@ def show_pods(token: Annotated[str, InjectedToolArg]) -> str:
237210
kube-system kube-apiserver-proxy-ip-10-0-130-91.ec2.internal 2m 13Mi
238211
"""
239212
result = run_oc([*["adm", "top", "pods", "-A"], "--token", token])
240-
return stdout_or_stderr(result)
213+
return result.stdout
241214

242215

243216
@tool
@@ -270,4 +243,36 @@ def oc_adm_top(
270243
result = run_oc(
271244
["adm", "top", *sanitize_oc_args(oc_adm_top_args), "--token", token]
272245
)
273-
return stdout_or_stderr(result)
246+
return result.stdout
247+
248+
249+
@register_tool_provider_as("openshift")
250+
class OCToolProvider(ToolSetProvider):
251+
"""Provider for OpenShift CLI tools."""
252+
253+
@property
254+
def tools(self):
255+
"""Get all OC tools."""
256+
return {
257+
"oc_get": oc_get,
258+
"oc_describe": oc_describe,
259+
"oc_logs": oc_logs,
260+
"oc_adm_top": oc_adm_top,
261+
"oc_status": oc_status,
262+
"show_pods": show_pods,
263+
}
264+
265+
# TODO: needs rebase for #2391
266+
def execute_tool(self, tool_name, tool_args, context) -> tuple[str, str]:
267+
"""Execute an OC tool with the given arguments and context."""
268+
tool = self.tools[tool_name]
269+
if not context.user_token:
270+
return "Error: No user token provided", "error"
271+
272+
# add token to arguments
273+
args_with_token = {**tool_args, "token": context.user_token}
274+
275+
try:
276+
return tool.invoke(args_with_token), "success"
277+
except Exception as e:
278+
return f"Error: {e}", "error"

ols/src/query_helpers/docs_summarizer.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
from typing import Any, AsyncGenerator, Optional
55

66
from langchain.globals import set_debug
7-
from langchain_core.messages import AIMessage, BaseMessage
7+
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
88
from langchain_core.prompts import ChatPromptTemplate
9-
from langchain_core.tools.base import BaseTool
109
from llama_index.core import VectorStoreIndex
1110

1211
from ols import config
@@ -16,13 +15,34 @@
1615
from ols.customize import reranker
1716
from ols.src.prompts.prompt_generator import GeneratePrompt
1817
from ols.src.query_helpers.query_helper import QueryHelper
19-
from ols.src.tools.oc_cli import token_works_for_oc
20-
from ols.src.tools.tools import execute_oc_tool_calls, oc_tools
18+
from ols.src.tools.tools import ToolProvidersRegistry, ToolsContext
2119
from ols.utils.token_handler import TokenHandler
2220

2321
logger = logging.getLogger(__name__)
2422

2523

24+
def execute_tools(tool_calls, tools_provider_map, tools_context):
25+
"""Execute tools based on the tool calls and context."""
26+
tool_messages = []
27+
for tool_call in tool_calls:
28+
tool_name = tool_call.get("name", "").lower()
29+
tool_args = tool_call.get("args", {})
30+
tool_id = tool_call.get("id")
31+
32+
tool_provider = tools_provider_map.get(tool_name)
33+
if tool_provider is None:
34+
logger.error("Error: unknown tool '%s'", tool_name)
35+
continue
36+
37+
tool_result, status = tool_provider.execute_tool(
38+
tool_name, tool_args, tools_context
39+
)
40+
tool_messages.append(
41+
ToolMessage(tool_result, status=status, tool_call_id=tool_id)
42+
)
43+
return tool_messages
44+
45+
2646
class DocsSummarizer(QueryHelper):
2747
"""A class for summarizing documentation context."""
2848

@@ -31,7 +51,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
3151
super().__init__(*args, **kwargs)
3252
self._prepare_llm()
3353
self.verbose = config.ols_config.logging_config.app_log_level == logging.DEBUG
34-
self._introspection_enabled = config.ols_config.introspection_enabled
54+
55+
self.tool_providers = [
56+
ToolProvidersRegistry.tool_providers[tool_provider.name]
57+
for tool_provider in config.config.tools
58+
]
3559

3660
# disabled - leaks token to logs when set to True
3761
set_debug(False)
@@ -87,7 +111,7 @@ def _prepare_prompt(
87111
["sample"],
88112
[AIMessage("sample")],
89113
self._system_prompt,
90-
self._introspection_enabled,
114+
True if self.tool_providers else False,
91115
).generate_prompt(self.model)
92116
available_tokens = token_handler.calculate_and_check_available_tokens(
93117
temp_prompt.format(**temp_prompt_input),
@@ -120,7 +144,7 @@ def _prepare_prompt(
120144
rag_context,
121145
history,
122146
self._system_prompt,
123-
self._introspection_enabled,
147+
True if self.tool_providers else False,
124148
).generate_prompt(self.model)
125149

126150
# Tokens-check: We trigger the computation of the token count
@@ -161,19 +185,6 @@ def _invoke_llm(
161185
)
162186
return out, generic_token_counter.token_counter
163187

164-
def _get_available_tools(self, user_token: Optional[str]) -> dict[str, BaseTool]:
165-
"""Get available tools based on introspection and user token."""
166-
if not self._introspection_enabled:
167-
return {}
168-
169-
logger.info("Introspection enabled - using default tools selection")
170-
171-
if user_token and user_token.strip() and token_works_for_oc(user_token):
172-
logger.info("Authenticated to 'oc' CLI; adding 'oc' tools")
173-
return oc_tools
174-
175-
return {}
176-
177188
def create_response(
178189
self,
179190
query: str,
@@ -189,22 +200,30 @@ def create_response(
189200
messages = final_prompt.model_copy()
190201
tool_calls = []
191202

192-
# TODO: for the specific tools type (oc) we need specific additional
193-
# context (user_token) to get the tools, we need to think how to make
194-
# it more generic to avoid low-level code changes with new tools type
195-
tools_map = self._get_available_tools(user_token)
203+
tools_context = ToolsContext(user_token=user_token)
204+
205+
# map to hold all tools for registering to llm
206+
tools_map = {}
207+
208+
# map to tell what tool belongs to what tool provider
209+
tools_provider_map = {}
210+
211+
for tool_provider in self.tool_providers:
212+
tools = tool_provider.tools
213+
tools_map.update(tools)
214+
for tool in tools.keys():
215+
# TODO: raise if key already exists
216+
tools_provider_map[tool] = tool_provider
196217

197218
# TODO: Tune system prompt
198219
# TODO: Handle context for each iteration
199220
# TODO: Handle tokens for tool response
200221
# TODO: Improvement for granite
201222
for i in range(MAX_ITERATIONS):
202223

203-
# Force llm to give final response when introspection is disabled
224+
# Force llm to give final response when tools are not provided
204225
# or max iteration is reached
205-
is_final_round = (not self._introspection_enabled) or (
206-
i == MAX_ITERATIONS - 1
207-
)
226+
is_final_round = (not tools_map) or (i == MAX_ITERATIONS - 1)
208227
out, token_counter = self._invoke_llm(
209228
messages, llm_input_values, tools_map, is_final_round
210229
)
@@ -228,8 +247,9 @@ def create_response(
228247
tool_calls.append(
229248
[ToolCall.from_langchain_tool_call(t) for t in out.tool_calls]
230249
)
231-
tool_calls_messages = execute_oc_tool_calls(
232-
tools_map, out.tool_calls, user_token
250+
251+
tool_calls_messages = execute_tools(
252+
out.tool_calls, tools_provider_map, tools_context
233253
)
234254
messages.extend(tool_calls_messages)
235255

0 commit comments

Comments
 (0)