Skip to content

Commit dfed34f

Browse files
committed
Rebase and update integration tests
Signed-off-by: declark1 <[email protected]>
1 parent 0188a73 commit dfed34f

File tree

1 file changed

+180
-32
lines changed

1 file changed

+180
-32
lines changed

tests/chat_completions_detection.rs

Lines changed: 180 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
1616
*/
1717

18-
use std::{collections::HashMap, vec};
19-
20-
use anyhow::Ok;
2118
use common::{
2219
chat_completions::CHAT_COMPLETIONS_ENDPOINT,
2320
chunker::CHUNKER_UNARY_ENDPOINT,
@@ -38,8 +35,7 @@ use fms_guardrails_orchestr8::{
3835
detector::{ContentAnalysisRequest, ContentAnalysisResponse},
3936
openai::{
4037
ChatCompletion, ChatCompletionChoice, ChatCompletionMessage, ChatDetections, Content,
41-
DetectorConfig, InputDetectionResult, Message, OrchestratorWarning,
42-
OutputDetectionResult, Role,
38+
InputDetectionResult, Message, OrchestratorWarning, OutputDetectionResult, Role,
4339
},
4440
},
4541
models::{
@@ -164,9 +160,13 @@ async fn no_detections() -> Result<(), anyhow::Error> {
164160
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
165161
.json(&json!({
166162
"model": MODEL_ID,
167-
"detectors": DetectorConfig {
168-
input: HashMap::from([(detector_name.into(), DetectorParams::new())]),
169-
output: HashMap::from([(detector_name.into(), DetectorParams::new())]),
163+
"detectors": {
164+
"input": {
165+
detector_name: {},
166+
},
167+
"output": {
168+
detector_name: {},
169+
},
170170
},
171171
"messages": messages,
172172
}))
@@ -288,9 +288,11 @@ async fn input_detections() -> Result<(), anyhow::Error> {
288288
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
289289
.json(&json!({
290290
"model": MODEL_ID,
291-
"detectors": DetectorConfig {
292-
input: HashMap::from([(detector_name.into(), DetectorParams::new())]),
293-
output: HashMap::new(),
291+
"detectors": {
292+
"input": {
293+
detector_name: {},
294+
},
295+
"output": {}
294296
},
295297
"messages": messages,
296298
}))
@@ -445,9 +447,11 @@ async fn input_client_error() -> Result<(), anyhow::Error> {
445447
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
446448
.json(&json!({
447449
"model": MODEL_ID,
448-
"detectors": DetectorConfig {
449-
input: HashMap::from([(detector_name.into(), DetectorParams::new())]),
450-
output: HashMap::new(),
450+
"detectors": {
451+
"input": {
452+
detector_name: {},
453+
},
454+
"output": {}
451455
},
452456
"messages": messages_chunker_error.clone(),
453457
}))
@@ -463,9 +467,11 @@ async fn input_client_error() -> Result<(), anyhow::Error> {
463467
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
464468
.json(&json!({
465469
"model": MODEL_ID,
466-
"detectors": DetectorConfig {
467-
input: HashMap::from([(detector_name.into(), DetectorParams::new())]),
468-
output: HashMap::new(),
470+
"detectors": {
471+
"input": {
472+
detector_name: {},
473+
},
474+
"output": {}
469475
},
470476
"messages": messages_detector_error.clone(),
471477
}))
@@ -481,9 +487,11 @@ async fn input_client_error() -> Result<(), anyhow::Error> {
481487
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
482488
.json(&json!({
483489
"model": MODEL_ID,
484-
"detectors": DetectorConfig {
485-
input: HashMap::from([(detector_name.into(), DetectorParams::new())]),
486-
output: HashMap::new(),
490+
"detectors": {
491+
"input": {
492+
detector_name: {},
493+
},
494+
"output": {}
487495
},
488496
"messages": messages_chat_completions_error.clone(),
489497
}))
@@ -664,9 +672,11 @@ async fn output_detections() -> Result<(), anyhow::Error> {
664672
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
665673
.json(&json!({
666674
"model": MODEL_ID,
667-
"detectors": DetectorConfig {
668-
input: HashMap::new(),
669-
output: HashMap::from([(detector_name.into(), DetectorParams::new())]),
675+
"detectors": {
676+
"input": {},
677+
"output": {
678+
detector_name: {},
679+
},
670680
},
671681
"messages": messages,
672682
}))
@@ -840,9 +850,11 @@ async fn output_client_error() -> Result<(), anyhow::Error> {
840850
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
841851
.json(&json!({
842852
"model": MODEL_ID,
843-
"detectors": DetectorConfig {
844-
input: HashMap::new(),
845-
output: HashMap::from([(detector_name.into(), DetectorParams::new())]),
853+
"detectors": {
854+
"input": {},
855+
"output": {
856+
detector_name: {},
857+
},
846858
},
847859
"messages": messages_chunker_error.clone(),
848860
}))
@@ -858,9 +870,11 @@ async fn output_client_error() -> Result<(), anyhow::Error> {
858870
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
859871
.json(&json!({
860872
"model": MODEL_ID,
861-
"detectors": DetectorConfig {
862-
input: HashMap::new(),
863-
output: HashMap::from([(detector_name.into(), DetectorParams::new())]),
873+
"detectors": {
874+
"input": {},
875+
"output": {
876+
detector_name: {},
877+
},
864878
},
865879
"messages": messages_detector_error.clone(),
866880
}))
@@ -876,9 +890,11 @@ async fn output_client_error() -> Result<(), anyhow::Error> {
876890
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
877891
.json(&json!({
878892
"model": MODEL_ID,
879-
"detectors": DetectorConfig {
880-
input: HashMap::new(),
881-
output: HashMap::from([(detector_name.into(), DetectorParams::new())]),
893+
"detectors": {
894+
"input": {},
895+
"output": {
896+
detector_name: {},
897+
},
882898
},
883899
"messages": messages_chat_completions_error.clone(),
884900
}))
@@ -891,3 +907,135 @@ async fn output_client_error() -> Result<(), anyhow::Error> {
891907

892908
Ok(())
893909
}
910+
911+
// Validate that invalid orchestrator requests returns 422 error
912+
#[test(tokio::test)]
913+
async fn orchestrator_validation_error() -> Result<(), anyhow::Error> {
914+
// Start orchestrator server and its dependencies
915+
let orchestrator_server = TestOrchestratorServer::builder()
916+
.config_path(ORCHESTRATOR_CONFIG_FILE_PATH)
917+
.build()
918+
.await?;
919+
920+
let messages = vec![Message {
921+
content: Some(Content::Text("Hi there!".to_string())),
922+
role: Role::User,
923+
..Default::default()
924+
}];
925+
926+
// Invalid input detector scenario
927+
let response = orchestrator_server
928+
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
929+
.json(&json!({
930+
"model": MODEL_ID,
931+
"detectors": {
932+
"input": {
933+
ANSWER_RELEVANCE_DETECTOR: {},
934+
},
935+
"output": {}
936+
},
937+
"messages": messages.clone(),
938+
}))
939+
.send()
940+
.await?;
941+
942+
let results = response.json::<OrchestratorError>().await?;
943+
debug!("{results:#?}");
944+
assert_eq!(
945+
results,
946+
OrchestratorError {
947+
code: 422,
948+
details: format!(
949+
"detector `{}` is not supported by this endpoint",
950+
ANSWER_RELEVANCE_DETECTOR
951+
)
952+
},
953+
"failed on invalid input detector scenario"
954+
);
955+
956+
// Non-existing input detector scenario
957+
let response = orchestrator_server
958+
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
959+
.json(&json!({
960+
"model": MODEL_ID,
961+
"detectors": {
962+
"input": {
963+
NON_EXISTING_DETECTOR: {},
964+
},
965+
"output": {}
966+
},
967+
"messages": messages.clone(),
968+
}))
969+
.send()
970+
.await?;
971+
972+
let results = response.json::<OrchestratorError>().await?;
973+
debug!("{results:#?}");
974+
assert_eq!(
975+
results,
976+
OrchestratorError {
977+
code: 404,
978+
details: format!("detector `{}` not found", NON_EXISTING_DETECTOR)
979+
},
980+
"failed on non-existing input detector scenario"
981+
);
982+
983+
// Invalid output detector scenario
984+
let response = orchestrator_server
985+
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
986+
.json(&json!({
987+
"model": MODEL_ID,
988+
"detectors": {
989+
"input": {},
990+
"output": {
991+
ANSWER_RELEVANCE_DETECTOR: {},
992+
},
993+
},
994+
"messages": messages.clone(),
995+
}))
996+
.send()
997+
.await?;
998+
999+
let results = response.json::<OrchestratorError>().await?;
1000+
debug!("{results:#?}");
1001+
assert_eq!(
1002+
results,
1003+
OrchestratorError {
1004+
code: 422,
1005+
details: format!(
1006+
"detector `{}` is not supported by this endpoint",
1007+
ANSWER_RELEVANCE_DETECTOR
1008+
)
1009+
},
1010+
"failed on invalid output detector scenario"
1011+
);
1012+
1013+
// Non-existing output detector scenario
1014+
let response = orchestrator_server
1015+
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
1016+
.json(&json!({
1017+
"model": MODEL_ID,
1018+
"detectors": {
1019+
"input": {},
1020+
"output": {
1021+
NON_EXISTING_DETECTOR: {},
1022+
}
1023+
},
1024+
"messages": messages.clone(),
1025+
}))
1026+
.send()
1027+
.await?;
1028+
1029+
let results = response.json::<OrchestratorError>().await?;
1030+
debug!("{results:#?}");
1031+
assert_eq!(
1032+
results,
1033+
OrchestratorError {
1034+
code: 404,
1035+
details: format!("detector `{}` not found", NON_EXISTING_DETECTOR)
1036+
},
1037+
"failed on non-existing input detector scenario"
1038+
);
1039+
1040+
Ok(())
1041+
}

0 commit comments

Comments
 (0)