iota_metrics/metered_channel.rs
1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5#![allow(dead_code)]
6
7use std::{
8 future::Future,
9 task::{Context, Poll},
10};
11
12use async_trait::async_trait;
13// TODO: complete tests - This kinda sorta facades the whole tokio::mpsc::{Sender, Receiver}:
14// without tests, this will be fragile to maintain.
15use futures::{FutureExt, Stream, TryFutureExt};
16use prometheus::{IntCounter, IntGauge};
17use tokio::sync::mpsc::{
18 self,
19 error::{SendError, TryRecvError, TrySendError},
20};
21
22#[cfg(test)]
23#[path = "tests/metered_channel_tests.rs"]
24mod metered_channel_tests;
25
26/// An [`mpsc::Sender`] with an [`IntGauge`]
27/// counting the number of currently queued items.
28#[derive(Debug)]
29pub struct Sender<T> {
30 inner: mpsc::Sender<T>,
31 gauge: IntGauge,
32}
33
34impl<T> Clone for Sender<T> {
35 fn clone(&self) -> Self {
36 Self {
37 inner: self.inner.clone(),
38 gauge: self.gauge.clone(),
39 }
40 }
41}
42
43impl<T> Sender<T> {
44 /// Downgrades the current `Sender` to a `WeakSender`, which holds a weak
45 /// reference to the underlying channel. This allows the `Sender` to be
46 /// safely dropped without affecting the channel while maintaining
47 /// the ability to upgrade back to a strong reference later if needed.
48 /// Returns a `WeakSender` that holds the weak reference and the
49 /// associated gauge for resource tracking.
50 pub fn downgrade(&self) -> WeakSender<T> {
51 let sender = self.inner.downgrade();
52 WeakSender {
53 inner: sender,
54 gauge: self.gauge.clone(),
55 }
56 }
57}
58
59/// An [`mpsc::WeakSender`] with an [`IntGauge`]
60/// counting the number of currently queued items.
61#[derive(Debug)]
62pub struct WeakSender<T> {
63 inner: mpsc::WeakSender<T>,
64 gauge: IntGauge,
65}
66
67impl<T> Clone for WeakSender<T> {
68 fn clone(&self) -> Self {
69 Self {
70 inner: self.inner.clone(),
71 gauge: self.gauge.clone(),
72 }
73 }
74}
75
76impl<T> WeakSender<T> {
77 /// Upgrades the `WeakSender` to a strong `Sender`, if the underlying
78 /// channel still exists. This allows the `WeakSender` to regain full
79 /// control over the channel, enabling it to send messages again. If the
80 /// underlying channel has been dropped, `None` is returned. Otherwise, it
81 /// returns a `Sender` with the upgraded reference and the associated
82 /// gauge for resource tracking.
83 pub fn upgrade(&self) -> Option<Sender<T>> {
84 self.inner.upgrade().map(|s| Sender {
85 inner: s,
86 gauge: self.gauge.clone(),
87 })
88 }
89}
90
91/// An [`mpsc::Receiver`] with an [`IntGauge`]
92/// counting the number of currently queued items.
93#[derive(Debug)]
94pub struct Receiver<T> {
95 inner: mpsc::Receiver<T>,
96 gauge: IntGauge,
97 total: Option<IntCounter>,
98}
99
100impl<T> Receiver<T> {
101 /// Receives the next value for this receiver.
102 /// Decrements the gauge in case of a successful `recv`.
103 pub async fn recv(&mut self) -> Option<T> {
104 self.inner
105 .recv()
106 .inspect(|opt| {
107 if opt.is_some() {
108 self.gauge.dec();
109 if let Some(total_gauge) = &self.total {
110 total_gauge.inc();
111 }
112 }
113 })
114 .await
115 }
116
117 /// Attempts to receive the next value for this receiver.
118 /// Decrements the gauge in case of a successful `try_recv`.
119 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
120 self.inner.try_recv().inspect(|_| {
121 self.gauge.dec();
122 if let Some(total_gauge) = &self.total {
123 total_gauge.inc();
124 }
125 })
126 }
127
128 /// Receives a value from the channel in a blocking manner. When a value is
129 /// received, the gauge is decremented to indicate that an item has been
130 /// processed, and the `total_gauge` (if available) is incremented to
131 /// keep track of the total number of received items. Returns the received
132 /// value if successful, or `None` if the channel is closed.
133 pub fn blocking_recv(&mut self) -> Option<T> {
134 self.inner.blocking_recv().inspect(|_| {
135 self.gauge.dec();
136 if let Some(total_gauge) = &self.total {
137 total_gauge.inc();
138 }
139 })
140 }
141
142 /// Closes the receiving half of a channel without dropping it.
143 pub fn close(&mut self) {
144 self.inner.close()
145 }
146
147 /// Polls to receive the next message on this channel.
148 /// Decrements the gauge in case of a successful `poll_recv`.
149 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
150 match self.inner.poll_recv(cx) {
151 res @ Poll::Ready(Some(_)) => {
152 self.gauge.dec();
153 if let Some(total_gauge) = &self.total {
154 total_gauge.inc();
155 }
156 res
157 }
158 s => s,
159 }
160 }
161}
162
163impl<T> Unpin for Receiver<T> {}
164
165/// A newtype for an `mpsc::Permit` which allows us to inject gauge accounting
166/// in the case the permit is dropped w/o sending
167pub struct Permit<'a, T> {
168 permit: Option<mpsc::Permit<'a, T>>,
169 gauge_ref: &'a IntGauge,
170}
171
172impl<'a, T> Permit<'a, T> {
173 /// Creates a new `Permit` instance using the provided `mpsc::Permit` and a
174 /// reference to an `IntGauge`. The `Permit` allows sending a message
175 /// into a channel and is used to ensure a slot is available before sending.
176 pub fn new(permit: mpsc::Permit<'a, T>, gauge_ref: &'a IntGauge) -> Permit<'a, T> {
177 Permit {
178 permit: Some(permit),
179 gauge_ref,
180 }
181 }
182
183 /// Sends a value into the channel using the held `Permit`. After sending
184 /// the value, the function uses `std::mem::forget(self)` to skip the
185 /// drop logic of the permit, avoiding double decrement of the gauge or
186 /// other unintended side effects.
187 pub fn send(mut self, value: T) {
188 let sender = self.permit.take().expect("Permit invariant violated!");
189 sender.send(value);
190 // skip the drop logic, see https://github.com/tokio-rs/tokio/blob/a66884a2fb80d1180451706f3c3e006a3fdcb036/tokio/src/sync/mpsc/bounded.rs#L1155-L1163
191 std::mem::forget(self);
192 }
193}
194
195impl<T> Drop for Permit<'_, T> {
196 /// Custom drop logic for the `Permit` to handle cases where the permit is
197 /// dropped without sending a value. This ensures that the occupancy of
198 /// the channel is correctly updated by decrementing the associated gauge
199 /// if the permit was not used. This prevents any overestimation of active
200 /// items in the channel.
201 fn drop(&mut self) {
202 // in the case the permit is dropped without sending, we still want to decrease
203 // the occupancy of the channel
204 if self.permit.is_some() {
205 self.gauge_ref.dec();
206 }
207 }
208}
209
210impl<T> Sender<T> {
211 /// Sends a value, waiting until there is capacity.
212 /// Increments the gauge in case of a successful `send`.
213 pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
214 self.inner
215 .send(value)
216 .inspect_ok(|_| self.gauge.inc())
217 .await
218 }
219
220 /// Completes when the receiver has dropped.
221 pub async fn closed(&self) {
222 self.inner.closed().await
223 }
224
225 /// Attempts to immediately send a message on this `Sender`
226 /// Increments the gauge in case of a successful `try_send`.
227 pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
228 self.inner
229 .try_send(message)
230 // remove this unsightly hack once https://github.com/rust-lang/rust/issues/91345 is resolved
231 .inspect(|_| {
232 self.gauge.inc();
233 })
234 }
235
236 // TODO: facade [`send_timeout`](tokio::mpsc::Sender::send_timeout) under the
237 // tokio feature flag "time" TODO: facade
238 // [`blocking_send`](tokio::mpsc::Sender::blocking_send) under the tokio feature
239 // flag "sync"
240
241 /// Checks if the channel has been closed. This happens when the
242 /// [`Receiver`] is dropped, or when the [`Receiver::close`] method is
243 /// called.
244 pub fn is_closed(&self) -> bool {
245 self.inner.is_closed()
246 }
247
248 /// Waits for channel capacity. Once capacity to send one message is
249 /// available, it is reserved for the caller.
250 /// Increments the gauge in case of a successful `reserve`.
251 pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
252 self.inner
253 .reserve()
254 // remove this unsightly hack once https://github.com/rust-lang/rust/issues/91345 is resolved
255 .map(|val| {
256 val.map(|permit| {
257 self.gauge.inc();
258 Permit::new(permit, &self.gauge)
259 })
260 })
261 .await
262 }
263
264 /// Tries to acquire a slot in the channel without waiting for the slot to
265 /// become available.
266 /// Increments the gauge in case of a successful `try_reserve`.
267 pub fn try_reserve(&self) -> Result<Permit<'_, T>, TrySendError<()>> {
268 self.inner.try_reserve().map(|val| {
269 // remove this unsightly hack once https://github.com/rust-lang/rust/issues/91345 is resolved
270 self.gauge.inc();
271 Permit::new(val, &self.gauge)
272 })
273 }
274
275 // TODO: consider exposing the _owned methods
276
277 // Note: not exposing `same_channel`, as it is hard to implement with callers
278 // able to break the coupling between channel and gauge using `gauge`.
279
280 /// Returns the current capacity of the channel.
281 pub fn capacity(&self) -> usize {
282 self.inner.capacity()
283 }
284
285 // We're voluntarily not putting WeakSender under a facade.
286
287 /// Returns a reference to the underlying gauge.
288 pub fn gauge(&self) -> &IntGauge {
289 &self.gauge
290 }
291}
292
293////////////////////////////////
294// Stream API Wrappers!
295////////////////////////////////
296
297/// A wrapper around [`crate::metered_channel::Receiver`] that implements
298/// [`Stream`].
299#[derive(Debug)]
300pub struct ReceiverStream<T> {
301 inner: Receiver<T>,
302}
303
304impl<T> ReceiverStream<T> {
305 /// Create a new `ReceiverStream`.
306 pub fn new(recv: Receiver<T>) -> Self {
307 Self { inner: recv }
308 }
309
310 /// Get back the inner `Receiver`.
311 pub fn into_inner(self) -> Receiver<T> {
312 self.inner
313 }
314
315 /// Closes the receiving half of a channel without dropping it.
316 pub fn close(&mut self) {
317 self.inner.close()
318 }
319}
320
321impl<T> Stream for ReceiverStream<T> {
322 type Item = T;
323 /// Polls the inner `Receiver` for the next item, enabling the
324 /// `ReceiverStream` to yield values in a stream-like manner.
325 fn poll_next(
326 mut self: std::pin::Pin<&mut Self>,
327 cx: &mut Context<'_>,
328 ) -> Poll<Option<Self::Item>> {
329 self.inner.poll_recv(cx)
330 }
331}
332
333impl<T> AsRef<Receiver<T>> for ReceiverStream<T> {
334 /// Gets a reference to the inner `Receiver`.
335 fn as_ref(&self) -> &Receiver<T> {
336 &self.inner
337 }
338}
339
340impl<T> AsMut<Receiver<T>> for ReceiverStream<T> {
341 /// Gets a mutable reference to the inner `Receiver`.
342 fn as_mut(&mut self) -> &mut Receiver<T> {
343 &mut self.inner
344 }
345}
346
347impl<T> From<Receiver<T>> for ReceiverStream<T> {
348 /// Converts a `Receiver` into a `ReceiverStream`.
349 fn from(recv: Receiver<T>) -> Self {
350 Self::new(recv)
351 }
352}
353
354// TODO: facade PollSender
355// TODO: add prom metrics reporting for gauge and migrate all existing use
356// cases.
357
358////////////////////////////////////////////////////////////////
359// Constructor
360////////////////////////////////////////////////////////////////
361
362/// Similar to `mpsc::channel`, `channel` creates a pair of `Sender` and
363/// `Receiver` Deprecated: use `monitored_mpsc::channel` instead.
364#[track_caller]
365pub fn channel<T>(size: usize, gauge: &IntGauge) -> (Sender<T>, Receiver<T>) {
366 gauge.set(0);
367 let (sender, receiver) = mpsc::channel(size);
368 (
369 Sender {
370 inner: sender,
371 gauge: gauge.clone(),
372 },
373 Receiver {
374 inner: receiver,
375 gauge: gauge.clone(),
376 total: None,
377 },
378 )
379}
380
381/// Defines an asynchronous method `with_permit` for working with a permit to
382/// send a message.
383#[async_trait]
384pub trait WithPermit<T> {
385 async fn with_permit<F: Future + Send>(&self, f: F) -> Option<(Permit<T>, F::Output)>;
386}
387
388#[async_trait]
389impl<T: Send> WithPermit<T> for Sender<T> {
390 /// Asynchronously reserves a permit for sending a message and then executes
391 /// the given future (`f`). If a permit can be successfully reserved, it
392 /// returns a tuple containing the `Permit` and the result of the future.
393 /// If the permit reservation fails, `None` is returned, indicating that no
394 /// slot is available. This method ensures that the future only proceeds
395 /// if the channel has available capacity.
396 async fn with_permit<F: Future + Send>(&self, f: F) -> Option<(Permit<T>, F::Output)> {
397 let permit = self.reserve().await.ok()?;
398 Some((permit, f.await))
399 }
400}