use std::time::Duration;
use tonic::{
Code, Status,
codegen::http::{HeaderValue, Request, Response, header::HeaderName},
};
use tower_http::{
classify::GrpcFailureClass,
trace::{OnFailure, OnRequest, OnResponse},
};
use tracing::Span;
pub(crate) static GRPC_ENDPOINT_PATH_HEADER: HeaderName = HeaderName::from_static("grpc-path-req");
pub trait MetricsCallbackProvider: Send + Sync + Clone + 'static {
fn on_request(&self, path: String);
fn on_response(&self, path: String, latency: Duration, status: u16, grpc_status_code: Code);
fn on_start(&self, _path: &str) {}
fn on_drop(&self, _path: &str) {}
}
#[derive(Clone, Default)]
pub struct DefaultMetricsCallbackProvider {}
impl MetricsCallbackProvider for DefaultMetricsCallbackProvider {
fn on_request(&self, _path: String) {}
fn on_response(
&self,
_path: String,
_latency: Duration,
_status: u16,
_grpc_status_code: Code,
) {
}
}
#[derive(Clone)]
pub(crate) struct MetricsHandler<M: MetricsCallbackProvider> {
metrics_provider: M,
}
impl<M: MetricsCallbackProvider> MetricsHandler<M> {
pub(crate) fn new(metrics_provider: M) -> Self {
Self { metrics_provider }
}
}
impl<B, M: MetricsCallbackProvider> OnResponse<B> for MetricsHandler<M> {
fn on_response(self, response: &Response<B>, latency: Duration, _span: &Span) {
let grpc_status = Status::from_header_map(response.headers());
let grpc_status_code = grpc_status.map_or(Code::Ok, |s| s.code());
let path: HeaderValue = response
.headers()
.get(&GRPC_ENDPOINT_PATH_HEADER)
.unwrap()
.clone();
self.metrics_provider.on_response(
path.to_str().unwrap().to_string(),
latency,
response.status().as_u16(),
grpc_status_code,
);
}
}
impl<B, M: MetricsCallbackProvider> OnRequest<B> for MetricsHandler<M> {
fn on_request(&mut self, request: &Request<B>, _span: &Span) {
self.metrics_provider
.on_request(request.uri().path().to_string());
}
}
impl<M: MetricsCallbackProvider> OnFailure<GrpcFailureClass> for MetricsHandler<M> {
fn on_failure(
&mut self,
_failure_classification: GrpcFailureClass,
_latency: Duration,
_span: &Span,
) {
}
}