iota_metrics/
metered_channel.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5#![allow(dead_code)]
6
7use std::{
8    future::Future,
9    task::{Context, Poll},
10};
11
12use async_trait::async_trait;
13// TODO: complete tests - This kinda sorta facades the whole tokio::mpsc::{Sender, Receiver}:
14// without tests, this will be fragile to maintain.
15use futures::{FutureExt, Stream, TryFutureExt};
16use prometheus::{IntCounter, IntGauge};
17use tokio::sync::mpsc::{
18    self,
19    error::{SendError, TryRecvError, TrySendError},
20};
21
22#[cfg(test)]
23#[path = "tests/metered_channel_tests.rs"]
24mod metered_channel_tests;
25
26/// An [`mpsc::Sender`] with an [`IntGauge`]
27/// counting the number of currently queued items.
28#[derive(Debug)]
29pub struct Sender<T> {
30    inner: mpsc::Sender<T>,
31    gauge: IntGauge,
32}
33
34impl<T> Clone for Sender<T> {
35    fn clone(&self) -> Self {
36        Self {
37            inner: self.inner.clone(),
38            gauge: self.gauge.clone(),
39        }
40    }
41}
42
43impl<T> Sender<T> {
44    /// Downgrades the current `Sender` to a `WeakSender`, which holds a weak
45    /// reference to the underlying channel. This allows the `Sender` to be
46    /// safely dropped without affecting the channel while maintaining
47    /// the ability to upgrade back to a strong reference later if needed.
48    /// Returns a `WeakSender` that holds the weak reference and the
49    /// associated gauge for resource tracking.
50    pub fn downgrade(&self) -> WeakSender<T> {
51        let sender = self.inner.downgrade();
52        WeakSender {
53            inner: sender,
54            gauge: self.gauge.clone(),
55        }
56    }
57}
58
59/// An [`mpsc::WeakSender`] with an [`IntGauge`]
60/// counting the number of currently queued items.
61#[derive(Debug)]
62pub struct WeakSender<T> {
63    inner: mpsc::WeakSender<T>,
64    gauge: IntGauge,
65}
66
67impl<T> Clone for WeakSender<T> {
68    fn clone(&self) -> Self {
69        Self {
70            inner: self.inner.clone(),
71            gauge: self.gauge.clone(),
72        }
73    }
74}
75
76impl<T> WeakSender<T> {
77    /// Upgrades the `WeakSender` to a strong `Sender`, if the underlying
78    /// channel still exists. This allows the `WeakSender` to regain full
79    /// control over the channel, enabling it to send messages again. If the
80    /// underlying channel has been dropped, `None` is returned. Otherwise, it
81    /// returns a `Sender` with the upgraded reference and the associated
82    /// gauge for resource tracking.
83    pub fn upgrade(&self) -> Option<Sender<T>> {
84        self.inner.upgrade().map(|s| Sender {
85            inner: s,
86            gauge: self.gauge.clone(),
87        })
88    }
89}
90
91/// An [`mpsc::Receiver`] with an [`IntGauge`]
92/// counting the number of currently queued items.
93#[derive(Debug)]
94pub struct Receiver<T> {
95    inner: mpsc::Receiver<T>,
96    gauge: IntGauge,
97    total: Option<IntCounter>,
98}
99
100impl<T> Receiver<T> {
101    /// Receives the next value for this receiver.
102    /// Decrements the gauge in case of a successful `recv`.
103    pub async fn recv(&mut self) -> Option<T> {
104        self.inner
105            .recv()
106            .inspect(|opt| {
107                if opt.is_some() {
108                    self.gauge.dec();
109                    if let Some(total_gauge) = &self.total {
110                        total_gauge.inc();
111                    }
112                }
113            })
114            .await
115    }
116
117    /// Attempts to receive the next value for this receiver.
118    /// Decrements the gauge in case of a successful `try_recv`.
119    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
120        self.inner.try_recv().inspect(|_| {
121            self.gauge.dec();
122            if let Some(total_gauge) = &self.total {
123                total_gauge.inc();
124            }
125        })
126    }
127
128    /// Receives a value from the channel in a blocking manner. When a value is
129    /// received, the gauge is decremented to indicate that an item has been
130    /// processed, and the `total_gauge` (if available) is incremented to
131    /// keep track of the total number of received items. Returns the received
132    /// value if successful, or `None` if the channel is closed.
133    pub fn blocking_recv(&mut self) -> Option<T> {
134        self.inner.blocking_recv().inspect(|_| {
135            self.gauge.dec();
136            if let Some(total_gauge) = &self.total {
137                total_gauge.inc();
138            }
139        })
140    }
141
142    /// Closes the receiving half of a channel without dropping it.
143    pub fn close(&mut self) {
144        self.inner.close()
145    }
146
147    /// Polls to receive the next message on this channel.
148    /// Decrements the gauge in case of a successful `poll_recv`.
149    pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
150        match self.inner.poll_recv(cx) {
151            res @ Poll::Ready(Some(_)) => {
152                self.gauge.dec();
153                if let Some(total_gauge) = &self.total {
154                    total_gauge.inc();
155                }
156                res
157            }
158            s => s,
159        }
160    }
161}
162
163impl<T> Unpin for Receiver<T> {}
164
165/// A newtype for an `mpsc::Permit` which allows us to inject gauge accounting
166/// in the case the permit is dropped w/o sending
167pub struct Permit<'a, T> {
168    permit: Option<mpsc::Permit<'a, T>>,
169    gauge_ref: &'a IntGauge,
170}
171
172impl<'a, T> Permit<'a, T> {
173    /// Creates a new `Permit` instance using the provided `mpsc::Permit` and a
174    /// reference to an `IntGauge`. The `Permit` allows sending a message
175    /// into a channel and is used to ensure a slot is available before sending.
176    pub fn new(permit: mpsc::Permit<'a, T>, gauge_ref: &'a IntGauge) -> Permit<'a, T> {
177        Permit {
178            permit: Some(permit),
179            gauge_ref,
180        }
181    }
182
183    /// Sends a value into the channel using the held `Permit`. After sending
184    /// the value, the function uses `std::mem::forget(self)` to skip the
185    /// drop logic of the permit, avoiding double decrement of the gauge or
186    /// other unintended side effects.
187    pub fn send(mut self, value: T) {
188        let sender = self.permit.take().expect("Permit invariant violated!");
189        sender.send(value);
190        // skip the drop logic, see https://github.com/tokio-rs/tokio/blob/a66884a2fb80d1180451706f3c3e006a3fdcb036/tokio/src/sync/mpsc/bounded.rs#L1155-L1163
191        std::mem::forget(self);
192    }
193}
194
195impl<T> Drop for Permit<'_, T> {
196    /// Custom drop logic for the `Permit` to handle cases where the permit is
197    /// dropped without sending a value. This ensures that the occupancy of
198    /// the channel is correctly updated by decrementing the associated gauge
199    /// if the permit was not used. This prevents any overestimation of active
200    /// items in the channel.
201    fn drop(&mut self) {
202        // in the case the permit is dropped without sending, we still want to decrease
203        // the occupancy of the channel
204        if self.permit.is_some() {
205            self.gauge_ref.dec();
206        }
207    }
208}
209
210impl<T> Sender<T> {
211    /// Sends a value, waiting until there is capacity.
212    /// Increments the gauge in case of a successful `send`.
213    pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
214        self.inner
215            .send(value)
216            .inspect_ok(|_| self.gauge.inc())
217            .await
218    }
219
220    /// Completes when the receiver has dropped.
221    pub async fn closed(&self) {
222        self.inner.closed().await
223    }
224
225    /// Attempts to immediately send a message on this `Sender`
226    /// Increments the gauge in case of a successful `try_send`.
227    pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
228        self.inner
229            .try_send(message)
230            // remove this unsightly hack once https://github.com/rust-lang/rust/issues/91345 is resolved
231            .inspect(|_| {
232                self.gauge.inc();
233            })
234    }
235
236    // TODO: facade [`send_timeout`](tokio::mpsc::Sender::send_timeout) under the
237    // tokio feature flag "time" TODO: facade
238    // [`blocking_send`](tokio::mpsc::Sender::blocking_send) under the tokio feature
239    // flag "sync"
240
241    /// Checks if the channel has been closed. This happens when the
242    /// [`Receiver`] is dropped, or when the [`Receiver::close`] method is
243    /// called.
244    pub fn is_closed(&self) -> bool {
245        self.inner.is_closed()
246    }
247
248    /// Waits for channel capacity. Once capacity to send one message is
249    /// available, it is reserved for the caller.
250    /// Increments the gauge in case of a successful `reserve`.
251    pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
252        self.inner
253            .reserve()
254            // remove this unsightly hack once https://github.com/rust-lang/rust/issues/91345 is resolved
255            .map(|val| {
256                val.map(|permit| {
257                    self.gauge.inc();
258                    Permit::new(permit, &self.gauge)
259                })
260            })
261            .await
262    }
263
264    /// Tries to acquire a slot in the channel without waiting for the slot to
265    /// become available.
266    /// Increments the gauge in case of a successful `try_reserve`.
267    pub fn try_reserve(&self) -> Result<Permit<'_, T>, TrySendError<()>> {
268        self.inner.try_reserve().map(|val| {
269            // remove this unsightly hack once https://github.com/rust-lang/rust/issues/91345 is resolved
270            self.gauge.inc();
271            Permit::new(val, &self.gauge)
272        })
273    }
274
275    // TODO: consider exposing the _owned methods
276
277    // Note: not exposing `same_channel`, as it is hard to implement with callers
278    // able to break the coupling between channel and gauge using `gauge`.
279
280    /// Returns the current capacity of the channel.
281    pub fn capacity(&self) -> usize {
282        self.inner.capacity()
283    }
284
285    // We're voluntarily not putting WeakSender under a facade.
286
287    /// Returns a reference to the underlying gauge.
288    pub fn gauge(&self) -> &IntGauge {
289        &self.gauge
290    }
291}
292
293////////////////////////////////
294// Stream API Wrappers!
295////////////////////////////////
296
297/// A wrapper around [`crate::metered_channel::Receiver`] that implements
298/// [`Stream`].
299#[derive(Debug)]
300pub struct ReceiverStream<T> {
301    inner: Receiver<T>,
302}
303
304impl<T> ReceiverStream<T> {
305    /// Create a new `ReceiverStream`.
306    pub fn new(recv: Receiver<T>) -> Self {
307        Self { inner: recv }
308    }
309
310    /// Get back the inner `Receiver`.
311    pub fn into_inner(self) -> Receiver<T> {
312        self.inner
313    }
314
315    /// Closes the receiving half of a channel without dropping it.
316    pub fn close(&mut self) {
317        self.inner.close()
318    }
319}
320
321impl<T> Stream for ReceiverStream<T> {
322    type Item = T;
323    /// Polls the inner `Receiver` for the next item, enabling the
324    /// `ReceiverStream` to yield values in a stream-like manner.
325    fn poll_next(
326        mut self: std::pin::Pin<&mut Self>,
327        cx: &mut Context<'_>,
328    ) -> Poll<Option<Self::Item>> {
329        self.inner.poll_recv(cx)
330    }
331}
332
333impl<T> AsRef<Receiver<T>> for ReceiverStream<T> {
334    /// Gets a reference to the inner `Receiver`.
335    fn as_ref(&self) -> &Receiver<T> {
336        &self.inner
337    }
338}
339
340impl<T> AsMut<Receiver<T>> for ReceiverStream<T> {
341    /// Gets a mutable reference to the inner `Receiver`.
342    fn as_mut(&mut self) -> &mut Receiver<T> {
343        &mut self.inner
344    }
345}
346
347impl<T> From<Receiver<T>> for ReceiverStream<T> {
348    /// Converts a `Receiver` into a `ReceiverStream`.
349    fn from(recv: Receiver<T>) -> Self {
350        Self::new(recv)
351    }
352}
353
354// TODO: facade PollSender
355// TODO: add prom metrics reporting for gauge and migrate all existing use
356// cases.
357
358////////////////////////////////////////////////////////////////
359// Constructor
360////////////////////////////////////////////////////////////////
361
362/// Similar to `mpsc::channel`, `channel` creates a pair of `Sender` and
363/// `Receiver` Deprecated: use `monitored_mpsc::channel` instead.
364#[track_caller]
365pub fn channel<T>(size: usize, gauge: &IntGauge) -> (Sender<T>, Receiver<T>) {
366    gauge.set(0);
367    let (sender, receiver) = mpsc::channel(size);
368    (
369        Sender {
370            inner: sender,
371            gauge: gauge.clone(),
372        },
373        Receiver {
374            inner: receiver,
375            gauge: gauge.clone(),
376            total: None,
377        },
378    )
379}
380
381/// Defines an asynchronous method `with_permit` for working with a permit to
382/// send a message.
383#[async_trait]
384pub trait WithPermit<T> {
385    async fn with_permit<F: Future + Send>(&self, f: F) -> Option<(Permit<T>, F::Output)>;
386}
387
388#[async_trait]
389impl<T: Send> WithPermit<T> for Sender<T> {
390    /// Asynchronously reserves a permit for sending a message and then executes
391    /// the given future (`f`). If a permit can be successfully reserved, it
392    /// returns a tuple containing the `Permit` and the result of the future.
393    /// If the permit reservation fails, `None` is returned, indicating that no
394    /// slot is available. This method ensures that the future only proceeds
395    /// if the channel has available capacity.
396    async fn with_permit<F: Future + Send>(&self, f: F) -> Option<(Permit<T>, F::Output)> {
397        let permit = self.reserve().await.ok()?;
398        Some((permit, f.await))
399    }
400}