Skip to content

Commit 426105e

Browse files
authored
Add integration checks for comparing the result of calling the model API directly vs via CodeGate (#1032)
* Enable codegate enrichment tests Signed-off-by: Radoslav Dimitrov <[email protected]> * Re-use call_provider for calling both codegate and the provider Signed-off-by: Radoslav Dimitrov <[email protected]> --------- Signed-off-by: Radoslav Dimitrov <[email protected]>
1 parent 36a1743 commit 426105e

File tree

4 files changed

+84
-18
lines changed

4 files changed

+84
-18
lines changed

tests/integration/checks.py

+34-13
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def load(test_data: dict) -> List[BaseCheck]:
2929
checks.append(ContainsCheck(test_name))
3030
if test_data.get(DoesNotContainCheck.KEY):
3131
checks.append(DoesNotContainCheck(test_name))
32-
32+
if test_data.get(CodeGateEnrichment.KEY) is not None:
33+
checks.append(CodeGateEnrichment(test_name))
3334
return checks
3435

3536

@@ -51,11 +52,10 @@ async def run_check(self, parsed_response: str, test_data: dict) -> bool:
5152
similarity = await self._calculate_string_similarity(
5253
parsed_response, test_data[DistanceCheck.KEY]
5354
)
55+
logger.debug(f"Similarity: {similarity}")
56+
logger.debug(f"Response: {parsed_response}")
57+
logger.debug(f"Expected Response: {test_data[DistanceCheck.KEY]}")
5458
if similarity < 0.8:
55-
logger.error(f"Test {self.test_name} failed")
56-
logger.error(f"Similarity: {similarity}")
57-
logger.error(f"Response: {parsed_response}")
58-
logger.error(f"Expected Response: {test_data[DistanceCheck.KEY]}")
5959
return False
6060
return True
6161

@@ -64,10 +64,9 @@ class ContainsCheck(BaseCheck):
6464
KEY = "contains"
6565

6666
async def run_check(self, parsed_response: str, test_data: dict) -> bool:
67+
logger.debug(f"Response: {parsed_response}")
68+
logger.debug(f"Expected Response to contain: {test_data[ContainsCheck.KEY]}")
6769
if test_data[ContainsCheck.KEY].strip() not in parsed_response:
68-
logger.error(f"Test {self.test_name} failed")
69-
logger.error(f"Response: {parsed_response}")
70-
logger.error(f"Expected Response to contain: '{test_data[ContainsCheck.KEY]}'")
7170
return False
7271
return True
7372

@@ -76,11 +75,33 @@ class DoesNotContainCheck(BaseCheck):
7675
KEY = "does_not_contain"
7776

7877
async def run_check(self, parsed_response: str, test_data: dict) -> bool:
78+
logger.debug(f"Response: {parsed_response}")
79+
logger.debug(f"Expected Response to not contain: '{test_data[DoesNotContainCheck.KEY]}'")
7980
if test_data[DoesNotContainCheck.KEY].strip() in parsed_response:
80-
logger.error(f"Test {self.test_name} failed")
81-
logger.error(f"Response: {parsed_response}")
82-
logger.error(
83-
f"Expected Response to not contain: '{test_data[DoesNotContainCheck.KEY]}'"
84-
)
8581
return False
8682
return True
83+
84+
85+
class CodeGateEnrichment(BaseCheck):
86+
KEY = "codegate_enrichment"
87+
88+
async def run_check(self, parsed_response: str, test_data: dict) -> bool:
89+
direct_response = test_data["direct_response"]
90+
logger.debug(f"Response (CodeGate): {parsed_response}")
91+
logger.debug(f"Response (Raw model): {direct_response}")
92+
93+
# Use the DistanceCheck to compare the two responses
94+
distance_check = DistanceCheck(self.test_name)
95+
are_similar = await distance_check.run_check(
96+
parsed_response, {DistanceCheck.KEY: direct_response}
97+
)
98+
99+
# Check if the response is enriched by CodeGate.
100+
# If it is, there should be a difference in the similarity score.
101+
expect_enrichment = test_data.get(CodeGateEnrichment.KEY).get("expect_difference", False)
102+
if expect_enrichment:
103+
logger.info("CodeGate enrichment check: Expecting difference")
104+
return not are_similar
105+
# If the response is not enriched, the similarity score should be the same.
106+
logger.info("CodeGate enrichment check: Not expecting difference")
107+
return are_similar

tests/integration/integration_tests.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import requests
1010
import structlog
1111
import yaml
12-
from checks import CheckLoader
12+
from checks import CheckLoader, CodeGateEnrichment
1313
from dotenv import find_dotenv, load_dotenv
1414
from requesters import RequesterFactory
1515

@@ -21,7 +21,7 @@ def __init__(self):
2121
self.requester_factory = RequesterFactory()
2222
self.failed_tests = [] # Track failed tests
2323

24-
def call_codegate(
24+
def call_provider(
2525
self, url: str, headers: dict, data: dict, provider: str, method: str = "POST"
2626
) -> Optional[requests.Response]:
2727
logger.debug(f"Creating requester for provider: {provider}")
@@ -132,18 +132,29 @@ def replacement(match):
132132

133133
async def run_test(self, test: dict, test_headers: dict) -> bool:
134134
test_name = test["name"]
135-
url = test["url"]
136135
data = json.loads(test["data"])
136+
codegate_url = test["url"]
137137
streaming = data.get("stream", False)
138138
provider = test["provider"]
139-
140139
logger.info(f"Starting test: {test_name}")
141140

142-
response = self.call_codegate(url, test_headers, data, provider)
141+
# Call Codegate
142+
response = self.call_provider(codegate_url, test_headers, data, provider)
143143
if not response:
144144
logger.error(f"Test {test_name} failed: No response received")
145145
return False
146146

147+
# Call model directly if specified
148+
direct_response = None
149+
if test.get(CodeGateEnrichment.KEY) is not None:
150+
direct_provider_url = test.get(CodeGateEnrichment.KEY)["provider_url"]
151+
direct_response = self.call_provider(
152+
direct_provider_url, test_headers, data, "not-codegate"
153+
)
154+
if not direct_response:
155+
logger.error(f"Test {test_name} failed: No direct response received")
156+
return False
157+
147158
# Debug response info
148159
logger.debug(f"Response status: {response.status_code}")
149160
logger.debug(f"Response headers: {dict(response.headers)}")
@@ -152,13 +163,24 @@ async def run_test(self, test: dict, test_headers: dict) -> bool:
152163
parsed_response = self.parse_response_message(response, streaming=streaming)
153164
logger.debug(f"Response message: {parsed_response}")
154165

166+
if direct_response:
167+
# Dirty hack to pass direct response to checks
168+
test["direct_response"] = self.parse_response_message(
169+
direct_response, streaming=streaming
170+
)
171+
logger.debug(f"Direct response message: {test['direct_response']}")
172+
155173
# Load appropriate checks for this test
156174
checks = CheckLoader.load(test)
157175

158176
# Run all checks
159177
all_passed = True
160178
for check in checks:
179+
logger.info(f"Running check: {check.__class__.__name__}")
161180
passed_check = await check.run_check(parsed_response, test)
181+
logger.info(
182+
f"Check {check.__class__.__name__} {'passed' if passed_check else 'failed'}"
183+
)
162184
if not passed_check:
163185
all_passed = False
164186

@@ -379,6 +401,7 @@ async def main():
379401
# Exit with status code 1 if any tests failed
380402
if not all_tests_passed:
381403
sys.exit(1)
404+
logger.info("All tests passed")
382405

383406

384407
if __name__ == "__main__":

tests/integration/ollama/testcases.yaml

+12
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ testcases:
3131
name: Ollama Chat
3232
provider: ollama
3333
url: http://127.0.0.1:8989/ollama/chat/completions
34+
codegate_enrichment:
35+
provider_url: http://127.0.0.1:11434/api/chat
36+
expect_difference: false
3437
data: |
3538
{
3639
"max_tokens":4096,
@@ -55,6 +58,9 @@ testcases:
5558
name: Ollama FIM
5659
provider: ollama
5760
url: http://127.0.0.1:8989/ollama/api/generate
61+
codegate_enrichment:
62+
provider_url: http://127.0.0.1:11434/api/generate
63+
expect_difference: false
5864
data: |
5965
{
6066
"stream": true,
@@ -88,6 +94,9 @@ testcases:
8894
name: Ollama Malicious Package
8995
provider: ollama
9096
url: http://127.0.0.1:8989/ollama/chat/completions
97+
codegate_enrichment:
98+
provider_url: http://127.0.0.1:11434/api/chat
99+
expect_difference: true
91100
data: |
92101
{
93102
"max_tokens":4096,
@@ -112,6 +121,9 @@ testcases:
112121
name: Ollama secret redacting chat
113122
provider: ollama
114123
url: http://127.0.0.1:8989/ollama/chat/completions
124+
codegate_enrichment:
125+
provider_url: http://127.0.0.1:11434/api/chat
126+
expect_difference: true
115127
data: |
116128
{
117129
"max_tokens":4096,

tests/integration/vllm/testcases.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ testcases:
3131
name: VLLM Chat
3232
provider: vllm
3333
url: http://127.0.0.1:8989/vllm/chat/completions
34+
codegate_enrichment:
35+
provider_url: http://127.0.0.1:8000/v1/chat/completions
36+
expect_difference: false
3437
data: |
3538
{
3639
"max_tokens":4096,
@@ -55,6 +58,10 @@ testcases:
5558
name: VLLM FIM
5659
provider: vllm
5760
url: http://127.0.0.1:8989/vllm/completions
61+
# This is commented out for now as there's some issue with parsing the streamed response from the model (on the vllm side, not codegate)
62+
# codegate_enrichment:
63+
# provider_url: http://127.0.0.1:8000/v1/completions
64+
# expect_difference: false
5865
data: |
5966
{
6067
"model": "Qwen/Qwen2.5-Coder-0.5B-Instruct",
@@ -84,6 +91,9 @@ testcases:
8491
name: VLLM Malicious Package
8592
provider: vllm
8693
url: http://127.0.0.1:8989/vllm/chat/completions
94+
codegate_enrichment:
95+
provider_url: http://127.0.0.1:8000/v1/chat/completions
96+
expect_difference: true
8797
data: |
8898
{
8999
"max_tokens":4096,

0 commit comments

Comments
 (0)