Skip to content

refactor: task handlers #355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c4135e4
Add classification_with_gen and streaming_classification_with_gen han…
declark1 Mar 25, 2025
e84fe44
Add placeholders (wip)
declark1 Mar 31, 2025
e939885
Implement handlers (wip)
declark1 Mar 31, 2025
4fe26f3
.
declark1 Apr 1, 2025
310cb6e
Fix default thresholds
declark1 Apr 1, 2025
614fa77
Cleanups
declark1 Apr 1, 2025
efe29b1
Return not yet supported error for chat completions streaming
declark1 Apr 2, 2025
66c018e
Implement stream_content_detection handler (wip)
declark1 Apr 2, 2025
72a0ad0
Fix default threshold for text_contents_detection_streams
declark1 Apr 2, 2025
ec1bbf3
Update input/response/detection channel buffer sizes to 128
declark1 Apr 2, 2025
99e5db9
Rebase and update chat_detection test
declark1 Apr 3, 2025
68f532a
Update chat_completions_detection tests to drop options
declark1 Apr 4, 2025
587d906
Fix
declark1 Apr 7, 2025
ea732bb
Return None when no detections instead of Some with default ChatDetec…
declark1 Apr 7, 2025
fa82017
Revert comment changes
declark1 Apr 7, 2025
df84ecd
Formatting
declark1 Apr 7, 2025
91f8d28
Update src/orchestrator/handlers/classification_with_gen.rs
declark1 Apr 8, 2025
cb2cbca
Update field docstrings for task structs
declark1 Apr 8, 2025
8bda45e
Instrument handle methods and task functions, inject current span int…
declark1 Apr 8, 2025
e784fae
Fix - update GuardrailsConfig::input_detectors and output_detectors m…
declark1 Apr 9, 2025
c4ad327
Tracing tweaks
declark1 Apr 10, 2025
60eab19
Default to compact log format
declark1 Apr 10, 2025
eec8c3f
Drop instrument from output_detection_response
declark1 Apr 10, 2025
f8a32e1
Add back detector request events. Change default logging format back …
declark1 Apr 14, 2025
327e67f
Add config field to task started log event with detectors config
declark1 Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ impl OtlpProtocol {

#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub enum LogFormat {
Compact,
#[default]
Full,
Compact,
Pretty,
JSON,
}
Expand Down
6 changes: 1 addition & 5 deletions src/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use hyper_timeout::TimeoutConnector;
use hyper_util::rt::TokioExecutor;
use tonic::{Request, metadata::MetadataMap};
use tower::{ServiceBuilder, timeout::TimeoutLayer};
use tracing::{Span, debug, instrument};
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
use url::Url;

Expand Down Expand Up @@ -205,7 +205,6 @@ impl ClientMap {
}
}

