diff --git a/auction-server/src/api.rs b/auction-server/src/api.rs index bc2ee67e..298da87d 100644 --- a/auction-server/src/api.rs +++ b/auction-server/src/api.rs @@ -11,10 +11,7 @@ use { EXIT_CHECK_INTERVAL, SHOULD_EXIT, }, - state::{ - ServerState, - StoreNew, - }, + state::StoreNew, }, anyhow::Result, axum::{ @@ -53,8 +50,9 @@ use { TypedHeader, }, axum_prometheus::{ - EndpointLabel, - PrometheusMetricLayerBuilder, + metrics, + AXUM_HTTP_REQUESTS_DURATION_SECONDS, + AXUM_HTTP_REQUESTS_TOTAL, }, clap::crate_version, express_relay_api_types::{ @@ -81,6 +79,7 @@ use { }, }, time::OffsetDateTime, + tokio::time::Instant, tower_http::cors::CorsLayer, utoipa::{ openapi::security::{ @@ -750,11 +749,43 @@ impl WrappedRouter { } } -pub async fn start_api( - run_options: RunOptions, - store: Arc, - server_state: Arc, -) -> Result<()> { +async fn track_metrics( + auth: Auth, + req: extract::Request, + next: middleware::Next, +) -> impl IntoResponse { + let start = Instant::now(); + let endpoint = if let Some(matched_path) = req.extensions().get::() { + matched_path.as_str().to_owned() + } else { + "unknown".to_string() + }; + let profile = if let Auth::Authorized(_, profile) = auth { + profile.name + } else { + "unauthorized".to_string() + }; + let method = req.method().clone(); + + let response = next.run(req).await; + + let latency = start.elapsed().as_secs_f64(); + let status = response.status().as_u16().to_string(); + + let labels = [ + ("method", method.to_string()), + ("endpoint", endpoint), + ("status", status), + ("profile", profile), + ]; + + metrics::counter!(AXUM_HTTP_REQUESTS_TOTAL, &labels).increment(1); + metrics::histogram!(AXUM_HTTP_REQUESTS_DURATION_SECONDS, &labels).record(latency); + + response +} + +pub async fn start_api(run_options: RunOptions, store: Arc) -> Result<()> { // Make sure functions included in the paths section have distinct names, otherwise some api generators will fail #[derive(OpenApi)] #[openapi( @@ -880,13 +911,6 @@ pub async fn start_api( .merge(profile_routes) .merge(ws::get_routes(store.clone())); - let (prometheus_layer, _) = PrometheusMetricLayerBuilder::new() - .with_metrics_from_fn(|| server_state.metrics_recorder.clone()) - .with_endpoint_label_type(EndpointLabel::MatchedPathWithFallbackFn(|_| { - "unknown".to_string() - })) - .build_pair(); - let original_doc = serde_json::to_value(ApiDoc::openapi()) .expect("Failed to serialize OpenAPI document to json value"); @@ -898,7 +922,7 @@ pub async fn start_api( .route(Route::OpenApi.as_ref(), get(original_doc.to_string())) .layer(CorsLayer::permissive()) .layer(middleware::from_extractor_with_state::>(store.clone())) - .layer(prometheus_layer) + .layer(middleware::from_fn_with_state(store.clone(), track_metrics)) .with_state(store); let listener = tokio::net::TcpListener::bind(&run_options.server.listen_addr).await?; diff --git a/auction-server/src/api/ws.rs b/auction-server/src/api/ws.rs index e7dced40..5f7a8eff 100644 --- a/auction-server/src/api/ws.rs +++ b/auction-server/src/api/ws.rs @@ -392,10 +392,14 @@ impl Subscriber { #[instrument( target = "metrics", - fields(category = "ws_update", result = "success", name), + fields(category = "ws_update", result = "success", profile, name), skip_all )] async fn handle_update(&mut self, event: UpdateEvent) -> Result<()> { + if let Auth::Authorized(_, profile) = self.auth.clone() { + tracing::Span::current().record("profile", profile.name); + } + let result = match event.clone() { UpdateEvent::NewOpportunity(opportunity) => { tracing::Span::current().record("name", "new_opportunity"); @@ -543,10 +547,14 @@ impl Subscriber { #[instrument( target = "metrics", - fields(category = "ws_client_message", result = "success", name), + fields(category = "ws_client_message", result = "success", profile, name), skip_all )] async fn handle_client_message(&mut self, message: Message) -> Result<()> { + if let Auth::Authorized(_, profile) = self.auth.clone() { + tracing::Span::current().record("profile", profile.name); + } + let maybe_client_message = match message { Message::Close(_) => { // Closing the connection. We don't remove it from the subscribers diff --git a/auction-server/src/per_metrics.rs b/auction-server/src/per_metrics.rs index c1134a3a..bf9abb7e 100644 --- a/auction-server/src/per_metrics.rs +++ b/auction-server/src/per_metrics.rs @@ -59,6 +59,7 @@ pub struct MetricsLayerData { started_at: std::time::Instant, result: String, name: String, + profile: String, } pub struct MetricsLayer; @@ -76,6 +77,8 @@ impl Visit for MetricsLayerData { self.result = value.to_string(); } else if field.name() == "name" { self.name = value.to_string(); + } else if field.name() == "profile" { + self.profile = value.to_string(); } } } @@ -87,6 +90,7 @@ impl Default for MetricsLayerData { started_at: Instant::now(), result: "unknown".to_string(), name: "unknown".to_string(), + profile: "unknown".to_string(), } } } @@ -153,7 +157,11 @@ where Some(span) => match span.extensions().get::() { Some(data) => { let latency = (Instant::now() - data.started_at).as_secs_f64(); - let labels = [("name", data.name.clone()), ("result", data.result.clone())]; + let labels = [ + ("name", data.name.clone()), + ("result", data.result.clone()), + ("profile", data.profile.clone()), + ]; metrics::histogram!(format!("{}_duration_seconds", data.category), &labels) .record(latency); metrics::counter!(format!("{}_total", data.category), &labels).increment(1); diff --git a/auction-server/src/server.rs b/auction-server/src/server.rs index 9be28388..c181670d 100644 --- a/auction-server/src/server.rs +++ b/auction-server/src/server.rs @@ -514,7 +514,6 @@ pub async fn start_server(run_options: RunOptions) -> Result<()> { fault_tolerant_handler("start api".to_string(), || api::start_api( run_options.clone(), store_new.clone(), - server_state.clone(), )), fault_tolerant_handler("start metrics".to_string(), || per_metrics::start_metrics( run_options.clone(),