use std::task::{Context, Poll};
use futures::{Future, TryFutureExt as _};
use prometheus::IntGauge;
use tap::Tap;
use tokio::sync::mpsc::{
self,
error::{SendError, TryRecvError, TrySendError},
};
use crate::get_metrics;
#[derive(Debug)]
pub struct Sender<T> {
inner: mpsc::Sender<T>,
inflight: Option<IntGauge>,
sent: Option<IntGauge>,
}
impl<T> Sender<T> {
pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
self.inner
.send(value)
.inspect_ok(|_| {
if let Some(inflight) = &self.inflight {
inflight.inc();
}
if let Some(sent) = &self.sent {
sent.inc();
}
})
.await
}
pub async fn closed(&self) {
self.inner.closed().await
}
pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
self.inner
.try_send(message)
.map(|_| {
if let Some(inflight) = &self.inflight {
inflight.inc();
}
if let Some(sent) = &self.sent {
sent.inc();
}
})
}
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
self.inner.reserve().await.map(|permit| {
if let Some(inflight) = &self.inflight {
inflight.inc();
}
Permit::new(permit, &self.inflight, &self.sent)
})
}
pub fn try_reserve(&self) -> Result<Permit<'_, T>, TrySendError<()>> {
self.inner.try_reserve().map(|val| {
if let Some(inflight) = &self.inflight {
inflight.inc();
}
Permit::new(val, &self.inflight, &self.sent)
})
}
pub fn capacity(&self) -> usize {
self.inner.capacity()
}
pub fn downgrade(&self) -> WeakSender<T> {
let sender = self.inner.downgrade();
WeakSender {
inner: sender,
inflight: self.inflight.clone(),
sent: self.sent.clone(),
}
}
#[cfg(test)]
fn inflight(&self) -> &IntGauge {
self.inflight
.as_ref()
.expect("Metrics should have initialized")
}
#[cfg(test)]
fn sent(&self) -> &IntGauge {
self.sent.as_ref().expect("Metrics should have initialized")
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
inflight: self.inflight.clone(),
sent: self.sent.clone(),
}
}
}
pub struct Permit<'a, T> {
permit: Option<mpsc::Permit<'a, T>>,
inflight_ref: &'a Option<IntGauge>,
sent_ref: &'a Option<IntGauge>,
}
impl<'a, T> Permit<'a, T> {
pub fn new(
permit: mpsc::Permit<'a, T>,
inflight_ref: &'a Option<IntGauge>,
sent_ref: &'a Option<IntGauge>,
) -> Permit<'a, T> {
Permit {
permit: Some(permit),
inflight_ref,
sent_ref,
}
}
pub fn send(mut self, value: T) {
let sender = self.permit.take().expect("Permit invariant violated!");
sender.send(value);
if let Some(sent_ref) = self.sent_ref {
sent_ref.inc();
}
std::mem::forget(self);
}
}
impl<'a, T> Drop for Permit<'a, T> {
fn drop(&mut self) {
if self.permit.is_some() {
if let Some(inflight_ref) = self.inflight_ref {
inflight_ref.dec();
}
}
}
}
#[async_trait::async_trait]
pub trait WithPermit<T> {
async fn with_permit<F: Future + Send>(&self, f: F) -> Option<(Permit<T>, F::Output)>
where
T: 'static;
}
#[async_trait::async_trait]
impl<T: Send> WithPermit<T> for Sender<T> {
async fn with_permit<F: Future + Send>(&self, f: F) -> Option<(Permit<T>, F::Output)> {
let permit = self.reserve().await.ok()?;
Some((permit, f.await))
}
}
#[derive(Debug)]
pub struct WeakSender<T> {
inner: mpsc::WeakSender<T>,
inflight: Option<IntGauge>,
sent: Option<IntGauge>,
}
impl<T> WeakSender<T> {
pub fn upgrade(&self) -> Option<Sender<T>> {
self.inner.upgrade().map(|s| Sender {
inner: s,
inflight: self.inflight.clone(),
sent: self.sent.clone(),
})
}
}
impl<T> Clone for WeakSender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
inflight: self.inflight.clone(),
sent: self.sent.clone(),
}
}
}
#[derive(Debug)]
pub struct Receiver<T> {
inner: mpsc::Receiver<T>,
inflight: Option<IntGauge>,
received: Option<IntGauge>,
}
impl<T> Receiver<T> {
pub async fn recv(&mut self) -> Option<T> {
self.inner.recv().await.tap(|opt| {
if opt.is_some() {
if let Some(inflight) = &self.inflight {
inflight.dec();
}
if let Some(received) = &self.received {
received.inc();
}
}
})
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.inner.try_recv().inspect(|_| {
if let Some(inflight) = &self.inflight {
inflight.dec();
}
if let Some(received) = &self.received {
received.inc();
}
})
}
pub fn blocking_recv(&mut self) -> Option<T> {
self.inner.blocking_recv().inspect(|_| {
if let Some(inflight) = &self.inflight {
inflight.dec();
}
if let Some(received) = &self.received {
received.inc();
}
})
}
pub fn close(&mut self) {
self.inner.close()
}
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
match self.inner.poll_recv(cx) {
res @ Poll::Ready(Some(_)) => {
if let Some(inflight) = &self.inflight {
inflight.dec();
}
if let Some(received) = &self.received {
received.inc();
}
res
}
s => s,
}
}
#[cfg(test)]
fn received(&self) -> &IntGauge {
self.received
.as_ref()
.expect("Metrics should have initialized")
}
}
impl<T> Unpin for Receiver<T> {}
pub fn channel<T>(name: &str, size: usize) -> (Sender<T>, Receiver<T>) {
let metrics = get_metrics();
let (sender, receiver) = mpsc::channel(size);
(
Sender {
inner: sender,
inflight: metrics.map(|m| m.channel_inflight.with_label_values(&[name])),
sent: metrics.map(|m| m.channel_sent.with_label_values(&[name])),
},
Receiver {
inner: receiver,
inflight: metrics.map(|m| m.channel_inflight.with_label_values(&[name])),
received: metrics.map(|m| m.channel_received.with_label_values(&[name])),
},
)
}
#[derive(Debug)]
pub struct UnboundedSender<T> {
inner: mpsc::UnboundedSender<T>,
inflight: Option<IntGauge>,
sent: Option<IntGauge>,
}
impl<T> UnboundedSender<T> {
pub fn send(&self, value: T) -> Result<(), SendError<T>> {
self.inner.send(value).map(|_| {
if let Some(inflight) = &self.inflight {
inflight.inc();
}
if let Some(sent) = &self.sent {
sent.inc();
}
})
}
pub async fn closed(&self) {
self.inner.closed().await
}
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
pub fn downgrade(&self) -> WeakUnboundedSender<T> {
let sender = self.inner.downgrade();
WeakUnboundedSender {
inner: sender,
inflight: self.inflight.clone(),
sent: self.sent.clone(),
}
}
#[cfg(test)]
fn inflight(&self) -> &IntGauge {
self.inflight
.as_ref()
.expect("Metrics should have initialized")
}
#[cfg(test)]
fn sent(&self) -> &IntGauge {
self.sent.as_ref().expect("Metrics should have initialized")
}
}
impl<T> Clone for UnboundedSender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
inflight: self.inflight.clone(),
sent: self.sent.clone(),
}
}
}
#[derive(Debug)]
pub struct WeakUnboundedSender<T> {
inner: mpsc::WeakUnboundedSender<T>,
inflight: Option<IntGauge>,
sent: Option<IntGauge>,
}
impl<T> WeakUnboundedSender<T> {
pub fn upgrade(&self) -> Option<UnboundedSender<T>> {
self.inner.upgrade().map(|s| UnboundedSender {
inner: s,
inflight: self.inflight.clone(),
sent: self.sent.clone(),
})
}
}
impl<T> Clone for WeakUnboundedSender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
inflight: self.inflight.clone(),
sent: self.sent.clone(),
}
}
}
#[derive(Debug)]
pub struct UnboundedReceiver<T> {
inner: mpsc::UnboundedReceiver<T>,
inflight: Option<IntGauge>,
received: Option<IntGauge>,
}
impl<T> UnboundedReceiver<T> {
pub async fn recv(&mut self) -> Option<T> {
self.inner.recv().await.tap(|opt| {
if opt.is_some() {
if let Some(inflight) = &self.inflight {
inflight.dec();
}
if let Some(received) = &self.received {
received.inc();
}
}
})
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.inner.try_recv().inspect(|_| {
if let Some(inflight) = &self.inflight {
inflight.dec();
}
if let Some(received) = &self.received {
received.inc();
}
})
}
pub fn blocking_recv(&mut self) -> Option<T> {
self.inner.blocking_recv().inspect(|_| {
if let Some(inflight) = &self.inflight {
inflight.dec();
}
if let Some(received) = &self.received {
received.inc();
}
})
}
pub fn close(&mut self) {
self.inner.close()
}
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
match self.inner.poll_recv(cx) {
res @ Poll::Ready(Some(_)) => {
if let Some(inflight) = &self.inflight {
inflight.dec();
}
if let Some(received) = &self.received {
received.inc();
}
res
}
s => s,
}
}
#[cfg(test)]
fn received(&self) -> &IntGauge {
self.received
.as_ref()
.expect("Metrics should have initialized")
}
}
impl<T> Unpin for UnboundedReceiver<T> {}
pub fn unbounded_channel<T>(name: &str) -> (UnboundedSender<T>, UnboundedReceiver<T>) {
let metrics = get_metrics();
#[allow(clippy::disallowed_methods)]
let (sender, receiver) = mpsc::unbounded_channel();
(
UnboundedSender {
inner: sender,
inflight: metrics.map(|m| m.channel_inflight.with_label_values(&[name])),
sent: metrics.map(|m| m.channel_sent.with_label_values(&[name])),
},
UnboundedReceiver {
inner: receiver,
inflight: metrics.map(|m| m.channel_inflight.with_label_values(&[name])),
received: metrics.map(|m| m.channel_received.with_label_values(&[name])),
},
)
}
#[cfg(test)]
mod test {
use std::task::{Context, Poll};
use futures::{FutureExt as _, task::noop_waker};
use prometheus::Registry;
use tokio::sync::mpsc::error::TrySendError;
use crate::{
init_metrics,
monitored_mpsc::{channel, unbounded_channel},
};
#[tokio::test]
async fn test_bounded_send_and_receive() {
init_metrics(&Registry::new());
let (tx, mut rx) = channel("test_bounded_send_and_receive", 8);
let inflight = tx.inflight();
let sent = tx.sent();
let received = rx.received().clone();
assert_eq!(inflight.get(), 0);
let item = 42;
tx.send(item).await.unwrap();
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 0);
let received_item = rx.recv().await.unwrap();
assert_eq!(received_item, item);
assert_eq!(inflight.get(), 0);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 1);
}
#[tokio::test]
async fn test_try_send() {
init_metrics(&Registry::new());
let (tx, mut rx) = channel("test_try_send", 1);
let inflight = tx.inflight();
let sent = tx.sent();
let received = rx.received().clone();
assert_eq!(inflight.get(), 0);
assert_eq!(sent.get(), 0);
assert_eq!(received.get(), 0);
let item = 42;
tx.try_send(item).unwrap();
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 0);
let received_item = rx.recv().await.unwrap();
assert_eq!(received_item, item);
assert_eq!(inflight.get(), 0);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 1);
}
#[tokio::test]
async fn test_try_send_full() {
init_metrics(&Registry::new());
let (tx, mut rx) = channel("test_try_send_full", 2);
let inflight = tx.inflight();
let sent = tx.sent();
let received = rx.received().clone();
assert_eq!(inflight.get(), 0);
let item = 42;
tx.try_send(item).unwrap();
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 0);
tx.try_send(item).unwrap();
assert_eq!(inflight.get(), 2);
assert_eq!(sent.get(), 2);
assert_eq!(received.get(), 0);
if let Err(e) = tx.try_send(item) {
assert!(matches!(e, TrySendError::Full(_)));
} else {
panic!("Expect try_send return channel being full error");
}
assert_eq!(inflight.get(), 2);
assert_eq!(sent.get(), 2);
assert_eq!(received.get(), 0);
let received_item = rx.recv().await.unwrap();
assert_eq!(received_item, item);
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 2);
assert_eq!(received.get(), 1);
let received_item = rx.recv().await.unwrap();
assert_eq!(received_item, item);
assert_eq!(inflight.get(), 0);
assert_eq!(sent.get(), 2);
assert_eq!(received.get(), 2);
}
#[tokio::test]
async fn test_unbounded_send_and_receive() {
init_metrics(&Registry::new());
let (tx, mut rx) = unbounded_channel("test_unbounded_send_and_receive");
let inflight = tx.inflight();
let sent = tx.sent();
let received = rx.received().clone();
assert_eq!(inflight.get(), 0);
let item = 42;
tx.send(item).unwrap();
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 0);
let received_item = rx.recv().await.unwrap();
assert_eq!(received_item, item);
assert_eq!(inflight.get(), 0);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 1);
}
#[tokio::test]
async fn test_empty_closed_channel() {
init_metrics(&Registry::new());
let (tx, mut rx) = channel("test_empty_closed_channel", 8);
let inflight = tx.inflight();
let received = rx.received().clone();
assert_eq!(inflight.get(), 0);
let item = 42;
tx.send(item).await.unwrap();
assert_eq!(inflight.get(), 1);
assert_eq!(received.get(), 0);
let received_item = rx.recv().await.unwrap();
assert_eq!(received_item, item);
assert_eq!(inflight.get(), 0);
assert_eq!(received.get(), 1);
let res = rx.try_recv();
assert!(res.is_err());
assert_eq!(inflight.get(), 0);
assert_eq!(received.get(), 1);
rx.close();
let res2 = rx.recv().now_or_never().unwrap();
assert!(res2.is_none());
assert_eq!(inflight.get(), 0);
assert_eq!(received.get(), 1);
}
#[tokio::test]
async fn test_reserve() {
init_metrics(&Registry::new());
let (tx, mut rx) = channel("test_reserve", 8);
let inflight = tx.inflight();
let sent = tx.sent();
let received = rx.received().clone();
assert_eq!(inflight.get(), 0);
let permit = tx.reserve().await.unwrap();
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 0);
assert_eq!(received.get(), 0);
let item = 42;
permit.send(item);
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 0);
let permit_2 = tx.reserve().await.unwrap();
assert_eq!(inflight.get(), 2);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 0);
drop(permit_2);
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 0);
let received_item = rx.recv().await.unwrap();
assert_eq!(received_item, item);
assert_eq!(inflight.get(), 0);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 1);
}
#[tokio::test]
async fn test_reserve_and_drop() {
init_metrics(&Registry::new());
let (tx, _rx) = channel::<usize>("test_reserve_and_drop", 8);
let inflight = tx.inflight();
assert_eq!(inflight.get(), 0);
let permit = tx.reserve().await.unwrap();
assert_eq!(inflight.get(), 1);
drop(permit);
assert_eq!(inflight.get(), 0);
}
#[tokio::test]
async fn test_send_backpressure() {
init_metrics(&Registry::new());
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let (tx, mut rx) = channel("test_send_backpressure", 1);
let inflight = tx.inflight();
let sent = tx.sent();
let received = rx.received().clone();
assert_eq!(inflight.get(), 0);
tx.send(1).await.unwrap();
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 0);
let mut task = Box::pin(tx.send(2));
assert!(matches!(task.poll_unpin(&mut cx), Poll::Pending));
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 0);
let item = rx.recv().await.unwrap();
assert_eq!(item, 1);
assert_eq!(inflight.get(), 0);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 1);
assert!(task.now_or_never().is_some());
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 2);
assert_eq!(received.get(), 1);
}
#[tokio::test]
async fn test_reserve_backpressure() {
init_metrics(&Registry::new());
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let (tx, mut rx) = channel("test_reserve_backpressure", 1);
let inflight = tx.inflight();
let sent = tx.sent();
let received = rx.received().clone();
assert_eq!(inflight.get(), 0);
let permit = tx.reserve().await.unwrap();
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 0);
assert_eq!(received.get(), 0);
let mut task = Box::pin(tx.send(2));
assert!(matches!(task.poll_unpin(&mut cx), Poll::Pending));
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 0);
assert_eq!(received.get(), 0);
permit.send(1);
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 0);
let item = rx.recv().await.unwrap();
assert_eq!(item, 1);
assert_eq!(inflight.get(), 0);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 1);
assert!(task.now_or_never().is_some());
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 2);
assert_eq!(received.get(), 1);
}
#[tokio::test]
async fn test_send_backpressure_multi_senders() {
init_metrics(&Registry::new());
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let (tx1, mut rx) = channel("test_send_backpressure_multi_senders", 1);
let inflight = tx1.inflight();
let sent = tx1.sent();
let received = rx.received().clone();
assert_eq!(inflight.get(), 0);
tx1.send(1).await.unwrap();
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 0);
let tx2 = tx1.clone();
let mut task = Box::pin(tx2.send(2));
assert!(matches!(task.poll_unpin(&mut cx), Poll::Pending));
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 0);
let item = rx.recv().await.unwrap();
assert_eq!(item, 1);
assert_eq!(inflight.get(), 0);
assert_eq!(sent.get(), 1);
assert_eq!(received.get(), 1);
assert!(task.now_or_never().is_some());
assert_eq!(inflight.get(), 1);
assert_eq!(sent.get(), 2);
assert_eq!(received.get(), 1);
}
}