#[instrument(skip_all, fields(hostname = service_config.hostname))]
pub async fn create_http_client(
default_port: u16,
service_config: &ServiceConfig,
Expand All @@ -220,7 +219,6 @@ pub async fn create_http_client(
base_url
.set_port(Some(port))
.unwrap_or_else(|_| panic!("error setting port: {}", port));
debug!(%base_url, "creating HTTP client");

let connect_timeout = Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SEC);
let request_timeout = Duration::from_secs(
Expand Down Expand Up @@ -257,7 +255,6 @@ pub async fn create_http_client(
Ok(HttpClient::new(base_url, client))
}

#[instrument(skip_all, fields(hostname = service_config.hostname))]
pub async fn create_grpc_client<C: Debug + Clone>(
default_port: u16,
service_config: &ServiceConfig,
Expand All @@ -270,7 +267,6 @@ pub async fn create_grpc_client<C: Debug + Clone>(
};
let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)).unwrap();
base_url.set_port(Some(port)).unwrap();
debug!(%base_url, "creating gRPC client");
let connect_timeout = Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SEC);
let request_timeout = Duration::from_secs(
service_config
Expand Down
6 changes: 1 addition & 5 deletions src/clients/chunker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use axum::http::HeaderMap;
use futures::{Future, Stream, StreamExt, TryStreamExt};
use ginepro::LoadBalancedChannel;
use tonic::{Code, Request, Response, Status, Streaming};
use tracing::{Span, debug, info, instrument};
use tracing::{Span, instrument};

use super::{
BoxStream, Client, Error, create_grpc_client, errors::grpc_to_http_code,
Expand Down Expand Up @@ -68,28 +68,24 @@ impl ChunkerClient {
}
}

#[instrument(skip_all, fields(model_id))]
pub async fn tokenization_task_predict(
&self,
model_id: &str,
request: ChunkerTokenizationTaskRequest,
) -> Result<TokenizationResults, Error> {
let mut client = self.client.clone();
let request = request_with_headers(request, model_id);
debug!(?request, "sending client request");
let response = client.chunker_tokenization_task_predict(request).await?;
let span = Span::current();
trace_context_from_grpc_response(&span, &response);
Ok(response.into_inner())
}

#[instrument(skip_all, fields(model_id))]
pub async fn bidi_streaming_tokenization_task_predict(
&self,
model_id: &str,
request_stream: BoxStream<BidiStreamingChunkerTokenizationTaskRequest>,
) -> Result<BoxStream<Result<ChunkerTokenizationStreamResult, Error>>, Error> {
info!("sending client stream request");
let mut client = self.client.clone();
let request = request_with_headers(request_stream, model_id);
// NOTE: this is an ugly workaround to avoid bogus higher-ranked lifetime errors.
Expand Down
2 changes: 0 additions & 2 deletions src/clients/detector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use axum::http::HeaderMap;
use http::header::CONTENT_TYPE;
use hyper::StatusCode;
use serde::Deserialize;
use tracing::instrument;
use url::Url;

use super::{
Expand Down Expand Up @@ -79,7 +78,6 @@ pub trait DetectorClientExt: HttpClientExt {
}

impl<C: DetectorClient + HttpClientExt> DetectorClientExt for C {
#[instrument(skip_all, fields(model_id, url))]
async fn post_to_detector<U: ResponseBody>(
&self,
model_id: &str,
Expand Down
3 changes: 1 addition & 2 deletions src/clients/detector/text_chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use async_trait::async_trait;
use hyper::HeaderMap;
use serde::Serialize;
use tracing::{info, instrument};
use tracing::info;

use super::{DEFAULT_PORT, DetectorClient, DetectorClientExt};
use crate::{
Expand Down Expand Up @@ -63,7 +63,6 @@ impl TextChatDetectorClient {
&self.client
}

#[instrument(skip_all, fields(model_id, ?headers))]
pub async fn text_chat(
&self,
model_id: &str,
Expand Down
3 changes: 1 addition & 2 deletions src/clients/detector/text_contents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::collections::BTreeMap;
use async_trait::async_trait;
use hyper::HeaderMap;
use serde::{Deserialize, Serialize};
use tracing::{info, instrument};
use tracing::info;

use super::{DEFAULT_PORT, DetectorClient, DetectorClientExt};
use crate::{
Expand Down Expand Up @@ -61,7 +61,6 @@ impl TextContentsDetectorClient {
&self.client
}

#[instrument(skip_all, fields(model_id))]
pub async fn text_contents(
&self,
model_id: &str,
Expand Down
3 changes: 1 addition & 2 deletions src/clients/detector/text_context_doc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use async_trait::async_trait;
use hyper::HeaderMap;
use serde::{Deserialize, Serialize};
use tracing::{info, instrument};
use tracing::info;

use super::{DEFAULT_PORT, DetectorClient, DetectorClientExt};
use crate::{
Expand Down Expand Up @@ -59,7 +59,6 @@ impl TextContextDocDetectorClient {
&self.client
}

#[instrument(skip_all, fields(model_id))]
pub async fn text_context_doc(
&self,
model_id: &str,
Expand Down
3 changes: 1 addition & 2 deletions src/clients/detector/text_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use async_trait::async_trait;
use hyper::HeaderMap;
use serde::Serialize;
use tracing::{info, instrument};
use tracing::info;

use super::{DEFAULT_PORT, DetectorClient, DetectorClientExt};
use crate::{
Expand Down Expand Up @@ -59,7 +59,6 @@ impl TextGenerationDetectorClient {
&self.client
}

#[instrument(skip_all, fields(model_id))]
pub async fn text_generation(
&self,
model_id: &str,
Expand Down
22 changes: 0 additions & 22 deletions src/clients/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
use async_trait::async_trait;
use futures::{StreamExt, TryStreamExt};
use hyper::HeaderMap;
use tracing::{debug, instrument};

use super::{BoxStream, Client, Error, NlpClient, TgisClient};
use crate::{
Expand Down Expand Up @@ -63,7 +62,6 @@ impl GenerationClient {
Self(None)
}

#[instrument(skip_all, fields(model_id))]
pub async fn tokenize(
&self,
model_id: String,
Expand All @@ -79,19 +77,15 @@ impl GenerationClient {
return_offsets: false,
truncate_input_tokens: 0,
};
debug!(provider = "tgis", ?request, "sending tokenize request");
let mut response = client.tokenize(request, headers).await?;
debug!(provider = "tgis", ?response, "received tokenize response");
let response = response.responses.swap_remove(0);
Ok((response.token_count, response.tokens))
}
Some(GenerationClientInner::Nlp(client)) => {
let request = TokenizationTaskRequest { text };
debug!(provider = "nlp", ?request, "sending tokenize request");
let response = client
.tokenization_task_predict(&model_id, request, headers)
.await?;
debug!(provider = "nlp", ?response, "received tokenize response");
let tokens = response
.results
.into_iter()
Expand All @@ -103,7 +97,6 @@ impl GenerationClient {
}
}

#[instrument(skip_all, fields(model_id))]
pub async fn generate(
&self,
model_id: String,
Expand All @@ -120,9 +113,7 @@ impl GenerationClient {
requests: vec![GenerationRequest { text }],
params,
};
debug!(provider = "tgis", ?request, "sending generate request");
let response = client.generate(request, headers).await?;
debug!(provider = "tgis", ?response, "received generate response");
Ok(response.into())
}
Some(GenerationClientInner::Nlp(client)) => {
Expand Down Expand Up @@ -157,18 +148,15 @@ impl GenerationClient {
..Default::default()
}
};
debug!(provider = "nlp", ?request, "sending generate request");
let response = client
.text_generation_task_predict(&model_id, request, headers)
.await?;
debug!(provider = "nlp", ?response, "received generate response");
Ok(response.into())
}
None => Err(Error::ModelNotFound { model_id }),
}
}

#[instrument(skip_all, fields(model_id))]
pub async fn generate_stream(
&self,
model_id: String,
Expand All @@ -185,11 +173,6 @@ impl GenerationClient {
request: Some(GenerationRequest { text }),
params,
};
debug!(
provider = "tgis",
?request,
"sending generate_stream request"
);
let response_stream = client
.generate_stream(request, headers)
.await?
Expand Down Expand Up @@ -229,11 +212,6 @@ impl GenerationClient {
..Default::default()
}
};
debug!(
provider = "nlp",
?request,
"sending generate_stream request"
);
let response_stream = client
.server_streaming_text_generation_task_predict(&model_id, request, headers)
.await?
Expand Down
20 changes: 2 additions & 18 deletions src/clients/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::{fmt::Debug, ops::Deref, time::Duration};
use http_body_util::{BodyExt, Full, combinators::BoxBody};
use hyper::{
HeaderMap, Method, Request, StatusCode,
body::{Body, Bytes, Incoming},
body::{Bytes, Incoming},
};
use hyper_rustls::HttpsConnector;
use hyper_timeout::TimeoutConnector;
Expand All @@ -36,7 +36,7 @@ use tower_http::{
Trace, TraceLayer,
},
};
use tracing::{Span, debug, error, info, info_span, instrument};
use tracing::{Span, error, info, info_span};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use url::Url;

Expand Down Expand Up @@ -137,7 +137,6 @@ impl HttpClient {
self.base_url.join(path).unwrap()
}

#[instrument(skip_all, fields(url))]
pub async fn get(
&self,
url: Url,
Expand All @@ -147,7 +146,6 @@ impl HttpClient {
self.send(url, Method::GET, headers, body).await
}

#[instrument(skip_all, fields(url))]
pub async fn post(
&self,
url: Url,
Expand All @@ -157,7 +155,6 @@ impl HttpClient {
self.send(url, Method::POST, headers, body).await
}

#[instrument(skip_all, fields(url))]
pub async fn send(
&self,
url: Url,
Expand All @@ -172,12 +169,6 @@ impl HttpClient {
.uri(url.as_uri());
match builder.headers_mut() {
Some(headers_mut) => {
debug!(
?url,
?headers,
?body,
"sending client request"
);
headers_mut.extend(headers);
let body =
Full::new(Bytes::from(serde_json::to_vec(&body).map_err(|e| {
Expand Down Expand Up @@ -211,13 +202,6 @@ impl HttpClient {
message: format!("client request timeout: {}", e),
}),
}?;

debug!(
status = ?response.status(),
headers = ?response.headers(),
size = ?response.size_hint(),
"incoming client response"
);
let span = Span::current();
trace::trace_context_from_http_response(&span, &response);
Ok(response.into())
Expand Down
11 changes: 4 additions & 7 deletions src/clients/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use http_body_util::BodyExt;
use hyper::{HeaderMap, StatusCode};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tracing::{info, instrument};

use super::{
Client, Error, HttpClient, create_http_client, detector::ContentAnalysisResponse,
Expand Down Expand Up @@ -70,14 +69,12 @@ impl OpenAiClient {
&self.client
}

#[instrument(skip_all, fields(request.model))]
pub async fn chat_completions(
&self,
request: ChatCompletionsRequest,
headers: HeaderMap,
) -> Result<ChatCompletionsResponse, Error> {
let url = self.inner().endpoint(CHAT_COMPLETIONS_ENDPOINT);
info!("sending Open AI chat completion request to {}", url);
if request.stream {
let (tx, rx) = mpsc::channel(32);
let mut event_stream = self
Expand Down Expand Up @@ -296,11 +293,11 @@ pub struct ChatCompletionsRequest {
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct DetectorConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub input: Option<HashMap<String, DetectorParams>>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub input: HashMap<String, DetectorParams>,

#[serde(skip_serializing_if = "Option::is_none")]
pub output: Option<HashMap<String, DetectorParams>>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub output: HashMap<String, DetectorParams>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down
Loading