Skip to content

Wire in Persona matching to muxing #1244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -9,8 +9,6 @@
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "e4c05d7591a8"
2 changes: 1 addition & 1 deletion src/codegate/cli.py
Original file line number Diff line number Diff line change
@@ -16,8 +16,8 @@
from codegate.config import Config, ConfigurationError
from codegate.db.connection import (
init_db_sync,
init_session_if_not_exists,
init_instance,
init_session_if_not_exists,
)
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.sensitive_data.manager import SensitiveDataManager
3 changes: 3 additions & 0 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
@@ -65,6 +65,9 @@ class Config:
# The value 0.3 was found through experimentation. See /tests/muxing/test_semantic_router.py
# It's the threshold value to determine if a persona description is similar to existing personas
persona_diff_desc_threshold = 0.3
# Weight factor for distances in the persona description similarity calculation. Check
# the function _weight_distances for more details. Range is [0, 1].
distances_weight_factor = 0.8

# Provider URLs with defaults
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())
41 changes: 19 additions & 22 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
@@ -617,7 +617,7 @@ async def init_instance(self) -> None:
await self._execute_with_no_return(sql, instance.model_dump())
except IntegrityError as e:
logger.debug(f"Exception type: {type(e)}")
raise AlreadyExistsError(f"Instance already initialized.")
raise AlreadyExistsError("Instance already initialized.")


class DbReader(DbCodeGate):
@@ -1059,6 +1059,24 @@ async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]:
)
return personas[0] if personas else None

async def get_persona_embed_by_name(self, persona_name: str) -> Optional[PersonaEmbedding]:
"""
Get a persona by name.
"""
sql = text(
"""
SELECT
id, name, description, description_embedding
FROM personas
WHERE name = :name
"""
)
conditions = {"name": persona_name}
personas = await self._exec_select_conditions_to_pydantic(
PersonaEmbedding, sql, conditions, should_raise=True
)
return personas[0] if personas else None

async def get_distance_to_existing_personas(
self, query_embedding: np.ndarray, exclude_id: Optional[str]
) -> List[PersonaDistance]:
@@ -1086,27 +1104,6 @@ async def get_distance_to_existing_personas(
)
return persona_distances

async def get_distance_to_persona(
self, persona_id: str, query_embedding: np.ndarray
) -> PersonaDistance:
"""
Get the distance between a persona and a query embedding.
"""
sql = """
SELECT
id,
name,
description,
vec_distance_cosine(description_embedding, :query_embedding) as distance
FROM personas
WHERE id = :id
"""
conditions = {"id": persona_id, "query_embedding": query_embedding}
persona_distance = await self._exec_vec_db_query_to_pydantic(
sql, conditions, PersonaDistance
)
return persona_distance[0]

