diff --git a/src/codegate/clients/interface.py b/src/codegate/clients/interface.py new file mode 100644 index 00000000..53656c1a --- /dev/null +++ b/src/codegate/clients/interface.py @@ -0,0 +1,61 @@ +import re +from abc import ABC, abstractmethod +from typing import Dict, Type + +from codegate.clients.clients import ClientType + + +class ClientInterface(ABC): + """Secure interface for client-specific message processing""" + + @abstractmethod + def strip_code_snippets(self, message: str) -> str: + """Remove code blocks and file listings to prevent context pollution""" + pass + + +class GenericClient(ClientInterface): + """Default implementation with strict input validation""" + + _MARKDOWN_CODE_REGEX = re.compile(r"```.*?```", re.DOTALL) + _MARKDOWN_FILE_LISTING = re.compile(r"⋮...*?⋮...\n\n", flags=re.DOTALL) + _ENVIRONMENT_DETAILS = re.compile( + r".*?", flags=re.DOTALL + ) + + _CLI_REGEX = re.compile(r"^codegate\s+(.*)$", re.IGNORECASE) + + def strip_code_snippets(self, message: str) -> str: + message = self._MARKDOWN_CODE_REGEX.sub("", message) + message = self._MARKDOWN_FILE_LISTING.sub("", message) + message = self._ENVIRONMENT_DETAILS.sub("", message) + return message + + +class ClineClient(ClientInterface): + """Cline-specific client interface""" + + _CLINE_FILE_REGEX = re.compile( + r"(?i)<\s*file_content\s*[^>]*>.*?", re.DOTALL + ) + + def __init__(self): + self.generic_client = GenericClient() + + def strip_code_snippets(self, message: str) -> str: + message = self.generic_client.strip_code_snippets(message) + return self._CLINE_FILE_REGEX.sub("", message) + + +class ClientFactory: + """Secure factory with updated client mappings""" + + _implementations: Dict[ClientType, Type[ClientInterface]] = { + ClientType.GENERIC: GenericClient, + ClientType.CLINE: ClineClient, + ClientType.KODU: ClineClient, + } + + @classmethod + def create(cls, client_type: ClientType) -> ClientInterface: + return cls._implementations.get(client_type, GenericClient)() diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index e22874a6..12a6d7d1 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -5,6 +5,7 @@ from litellm import ChatCompletionRequest from codegate.clients.clients import ClientType +from codegate.clients.interface import ClientFactory from codegate.db.models import AlertSeverity from codegate.extract_snippets.factory import MessageCodeExtractorFactory from codegate.pipeline.base import ( @@ -22,6 +23,9 @@ # Pre-compiled regex patterns for performance markdown_code_block = re.compile(r"```.*?```", flags=re.DOTALL) markdown_file_listing = re.compile(r"⋮...*?⋮...\n\n", flags=re.DOTALL) +cline_file_listing = re.compile( + r"(?i)<\s*file_content\s*[^>]*>.*?", flags=re.DOTALL +) environment_details = re.compile(r".*?", flags=re.DOTALL) @@ -112,12 +116,11 @@ async def process( # noqa: C901 # Remove code snippets and file listing from the user messages and search for bad packages # in the rest of the user query/messsages - user_messages = markdown_code_block.sub("", user_message) - user_messages = markdown_file_listing.sub("", user_messages) - user_messages = environment_details.sub("", user_messages) + client_if = ClientFactory.create(context.client) + non_code_user_message = client_if.strip_code_snippets(user_message) # split messages into double newlines, to avoid passing so many content in the search - split_messages = re.split(r"|\n|\\n", user_messages) + split_messages = re.split(r"|\n|\\n", non_code_user_message) collected_bad_packages = [] for item_message in filter(None, map(str.strip, split_messages)): # Vector search to find bad packages