1use std::task::{Context, Poll};
9
10use futures::{Future, TryFutureExt as _};
11use prometheus::IntGauge;
12use tap::Tap;
13use tokio::sync::mpsc::{
14 self,
15 error::{SendError, TryRecvError, TrySendError},
16};
17
18use crate::get_metrics;
19
20#[derive(Debug)]
22pub struct Sender<T> {
23 inner: mpsc::Sender<T>,
24 inflight: Option<IntGauge>,
25 sent: Option<IntGauge>,
26}
27
28impl<T> Sender<T> {
29 pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
32 self.inner
33 .send(value)
34 .inspect_ok(|_| {
35 if let Some(inflight) = &self.inflight {
36 inflight.inc();
37 }
38 if let Some(sent) = &self.sent {
39 sent.inc();
40 }
41 })
42 .await
43 }
44
45 pub async fn closed(&self) {
47 self.inner.closed().await
48 }
49
50 pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
53 self.inner
54 .try_send(message)
55 .map(|_| {
57 if let Some(inflight) = &self.inflight {
58 inflight.inc();
59 }
60 if let Some(sent) = &self.sent {
61 sent.inc();
62 }
63 })
64 }
65
66 pub fn is_closed(&self) -> bool {
75 self.inner.is_closed()
76 }
77
78 pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
82 self.inner.reserve().await.map(|permit| {
83 if let Some(inflight) = &self.inflight {
84 inflight.inc();
85 }
86 Permit::new(permit, &self.inflight, &self.sent)
87 })
88 }
89
90 pub fn try_reserve(&self) -> Result<Permit<'_, T>, TrySendError<()>> {
94 self.inner.try_reserve().map(|val| {
95 if let Some(inflight) = &self.inflight {
96 inflight.inc();
97 }
98 Permit::new(val, &self.inflight, &self.sent)
99 })
100 }
101
102 pub fn capacity(&self) -> usize {
106 self.inner.capacity()
107 }
108
109 pub fn downgrade(&self) -> WeakSender<T> {
114 let sender = self.inner.downgrade();
115 WeakSender {
116 inner: sender,
117 inflight: self.inflight.clone(),
118 sent: self.sent.clone(),
119 }
120 }
121
122 #[cfg(test)]
124 fn inflight(&self) -> &IntGauge {
125 self.inflight
126 .as_ref()
127 .expect("Metrics should have initialized")
128 }
129
130 #[cfg(test)]
132 fn sent(&self) -> &IntGauge {
133 self.sent.as_ref().expect("Metrics should have initialized")
134 }
135}
136
137impl<T> Clone for Sender<T> {
139 fn clone(&self) -> Self {
140 Self {
141 inner: self.inner.clone(),
142 inflight: self.inflight.clone(),
143 sent: self.sent.clone(),
144 }
145 }
146}
147
148pub struct Permit<'a, T> {
151 permit: Option<mpsc::Permit<'a, T>>,
152 inflight_ref: &'a Option<IntGauge>,
153 sent_ref: &'a Option<IntGauge>,
154}
155
156impl<'a, T> Permit<'a, T> {
157 pub fn new(
163 permit: mpsc::Permit<'a, T>,
164 inflight_ref: &'a Option<IntGauge>,
165 sent_ref: &'a Option<IntGauge>,
166 ) -> Permit<'a, T> {
167 Permit {
168 permit: Some(permit),
169 inflight_ref,
170 sent_ref,
171 }
172 }
173
174 pub fn send(mut self, value: T) {
181 let sender = self.permit.take().expect("Permit invariant violated!");
182 sender.send(value);
183 if let Some(sent_ref) = self.sent_ref {
184 sent_ref.inc();
185 }
186 std::mem::forget(self);
188 }
189}
190
191impl<T> Drop for Permit<'_, T> {
192 fn drop(&mut self) {
193 if self.permit.is_some() {
197 if let Some(inflight_ref) = self.inflight_ref {
198 inflight_ref.dec();
199 }
200 }
201 }
202}
203
204#[async_trait::async_trait]
208pub trait WithPermit<T> {
209 async fn with_permit<F: Future + Send>(&self, f: F) -> Option<(Permit<T>, F::Output)>
210 where
211 T: 'static;
212}
213
214#[async_trait::async_trait]
215impl<T: Send> WithPermit<T> for Sender<T> {
216 async fn with_permit<F: Future + Send>(&self, f: F) -> Option<(Permit<T>, F::Output)> {
224 let permit = self.reserve().await.ok()?;
225 Some((permit, f.await))
226 }
227}
228
229#[derive(Debug)]
231pub struct WeakSender<T> {
232 inner: mpsc::WeakSender<T>,
233 inflight: Option<IntGauge>,
234 sent: Option<IntGauge>,
235}
236
237impl<T> WeakSender<T> {
238 pub fn upgrade(&self) -> Option<Sender<T>> {
245 self.inner.upgrade().map(|s| Sender {
246 inner: s,
247 inflight: self.inflight.clone(),
248 sent: self.sent.clone(),
249 })
250 }
251}
252
253impl<T> Clone for WeakSender<T> {
255 fn clone(&self) -> Self {
256 Self {
257 inner: self.inner.clone(),
258 inflight: self.inflight.clone(),
259 sent: self.sent.clone(),
260 }
261 }
262}
263
264#[derive(Debug)]
267pub struct Receiver<T> {
268 inner: mpsc::Receiver<T>,
269 inflight: Option<IntGauge>,
270 received: Option<IntGauge>,
271}
272
273impl<T> Receiver<T> {
274 pub async fn recv(&mut self) -> Option<T> {
277 self.inner.recv().await.tap(|opt| {
278 if opt.is_some() {
279 if let Some(inflight) = &self.inflight {
280 inflight.dec();
281 }
282 if let Some(received) = &self.received {
283 received.inc();
284 }
285 }
286 })
287 }
288
289 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
292 self.inner.try_recv().inspect(|_| {
293 if let Some(inflight) = &self.inflight {
294 inflight.dec();
295 }
296 if let Some(received) = &self.received {
297 received.inc();
298 }
299 })
300 }
301
302 pub fn blocking_recv(&mut self) -> Option<T> {
308 self.inner.blocking_recv().inspect(|_| {
309 if let Some(inflight) = &self.inflight {
310 inflight.dec();
311 }
312 if let Some(received) = &self.received {
313 received.inc();
314 }
315 })
316 }
317
318 pub fn close(&mut self) {
320 self.inner.close()
321 }
322
323 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
326 match self.inner.poll_recv(cx) {
327 res @ Poll::Ready(Some(_)) => {
328 if let Some(inflight) = &self.inflight {
329 inflight.dec();
330 }
331 if let Some(received) = &self.received {
332 received.inc();
333 }
334 res
335 }
336 s => s,
337 }
338 }
339
340 #[cfg(test)]
342 fn received(&self) -> &IntGauge {
343 self.received
344 .as_ref()
345 .expect("Metrics should have initialized")
346 }
347}
348
349impl<T> Unpin for Receiver<T> {}
350
351pub fn channel<T>(name: &str, size: usize) -> (Sender<T>, Receiver<T>) {
353 let metrics = get_metrics();
354 let (sender, receiver) = mpsc::channel(size);
355 (
356 Sender {
357 inner: sender,
358 inflight: metrics.map(|m| m.channel_inflight.with_label_values(&[name])),
359 sent: metrics.map(|m| m.channel_sent.with_label_values(&[name])),
360 },
361 Receiver {
362 inner: receiver,
363 inflight: metrics.map(|m| m.channel_inflight.with_label_values(&[name])),
364 received: metrics.map(|m| m.channel_received.with_label_values(&[name])),
365 },
366 )
367}
368
369#[derive(Debug)]
372pub struct UnboundedSender<T> {
373 inner: mpsc::UnboundedSender<T>,
374 inflight: Option<IntGauge>,
375 sent: Option<IntGauge>,
376}
377
378impl<T> UnboundedSender<T> {
379 pub fn send(&self, value: T) -> Result<(), SendError<T>> {
382 self.inner.send(value).map(|_| {
383 if let Some(inflight) = &self.inflight {
384 inflight.inc();
385 }
386 if let Some(sent) = &self.sent {
387 sent.inc();
388 }
389 })
390 }
391
392 pub async fn closed(&self) {
394 self.inner.closed().await
395 }
396
397 pub fn is_closed(&self) -> bool {
401 self.inner.is_closed()
402 }
403
404 pub fn downgrade(&self) -> WeakUnboundedSender<T> {
405 let sender = self.inner.downgrade();
406 WeakUnboundedSender {
407 inner: sender,
408 inflight: self.inflight.clone(),
409 sent: self.sent.clone(),
410 }
411 }
412
413 #[cfg(test)]
415 fn inflight(&self) -> &IntGauge {
416 self.inflight
417 .as_ref()
418 .expect("Metrics should have initialized")
419 }
420
421 #[cfg(test)]
423 fn sent(&self) -> &IntGauge {
424 self.sent.as_ref().expect("Metrics should have initialized")
425 }
426}
427
428impl<T> Clone for UnboundedSender<T> {
430 fn clone(&self) -> Self {
431 Self {
432 inner: self.inner.clone(),
433 inflight: self.inflight.clone(),
434 sent: self.sent.clone(),
435 }
436 }
437}
438
439#[derive(Debug)]
442pub struct WeakUnboundedSender<T> {
443 inner: mpsc::WeakUnboundedSender<T>,
444 inflight: Option<IntGauge>,
445 sent: Option<IntGauge>,
446}
447
448impl<T> WeakUnboundedSender<T> {
449 pub fn upgrade(&self) -> Option<UnboundedSender<T>> {
457 self.inner.upgrade().map(|s| UnboundedSender {
458 inner: s,
459 inflight: self.inflight.clone(),
460 sent: self.sent.clone(),
461 })
462 }
463}
464
465impl<T> Clone for WeakUnboundedSender<T> {
467 fn clone(&self) -> Self {
468 Self {
469 inner: self.inner.clone(),
470 inflight: self.inflight.clone(),
471 sent: self.sent.clone(),
472 }
473 }
474}
475
476#[derive(Debug)]
479pub struct UnboundedReceiver<T> {
480 inner: mpsc::UnboundedReceiver<T>,
481 inflight: Option<IntGauge>,
482 received: Option<IntGauge>,
483}
484
485impl<T> UnboundedReceiver<T> {
486 pub async fn recv(&mut self) -> Option<T> {
489 self.inner.recv().await.tap(|opt| {
490 if opt.is_some() {
491 if let Some(inflight) = &self.inflight {
492 inflight.dec();
493 }
494 if let Some(received) = &self.received {
495 received.inc();
496 }
497 }
498 })
499 }
500
501 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
504 self.inner.try_recv().inspect(|_| {
505 if let Some(inflight) = &self.inflight {
506 inflight.dec();
507 }
508 if let Some(received) = &self.received {
509 received.inc();
510 }
511 })
512 }
513
514 pub fn blocking_recv(&mut self) -> Option<T> {
519 self.inner.blocking_recv().inspect(|_| {
520 if let Some(inflight) = &self.inflight {
521 inflight.dec();
522 }
523 if let Some(received) = &self.received {
524 received.inc();
525 }
526 })
527 }
528
529 pub fn close(&mut self) {
531 self.inner.close()
532 }
533
534 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
537 match self.inner.poll_recv(cx) {
538 res @ Poll::Ready(Some(_)) => {
539 if let Some(inflight) = &self.inflight {
540 inflight.dec();
541 }
542 if let Some(received) = &self.received {
543 received.inc();
544 }
545 res
546 }
547 s => s,
548 }
549 }
550
551 #[cfg(test)]
553 fn received(&self) -> &IntGauge {
554 self.received
555 .as_ref()
556 .expect("Metrics should have initialized")
557 }
558}
559
560impl<T> Unpin for UnboundedReceiver<T> {}
561
562pub fn unbounded_channel<T>(name: &str) -> (UnboundedSender<T>, UnboundedReceiver<T>) {
565 let metrics = get_metrics();
566 #[expect(clippy::disallowed_methods)]
567 let (sender, receiver) = mpsc::unbounded_channel();
568 (
569 UnboundedSender {
570 inner: sender,
571 inflight: metrics.map(|m| m.channel_inflight.with_label_values(&[name])),
572 sent: metrics.map(|m| m.channel_sent.with_label_values(&[name])),
573 },
574 UnboundedReceiver {
575 inner: receiver,
576 inflight: metrics.map(|m| m.channel_inflight.with_label_values(&[name])),
577 received: metrics.map(|m| m.channel_received.with_label_values(&[name])),
578 },
579 )
580}
581
582#[cfg(test)]
583mod test {
584 use std::task::{Context, Poll};
585
586 use futures::{FutureExt as _, task::noop_waker};
587 use prometheus::Registry;
588 use tokio::sync::mpsc::error::TrySendError;
589
590 use crate::{
591 init_metrics,
592 monitored_mpsc::{channel, unbounded_channel},
593 };
594
595 #[tokio::test]
596 async fn test_bounded_send_and_receive() {
597 init_metrics(&Registry::new());
598 let (tx, mut rx) = channel("test_bounded_send_and_receive", 8);
599 let inflight = tx.inflight();
600 let sent = tx.sent();
601 let received = rx.received().clone();
602
603 assert_eq!(inflight.get(), 0);
604 let item = 42;
605 tx.send(item).await.unwrap();
606 assert_eq!(inflight.get(), 1);
607 assert_eq!(sent.get(), 1);
608 assert_eq!(received.get(), 0);
609
610 let received_item = rx.recv().await.unwrap();
611 assert_eq!(received_item, item);
612 assert_eq!(inflight.get(), 0);
613 assert_eq!(sent.get(), 1);
614 assert_eq!(received.get(), 1);
615 }
616
617 #[tokio::test]
618 async fn test_try_send() {
619 init_metrics(&Registry::new());
620 let (tx, mut rx) = channel("test_try_send", 1);
621 let inflight = tx.inflight();
622 let sent = tx.sent();
623 let received = rx.received().clone();
624
625 assert_eq!(inflight.get(), 0);
626 assert_eq!(sent.get(), 0);
627 assert_eq!(received.get(), 0);
628
629 let item = 42;
630 tx.try_send(item).unwrap();
631 assert_eq!(inflight.get(), 1);
632 assert_eq!(sent.get(), 1);
633 assert_eq!(received.get(), 0);
634
635 let received_item = rx.recv().await.unwrap();
636 assert_eq!(received_item, item);
637 assert_eq!(inflight.get(), 0);
638 assert_eq!(sent.get(), 1);
639 assert_eq!(received.get(), 1);
640 }
641
642 #[tokio::test]
643 async fn test_try_send_full() {
644 init_metrics(&Registry::new());
645 let (tx, mut rx) = channel("test_try_send_full", 2);
646 let inflight = tx.inflight();
647 let sent = tx.sent();
648 let received = rx.received().clone();
649
650 assert_eq!(inflight.get(), 0);
651
652 let item = 42;
653 tx.try_send(item).unwrap();
654 assert_eq!(inflight.get(), 1);
655 assert_eq!(sent.get(), 1);
656 assert_eq!(received.get(), 0);
657
658 tx.try_send(item).unwrap();
659 assert_eq!(inflight.get(), 2);
660 assert_eq!(sent.get(), 2);
661 assert_eq!(received.get(), 0);
662
663 if let Err(e) = tx.try_send(item) {
664 assert!(matches!(e, TrySendError::Full(_)));
665 } else {
666 panic!("Expect try_send return channel being full error");
667 }
668 assert_eq!(inflight.get(), 2);
669 assert_eq!(sent.get(), 2);
670 assert_eq!(received.get(), 0);
671
672 let received_item = rx.recv().await.unwrap();
673 assert_eq!(received_item, item);
674 assert_eq!(inflight.get(), 1);
675 assert_eq!(sent.get(), 2);
676 assert_eq!(received.get(), 1);
677
678 let received_item = rx.recv().await.unwrap();
679 assert_eq!(received_item, item);
680 assert_eq!(inflight.get(), 0);
681 assert_eq!(sent.get(), 2);
682 assert_eq!(received.get(), 2);
683 }
684
685 #[tokio::test]
686 async fn test_unbounded_send_and_receive() {
687 init_metrics(&Registry::new());
688 let (tx, mut rx) = unbounded_channel("test_unbounded_send_and_receive");
689 let inflight = tx.inflight();
690 let sent = tx.sent();
691 let received = rx.received().clone();
692
693 assert_eq!(inflight.get(), 0);
694 let item = 42;
695 tx.send(item).unwrap();
696 assert_eq!(inflight.get(), 1);
697 assert_eq!(sent.get(), 1);
698 assert_eq!(received.get(), 0);
699
700 let received_item = rx.recv().await.unwrap();
701 assert_eq!(received_item, item);
702 assert_eq!(inflight.get(), 0);
703 assert_eq!(sent.get(), 1);
704 assert_eq!(received.get(), 1);
705 }
706
707 #[tokio::test]
708 async fn test_empty_closed_channel() {
709 init_metrics(&Registry::new());
710 let (tx, mut rx) = channel("test_empty_closed_channel", 8);
711 let inflight = tx.inflight();
712 let received = rx.received().clone();
713
714 assert_eq!(inflight.get(), 0);
715 let item = 42;
716 tx.send(item).await.unwrap();
717 assert_eq!(inflight.get(), 1);
718 assert_eq!(received.get(), 0);
719
720 let received_item = rx.recv().await.unwrap();
721 assert_eq!(received_item, item);
722 assert_eq!(inflight.get(), 0);
723 assert_eq!(received.get(), 1);
724
725 let res = rx.try_recv();
727 assert!(res.is_err());
728 assert_eq!(inflight.get(), 0);
729 assert_eq!(received.get(), 1);
730
731 rx.close();
733 let res2 = rx.recv().now_or_never().unwrap();
734 assert!(res2.is_none());
735 assert_eq!(inflight.get(), 0);
736 assert_eq!(received.get(), 1);
737 }
738
739 #[tokio::test]
740 async fn test_reserve() {
741 init_metrics(&Registry::new());
742 let (tx, mut rx) = channel("test_reserve", 8);
743 let inflight = tx.inflight();
744 let sent = tx.sent();
745 let received = rx.received().clone();
746
747 assert_eq!(inflight.get(), 0);
748
749 let permit = tx.reserve().await.unwrap();
750 assert_eq!(inflight.get(), 1);
751 assert_eq!(sent.get(), 0);
752 assert_eq!(received.get(), 0);
753
754 let item = 42;
755 permit.send(item);
756 assert_eq!(inflight.get(), 1);
757 assert_eq!(sent.get(), 1);
758 assert_eq!(received.get(), 0);
759
760 let permit_2 = tx.reserve().await.unwrap();
761 assert_eq!(inflight.get(), 2);
762 assert_eq!(sent.get(), 1);
763 assert_eq!(received.get(), 0);
764
765 drop(permit_2);
766 assert_eq!(inflight.get(), 1);
767 assert_eq!(sent.get(), 1);
768 assert_eq!(received.get(), 0);
769
770 let received_item = rx.recv().await.unwrap();
771 assert_eq!(received_item, item);
772
773 assert_eq!(inflight.get(), 0);
774 assert_eq!(sent.get(), 1);
775 assert_eq!(received.get(), 1);
776 }
777
778 #[tokio::test]
779 async fn test_reserve_and_drop() {
780 init_metrics(&Registry::new());
781 let (tx, _rx) = channel::<usize>("test_reserve_and_drop", 8);
782 let inflight = tx.inflight();
783
784 assert_eq!(inflight.get(), 0);
785
786 let permit = tx.reserve().await.unwrap();
787 assert_eq!(inflight.get(), 1);
788
789 drop(permit);
790
791 assert_eq!(inflight.get(), 0);
792 }
793
794 #[tokio::test]
795 async fn test_send_backpressure() {
796 init_metrics(&Registry::new());
797 let waker = noop_waker();
798 let mut cx = Context::from_waker(&waker);
799
800 let (tx, mut rx) = channel("test_send_backpressure", 1);
801 let inflight = tx.inflight();
802 let sent = tx.sent();
803 let received = rx.received().clone();
804
805 assert_eq!(inflight.get(), 0);
806
807 tx.send(1).await.unwrap();
808 assert_eq!(inflight.get(), 1);
809 assert_eq!(sent.get(), 1);
810 assert_eq!(received.get(), 0);
811
812 let mut task = Box::pin(tx.send(2));
814 assert!(matches!(task.poll_unpin(&mut cx), Poll::Pending));
815 assert_eq!(inflight.get(), 1);
816 assert_eq!(sent.get(), 1);
817 assert_eq!(received.get(), 0);
818
819 let item = rx.recv().await.unwrap();
820 assert_eq!(item, 1);
821 assert_eq!(inflight.get(), 0);
822 assert_eq!(sent.get(), 1);
823 assert_eq!(received.get(), 1);
824
825 assert!(task.now_or_never().is_some());
826 assert_eq!(inflight.get(), 1);
827 assert_eq!(sent.get(), 2);
828 assert_eq!(received.get(), 1);
829 }
830
831 #[tokio::test]
832 async fn test_reserve_backpressure() {
833 init_metrics(&Registry::new());
834 let waker = noop_waker();
835 let mut cx = Context::from_waker(&waker);
836
837 let (tx, mut rx) = channel("test_reserve_backpressure", 1);
838 let inflight = tx.inflight();
839 let sent = tx.sent();
840 let received = rx.received().clone();
841
842 assert_eq!(inflight.get(), 0);
843
844 let permit = tx.reserve().await.unwrap();
845 assert_eq!(inflight.get(), 1);
846 assert_eq!(sent.get(), 0);
847 assert_eq!(received.get(), 0);
848
849 let mut task = Box::pin(tx.send(2));
850 assert!(matches!(task.poll_unpin(&mut cx), Poll::Pending));
851 assert_eq!(inflight.get(), 1);
852 assert_eq!(sent.get(), 0);
853 assert_eq!(received.get(), 0);
854
855 permit.send(1);
856 assert_eq!(inflight.get(), 1);
857 assert_eq!(sent.get(), 1);
858 assert_eq!(received.get(), 0);
859
860 let item = rx.recv().await.unwrap();
861 assert_eq!(item, 1);
862 assert_eq!(inflight.get(), 0);
863 assert_eq!(sent.get(), 1);
864 assert_eq!(received.get(), 1);
865
866 assert!(task.now_or_never().is_some());
867 assert_eq!(inflight.get(), 1);
868 assert_eq!(sent.get(), 2);
869 assert_eq!(received.get(), 1);
870 }
871
872 #[tokio::test]
873 async fn test_send_backpressure_multi_senders() {
874 init_metrics(&Registry::new());
875 let waker = noop_waker();
876 let mut cx = Context::from_waker(&waker);
877 let (tx1, mut rx) = channel("test_send_backpressure_multi_senders", 1);
878 let inflight = tx1.inflight();
879 let sent = tx1.sent();
880 let received = rx.received().clone();
881
882 assert_eq!(inflight.get(), 0);
883
884 tx1.send(1).await.unwrap();
885 assert_eq!(inflight.get(), 1);
886 assert_eq!(sent.get(), 1);
887 assert_eq!(received.get(), 0);
888
889 let tx2 = tx1.clone();
890 let mut task = Box::pin(tx2.send(2));
891 assert!(matches!(task.poll_unpin(&mut cx), Poll::Pending));
892 assert_eq!(inflight.get(), 1);
893 assert_eq!(sent.get(), 1);
894 assert_eq!(received.get(), 0);
895
896 let item = rx.recv().await.unwrap();
897 assert_eq!(item, 1);
898 assert_eq!(inflight.get(), 0);
899 assert_eq!(sent.get(), 1);
900 assert_eq!(received.get(), 1);
901
902 assert!(task.now_or_never().is_some());
903 assert_eq!(inflight.get(), 1);
904 assert_eq!(sent.get(), 2);
905 assert_eq!(received.get(), 1);
906 }
907}