async def get_all_personas(self) -> List[Persona]:
"""
Get all the personas.
2 changes: 1 addition & 1 deletion src/codegate/db/models.py
Original file line number Diff line number Diff line change
@@ -276,7 +276,7 @@ def nd_array_custom_serializer(x):
NdArray = Annotated[
np.ndarray,
BeforeValidator(nd_array_custom_before_validator),
PlainSerializer(nd_array_custom_serializer, return_type=str),
PlainSerializer(nd_array_custom_serializer, return_type=np.ndarray),
]

VALID_PERSONA_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_ -]+$")
4 changes: 4 additions & 0 deletions src/codegate/muxing/models.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,10 @@ class MuxMatcherType(str, Enum):
fim_filename = "fim_filename"
# Match based on chat request type. It will match if the request type is chat
chat_filename = "chat_filename"
# Match the user messages to the persona description
persona_description = "persona_description"
# Match the system prompt to the persona description
sys_prompt_persona_desc = "sys_prompt_persona_desc"


class MuxRule(pydantic.BaseModel):
103 changes: 85 additions & 18 deletions src/codegate/muxing/persona.py
Original file line number Diff line number Diff line change
@@ -43,6 +43,7 @@ def __init__(self):
self._n_gpu = conf.chat_model_n_gpu_layers
self._persona_threshold = conf.persona_threshold
self._persona_diff_desc_threshold = conf.persona_diff_desc_threshold
self._distances_weight_factor = conf.distances_weight_factor
self._db_recorder = DbRecorder()
self._db_reader = DbReader()

@@ -99,18 +100,17 @@ def _clean_text_for_embedding(self, text: str) -> str:

return text

async def _embed_text(self, text: str) -> np.ndarray:
async def _embed_texts(self, texts: List[str]) -> np.ndarray:
"""
Helper function to embed text using the inference engine.
"""
cleaned_text = self._clean_text_for_embedding(text)
cleaned_texts = [self._clean_text_for_embedding(text) for text in texts]
# .embed returns a list of embeddings
embed_list = await self._inference_engine.embed(
self._embeddings_model, [cleaned_text], n_gpu_layers=self._n_gpu
self._embeddings_model, cleaned_texts, n_gpu_layers=self._n_gpu
)
# Use only the first entry in the list and make sure we have the appropriate type
logger.debug("Text embedded in semantic routing", text=cleaned_text[:50])
return np.array(embed_list[0], dtype=np.float32)
logger.debug("Text embedded in semantic routing", num_texts=len(texts))
return np.array(embed_list, dtype=np.float32)

async def _is_persona_description_diff(
self, emb_persona_desc: np.ndarray, exclude_id: Optional[str]
@@ -142,7 +142,8 @@ async def _validate_persona_description(
Validate the persona description by embedding the text and checking if it is
different enough from existing personas.
"""
emb_persona_desc = await self._embed_text(persona_desc)
emb_persona_desc_list = await self._embed_texts([persona_desc])
emb_persona_desc = emb_persona_desc_list[0]
if not await self._is_persona_description_diff(emb_persona_desc, exclude_id):
raise PersonaSimilarDescriptionError(
"The persona description is too similar to existing personas."
@@ -217,21 +218,87 @@ async def delete_persona(self, persona_name: str) -> None:
await self._db_recorder.delete_persona(persona.id)
logger.info(f"Deleted persona {persona_name} from the database.")

async def check_persona_match(self, persona_name: str, query: str) -> bool:
async def _get_cosine_distance(self, emb_queries: np.ndarray, emb_persona: np.ndarray) -> float:
"""
Check if the query matches the persona description. A vector similarity
search is performed between the query and the persona description.
Calculate the cosine distance between the queries embeddings and persona embedding.
Persona embedding is a single vector of length M
Queries embeddings is a matrix of shape (N, M)
N is the number of queries. User messages in this case.
M is the number of dimensions in the embedding

Defintion of cosine distance: 1 - cosine similarity
[Cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity)

NOTE: Experimented by individually querying SQLite for each query, but as the number
of queries increases, the performance is better with NumPy. If the number of queries
is small the performance is onpar. Hence the decision to use NumPy.
"""
# Handle the case where we have a single query (single user message)
if emb_queries.ndim == 1:
emb_queries = emb_queries.reshape(1, -1)

emb_queries_norm = np.linalg.norm(emb_queries, axis=1)
persona_embed_norm = np.linalg.norm(emb_persona)
cosine_similarities = np.dot(emb_queries, emb_persona.T) / (
emb_queries_norm * persona_embed_norm
)
# We could also use directly cosine_similarities but we get the distance to match
# the behavior of SQLite function vec_distance_cosine
cosine_distances = 1 - cosine_similarities
return cosine_distances

async def _weight_distances(self, distances: np.ndarray) -> np.ndarray:
"""
Weights the received distances, with later positions being more important and the
last position unchanged. The reasoning is that the distances correspond to user
messages, with the last message being the most recent and therefore the most
important.

Args:
distances: NumPy array of float values between 0 and 2
weight_factor: Factor that determines how quickly weights increase (0-1)
Lower values create a steeper importance curve. 1 makes
all weights equal.

Returns:
Weighted distances as a NumPy array
"""
# Get array length
n = len(distances)

# Create positions array in reverse order (n-1, n-2, ..., 1, 0)
# This makes the last element have position 0
positions = np.arange(start=n - 1, stop=-1, step=-1, dtype=np.float32)

# Create weights - now the last element (position 0) gets weight 1
weights = self._distances_weight_factor**positions

# Apply weights by dividing distances
# Smaller weight -> larger effective distance
weighted_distances = distances / weights
return weighted_distances

async def check_persona_match(self, persona_name: str, queries: List[str]) -> bool:
"""
Check if the queries match the persona description. A vector similarity
search is performed between the queries and the persona description.
0 means the vectors are identical, 2 means they are orthogonal.
See
[sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine)

The vectors are compared using cosine similarity implemented in _get_cosine_distance.
"""
persona = await self._db_reader.get_persona_by_name(persona_name)
if not persona:
persona_embed = await self._db_reader.get_persona_embed_by_name(persona_name)
if not persona_embed:
raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.")

emb_query = await self._embed_text(query)
persona_distance = await self._db_reader.get_distance_to_persona(persona.id, emb_query)
logger.info(f"Persona distance to {persona_name}", distance=persona_distance.distance)
if persona_distance.distance < self._persona_threshold:
emb_queries = await self._embed_texts(queries)
cosine_distances = await self._get_cosine_distance(
emb_queries, persona_embed.description_embedding
)
logger.debug("Cosine distances calculated", cosine_distances=cosine_distances)

weighted_distances = await self._weight_distances(cosine_distances)
logger.info("Weighted distances to persona", weighted_distances=weighted_distances)

if np.any(weighted_distances < self._persona_threshold):
return True
return False
94 changes: 89 additions & 5 deletions src/codegate/muxing/rulematcher.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError
from codegate.extract_snippets.factory import BodyCodeExtractorFactory
from codegate.muxing import models as mux_models
from codegate.muxing.persona import PersonaManager

logger = structlog.get_logger("codegate")

@@ -60,7 +61,7 @@ def __init__(self, route: ModelRoute, mux_rule: mux_models.MuxRule):
self._mux_rule = mux_rule

@abstractmethod
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""Return True if the rule matches the thing_to_match."""
pass

@@ -82,6 +83,8 @@ def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatch
mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher,
mux_models.MuxMatcherType.fim_filename: RequestTypeAndFileMuxingRuleMatcher,
mux_models.MuxMatcherType.chat_filename: RequestTypeAndFileMuxingRuleMatcher,
mux_models.MuxMatcherType.persona_description: UserMsgsPersonaDescMuxMatcher,
mux_models.MuxMatcherType.sys_prompt_persona_desc: SysPromptPersonaDescMuxMatcher,
}

try:
@@ -95,7 +98,7 @@ def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatch
class CatchAllMuxingRuleMatcher(MuxingRuleMatcher):
"""A catch all muxing rule matcher."""

def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
logger.info("Catch all rule matched")
return True

@@ -130,7 +133,7 @@ def _is_matcher_in_filenames(self, detected_client: ClientType, data: dict) -> b
)
return is_filename_match

def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""
Return True if the matcher is in one of the request filenames.
"""
@@ -154,7 +157,7 @@ def _is_request_type_match(self, is_fim_request: bool) -> bool:
return True
return False

def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""
Return True if the matcher is in one of the request filenames and
if the request type matches the MuxMatcherType.
@@ -171,6 +174,87 @@ def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
return is_rule_matched


class PersonaDescMuxMatcher(MuxingRuleMatcher):
"""Muxing rule to match the request content to a persona description."""

@abstractmethod
def _get_queries_for_persona_match(self, body: Dict) -> List[str]:
"""
Get the queries to use for persona matching.
"""
pass

async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""
Return True if the matcher is the persona description matched with the queries.
The queries are extracted from the body and will depend on the type of matcher.
1. UserMessagesPersonaDescMuxMatcher: Extracts queries from the user messages in the body.
2. SysPromptPersonaDescMuxMatcher: Extracts queries from the system messages in the body.
"""
queries = self._get_queries_for_persona_match(thing_to_match.body)
if not queries:
return False

persona_manager = PersonaManager()
is_persona_matched = await persona_manager.check_persona_match(
persona_name=self._mux_rule.matcher, queries=queries
)
if is_persona_matched:
logger.info("Persona rule matched", persona=self._mux_rule.matcher)
return is_persona_matched


class UserMsgsPersonaDescMuxMatcher(PersonaDescMuxMatcher):

def _get_queries_for_persona_match(self, body: Dict) -> List[str]:
"""
Get the queries from the user messages in the body.
"""
user_messages = []
for msg in body.get("messages", []):
if msg.get("role", "") == "user":
msgs_content = msg.get("content")
if not msgs_content:
continue
if isinstance(msgs_content, list):
for msg_content in msgs_content:
if msg_content.get("type", "") == "text":
user_messages.append(msg_content.get("text", ""))
elif isinstance(msgs_content, str):
user_messages.append(msgs_content)
return user_messages


class SysPromptPersonaDescMuxMatcher(PersonaDescMuxMatcher):

def _get_queries_for_persona_match(self, body: Dict) -> List[str]:
"""
Get the queries from the system messages in the body.
"""
system_messages = []
for msg in body.get("messages", []):
if msg.get("role", "") in ["system", "developer"]:
msgs_content = msg.get("content")
if not msgs_content:
continue
if isinstance(msgs_content, list):
for msg_content in msgs_content:
if msg_content.get("type", "") == "text":
system_messages.append(msg_content.get("text", ""))
elif isinstance(msgs_content, str):
system_messages.append(msgs_content)

# Handling the anthropic system prompt
anthropic_sys_prompt = body.get("system")
if anthropic_sys_prompt:
system_messages.append(anthropic_sys_prompt)

# In an ideal world, the length of system_messages should be 1. Returnin the list
# to handle any edge cases and to not break parent function's signature.
return system_messages


class MuxingRulesinWorkspaces:
"""A thread safe dictionary to store the muxing rules in workspaces."""

@@ -214,7 +298,7 @@ async def get_match_for_active_workspace(
try:
rules = await self.get_ws_rules(self._active_workspace)
for rule in rules:
if rule.match(thing_to_match):
if await rule.match(thing_to_match):
return rule.destination()
return None
except KeyError:
62 changes: 59 additions & 3 deletions tests/muxing/test_persona.py
Original file line number Diff line number Diff line change
@@ -90,7 +90,7 @@ async def test_persona_not_exist_match(semantic_router_mocked_db: PersonaManager
persona_name = "test_persona"
query = "test_query"
with pytest.raises(PersonaDoesNotExistError):
await semantic_router_mocked_db.check_persona_match(persona_name, query)
await semantic_router_mocked_db.check_persona_match(persona_name, [query])


class PersonaMatchTest(BaseModel):
@@ -333,11 +333,39 @@ async def test_check_persona_pass_match(
# Check for the queries that should pass
for query in persona_match_test.pass_queries:
match = await semantic_router_mocked_db.check_persona_match(
persona_match_test.persona_name, query
persona_match_test.persona_name, [query]
)
assert match is True


@pytest.mark.asyncio
@pytest.mark.parametrize(
"persona_match_test",
[
simple_persona,
architect,
coder,
devops_sre,
],
)
async def test_check_persona_pass_match_vector(
semantic_router_mocked_db: PersonaManager, persona_match_test: PersonaMatchTest
):
"""Test checking persona match."""
await semantic_router_mocked_db.add_persona(
persona_match_test.persona_name, persona_match_test.persona_desc
)

# We disable the weighting between distances since these are no user messages that
# need to be weighted differently, they all are weighted the same.
semantic_router_mocked_db._distances_weight_factor = 1.0
# Check for match passing the entire list
match = await semantic_router_mocked_db.check_persona_match(
persona_match_test.persona_name, persona_match_test.pass_queries
)
assert match is True


@pytest.mark.asyncio
@pytest.mark.parametrize(
"persona_match_test",
@@ -359,11 +387,39 @@ async def test_check_persona_fail_match(
# Check for the queries that should fail
for query in persona_match_test.fail_queries:
match = await semantic_router_mocked_db.check_persona_match(
persona_match_test.persona_name, query
persona_match_test.persona_name, [query]
)
assert match is False


@pytest.mark.asyncio
@pytest.mark.parametrize(
"persona_match_test",
[
simple_persona,
architect,
coder,
devops_sre,
],
)
async def test_check_persona_fail_match_vector(
semantic_router_mocked_db: PersonaManager, persona_match_test: PersonaMatchTest
):
"""Test checking persona match."""
await semantic_router_mocked_db.add_persona(
persona_match_test.persona_name, persona_match_test.persona_desc
)

# We disable the weighting between distances since these are no user messages that
# need to be weighted differently, they all are weighted the same.
semantic_router_mocked_db._distances_weight_factor = 1.0
# Check for match passing the entire list
match = await semantic_router_mocked_db.check_persona_match(
persona_match_test.persona_name, persona_match_test.fail_queries
)
assert match is False


@pytest.mark.asyncio
@pytest.mark.parametrize(
"personas",
120 changes: 113 additions & 7 deletions tests/muxing/test_rulematcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from unittest.mock import MagicMock
from typing import Dict, List
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

@@ -24,6 +25,7 @@
)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"matcher_blob, thing_to_match",
[
@@ -40,12 +42,13 @@
),
],
)
def test_catch_all(matcher_blob, thing_to_match):
async def test_catch_all(matcher_blob, thing_to_match):
muxing_rule_matcher = rulematcher.CatchAllMuxingRuleMatcher(mocked_route_openai, matcher_blob)
# It should always match
assert muxing_rule_matcher.match(thing_to_match) is True
assert await muxing_rule_matcher.match(thing_to_match) is True


@pytest.mark.asyncio
@pytest.mark.parametrize(
"matcher, filenames_to_match, expected_bool",
[
@@ -60,7 +63,7 @@ def test_catch_all(matcher_blob, thing_to_match):
("*.ts", ["main.tsx", "test.tsx"], False), # Extension no match
],
)
def test_file_matcher(
async def test_file_matcher(
matcher,
filenames_to_match,
expected_bool,
@@ -81,9 +84,10 @@ def test_file_matcher(
is_fim_request=False,
client_type="generic",
)
assert muxing_rule_matcher.match(mocked_thing_to_match) is expected_bool
assert await muxing_rule_matcher.match(mocked_thing_to_match) is expected_bool


@pytest.mark.asyncio
@pytest.mark.parametrize(
"matcher, filenames_to_match, expected_bool_filenames",
[
@@ -107,7 +111,7 @@ def test_file_matcher(
(True, "chat_filename", False), # No match
],
)
def test_request_file_matcher(
async def test_request_file_matcher(
matcher,
filenames_to_match,
expected_bool_filenames,
@@ -143,18 +147,120 @@ def test_request_file_matcher(
)
is expected_bool_filenames
)
assert muxing_rule_matcher.match(mocked_thing_to_match) is (
assert await muxing_rule_matcher.match(mocked_thing_to_match) is (
expected_bool_request and expected_bool_filenames
)


# We mock PersonaManager because it's tested in /tests/persona/test_manager.py
MOCK_PERSONA_MANAGER = AsyncMock()
MOCK_PERSONA_MANAGER.check_persona_match.return_value = True


@pytest.mark.asyncio
@pytest.mark.parametrize(
"body, expected_queries",
[
({"messages": [{"role": "system", "content": "Youre helpful"}]}, []),
({"messages": [{"role": "user", "content": "hello"}]}, ["hello"]),
(
{"messages": [{"role": "user", "content": [{"type": "text", "text": "hello_dict"}]}]},
["hello_dict"],
),
],
)
async def test_user_msgs_persona_desc_matcher(body: Dict, expected_queries: List[str]):
mux_rule = mux_models.MuxRule(
provider_id="1",
model="fake-gpt",
matcher_type="persona_description",
matcher="foo_persona",
)
muxing_rule_matcher = rulematcher.UserMsgsPersonaDescMuxMatcher(mocked_route_openai, mux_rule)

mocked_thing_to_match = mux_models.ThingToMatchMux(
body=body,
url_request_path="/chat/completions",
is_fim_request=False,
client_type="generic",
)

resulting_queries = muxing_rule_matcher._get_queries_for_persona_match(body)
assert set(resulting_queries) == set(expected_queries)

with patch("codegate.muxing.rulematcher.PersonaManager", return_value=MOCK_PERSONA_MANAGER):
result = await muxing_rule_matcher.match(mocked_thing_to_match)

if expected_queries:
assert result is True
else:
assert result is False


@pytest.mark.asyncio
@pytest.mark.parametrize(
"body, expected_queries",
[
({"messages": [{"role": "system", "content": "Youre helpful"}]}, ["Youre helpful"]),
({"messages": [{"role": "user", "content": "hello"}]}, []),
(
{
"messages": [
{"role": "system", "content": "Youre helpful"},
{"role": "user", "content": "hello"},
]
},
["Youre helpful"],
),
(
{"messages": [{"role": "user", "content": "hello"}], "system": "Anthropic system"},
["Anthropic system"],
),
],
)
async def test_sys_prompt_persona_desc_matcher(body: Dict, expected_queries: List[str]):
mux_rule = mux_models.MuxRule(
provider_id="1",
model="fake-gpt",
matcher_type="sys_prompt_persona_desc",
matcher="foo_persona",
)
muxing_rule_matcher = rulematcher.SysPromptPersonaDescMuxMatcher(mocked_route_openai, mux_rule)

mocked_thing_to_match = mux_models.ThingToMatchMux(
body=body,
url_request_path="/chat/completions",
is_fim_request=False,
client_type="generic",
)

resulting_queries = muxing_rule_matcher._get_queries_for_persona_match(body)
assert set(resulting_queries) == set(expected_queries)

with patch("codegate.muxing.rulematcher.PersonaManager", return_value=MOCK_PERSONA_MANAGER):
result = await muxing_rule_matcher.match(mocked_thing_to_match)

if expected_queries:
assert result is True
else:
assert result is False


@pytest.mark.parametrize(
"matcher_type, expected_class",
[
(mux_models.MuxMatcherType.catch_all, rulematcher.CatchAllMuxingRuleMatcher),
(mux_models.MuxMatcherType.filename_match, rulematcher.FileMuxingRuleMatcher),
(mux_models.MuxMatcherType.fim_filename, rulematcher.RequestTypeAndFileMuxingRuleMatcher),
(mux_models.MuxMatcherType.chat_filename, rulematcher.RequestTypeAndFileMuxingRuleMatcher),
(
mux_models.MuxMatcherType.persona_description,
rulematcher.UserMsgsPersonaDescMuxMatcher,
),
(
mux_models.MuxMatcherType.sys_prompt_persona_desc,
rulematcher.SysPromptPersonaDescMuxMatcher,
),
("invalid_matcher", None),
],
)