-
Notifications
You must be signed in to change notification settings - Fork 78
/
Copy pathcodegate.py
184 lines (157 loc) · 7.79 KB
/
codegate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import json
import regex as re
import structlog
from litellm import ChatCompletionRequest
from codegate.clients.clients import ClientType
from codegate.db.models import AlertSeverity
from codegate.extract_snippets.factory import MessageCodeExtractorFactory
from codegate.pipeline.base import (
PipelineContext,
PipelineResult,
PipelineStep,
)
from codegate.storage.storage_engine import StorageEngine
from codegate.utils.package_extractor import PackageExtractor
from codegate.utils.utils import generate_vector_string
logger = structlog.get_logger("codegate")
# Pre-compiled regex patterns for performance
markdown_code_block = re.compile(r"```.*?```", flags=re.DOTALL)
markdown_file_listing = re.compile(r"⋮...*?⋮...\n\n", flags=re.DOTALL)
cline_file_listing = re.compile(r"(?i)<\s*file_content\s*[^>]*>.*?</\s*file_content\s*>", flags=re.DOTALL)
environment_details = re.compile(r"<environment_details>.*?</environment_details>", flags=re.DOTALL)
class CodegateContextRetriever(PipelineStep):
"""
Pipeline step that adds a context message to the completion request when it detects
the word "codegate" in the user message.
"""
@property
def name(self) -> str:
"""
Returns the name of this pipeline step.
"""
return "codegate-context-retriever"
def generate_context_str(
self, objects: list[object], context: PipelineContext, snippet_map: dict
) -> str:
context_str = ""
matched_packages = []
for obj in objects:
# The object is already a dictionary with 'properties'
package_obj = obj["properties"] # type: ignore
matched_packages.append(f"{package_obj['name']} ({package_obj['type']})")
# Retrieve the related snippet if it exists
code_snippet = snippet_map.get(package_obj["name"])
# Add one alert for each package found
context.add_alert(
self.name,
trigger_string=json.dumps(package_obj),
severity_category=AlertSeverity.CRITICAL,
code_snippet=code_snippet,
)
package_str = generate_vector_string(package_obj)
context_str += package_str + "\n"
if matched_packages:
logger.debug(
"Found matching packages in sqlite-vec database", matched_packages=matched_packages
)
return context_str
async def process( # noqa: C901
self, request: ChatCompletionRequest, context: PipelineContext
) -> PipelineResult:
"""
Use RAG DB to add context to the user request
"""
# Get the latest user message
last_message = self.get_last_user_message_block(request, context.client)
if not last_message:
return PipelineResult(request=request)
user_message, last_user_idx = last_message
# Create storage engine object
storage_engine = StorageEngine()
# Extract any code snippets
extractor = MessageCodeExtractorFactory.create_snippet_extractor(context.client)
snippets = extractor.extract_snippets(user_message)
bad_snippet_packages = []
snippet_map = {}
if snippets and len(snippets) > 0:
snippet_language = snippets[0].language
# Collect all packages referenced in the snippets
snippet_packages = []
for snippet in snippets:
extracted_packages = PackageExtractor.extract_packages(
snippet.code, snippet.language
)
snippet_packages.extend(extracted_packages)
for package in extracted_packages:
snippet_map[package] = snippet
logger.info(
f"Found {len(snippet_packages)} packages "
f"for language {snippet_language} in code snippets."
)
# Find bad packages in the snippets
bad_snippet_packages = await storage_engine.search(
language=snippet_language, packages=snippet_packages
) # type: ignore
logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.")
# Remove code snippets and file listing from the user messages and search for bad packages
# in the rest of the user query/messsages
user_messages = markdown_code_block.sub("", user_message)
user_messages = markdown_file_listing.sub("", user_messages)
user_messages = cline_file_listing.sub("", user_messages)
user_messages = environment_details.sub("", user_messages)
# split messages into double newlines, to avoid passing so many content in the search
split_messages = re.split(r"</?task>|\n|\\n", user_messages)
collected_bad_packages = []
for item_message in filter(None, map(str.strip, split_messages)):
# Vector search to find bad packages
bad_packages = await storage_engine.search(query=item_message, distance=0.5, limit=100)
if bad_packages and len(bad_packages) > 0:
collected_bad_packages.extend(bad_packages)
# All bad packages
all_bad_packages = bad_snippet_packages + collected_bad_packages
logger.info(f"Adding {len(all_bad_packages)} bad packages to the context.")
# Generate context string using the searched objects
context_str = "CodeGate did not find any malicious or archived packages."
# Nothing to do if no bad packages are found
if len(all_bad_packages) == 0:
return PipelineResult(request=request, context=context)
else:
# Add context for bad packages
context_str = self.generate_context_str(all_bad_packages, context, snippet_map)
context.bad_packages_found = True
# Make a copy of the request
new_request = request.copy()
# perform replacement in all the messages starting from this index
if context.client != ClientType.OPEN_INTERPRETER:
for i in range(last_user_idx, len(new_request["messages"])):
message = new_request["messages"][i]
message_str = str(message["content"]) # type: ignore
context_msg = message_str
# Add the context to the last user message
if context.client in [ClientType.CLINE, ClientType.KODU]:
match = re.search(r"<task>\s*(.*?)\s*</task>(.*)", message_str, re.DOTALL)
if match:
task_content = match.group(1) # Content within <task>...</task>
rest_of_message = match.group(
2
).strip() # Content after </task>, if any
# Embed the context into the task block
updated_task_content = (
f"<task>Context: {context_str}"
+ f"Query: {task_content.strip()}</task>"
)
# Combine updated task content with the rest of the message
context_msg = updated_task_content + rest_of_message
else:
context_msg = f"Context: {context_str} \n\n Query: {message_str}"
new_request["messages"][i]["content"] = context_msg
logger.debug("Final context message", context_message=context_msg)
else:
# just add a message in the end
new_request["messages"].append(
{
"content": context_str,
"role": "assistant",
}
)
return PipelineResult(request=new_request, context=context)