iota_common/sync/
notify_read.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::{HashMap, hash_map::DefaultHasher},
7    error::Error,
8    future::Future,
9    hash::{Hash, Hasher},
10    mem,
11    pin::Pin,
12    sync::atomic::{AtomicUsize, Ordering},
13    task::{Context, Poll},
14};
15
16use futures::future::{Either, join_all};
17use parking_lot::{Mutex, MutexGuard};
18use tokio::sync::oneshot;
19
20type Registrations<V> = Vec<oneshot::Sender<V>>;
21
22pub struct NotifyRead<K, V> {
23    pending: Vec<Mutex<HashMap<K, Registrations<V>>>>,
24    count_pending: AtomicUsize,
25}
26
27impl<K: Eq + Hash + Clone, V: Clone> NotifyRead<K, V> {
28    pub fn new() -> Self {
29        let pending = (0..255).map(|_| Default::default()).collect();
30        let count_pending = Default::default();
31        Self {
32            pending,
33            count_pending,
34        }
35    }
36
37    /// Asynchronously notifies waiters and return number of remaining pending
38    /// registration
39    pub fn notify(&self, key: &K, value: &V) -> usize {
40        let registrations = self.pending(key).remove(key);
41        let Some(registrations) = registrations else {
42            return self.count_pending.load(Ordering::Relaxed);
43        };
44        let rem = self
45            .count_pending
46            .fetch_sub(registrations.len(), Ordering::Relaxed);
47        for registration in registrations {
48            registration.send(value.clone()).ok();
49        }
50        rem
51    }
52
53    pub fn register_one(&self, key: &K) -> Registration<K, V> {
54        self.count_pending.fetch_add(1, Ordering::Relaxed);
55        let (sender, receiver) = oneshot::channel();
56        self.register(key, sender);
57        Registration {
58            this: self,
59            registration: Some((key.clone(), receiver)),
60        }
61    }
62
63    pub fn register_all(&self, keys: &[K]) -> Vec<Registration<K, V>> {
64        self.count_pending.fetch_add(keys.len(), Ordering::Relaxed);
65        let mut registrations = vec![];
66        for key in keys.iter() {
67            let (sender, receiver) = oneshot::channel();
68            self.register(key, sender);
69            let registration = Registration {
70                this: self,
71                registration: Some((key.clone(), receiver)),
72            };
73            registrations.push(registration);
74        }
75        registrations
76    }
77
78    fn register(&self, key: &K, sender: oneshot::Sender<V>) {
79        self.pending(key)
80            .entry(key.clone())
81            .or_default()
82            .push(sender);
83    }
84
85    fn pending(&self, key: &K) -> MutexGuard<HashMap<K, Registrations<V>>> {
86        let mut state = DefaultHasher::new();
87        key.hash(&mut state);
88        let hash = state.finish();
89        let pending = self
90            .pending
91            .get((hash % self.pending.len() as u64) as usize)
92            .unwrap();
93        pending.lock()
94    }
95
96    pub fn num_pending(&self) -> usize {
97        self.count_pending.load(Ordering::Relaxed)
98    }
99
100    fn cleanup(&self, key: &K) {
101        let mut pending = self.pending(key);
102        // it is possible that registration was fulfilled before we get here
103        let Some(registrations) = pending.get_mut(key) else {
104            return;
105        };
106        let mut count_deleted = 0usize;
107        registrations.retain(|s| {
108            let delete = s.is_closed();
109            if delete {
110                count_deleted += 1;
111            }
112            !delete
113        });
114        self.count_pending
115            .fetch_sub(count_deleted, Ordering::Relaxed);
116        if registrations.is_empty() {
117            pending.remove(key);
118        }
119    }
120}
121
122impl<K: Eq + Hash + Clone + Unpin, V: Clone + Unpin> NotifyRead<K, V> {
123    pub async fn read<E: Error>(
124        &self,
125        keys: &[K],
126        fetch: impl FnOnce(&[K]) -> Result<Vec<Option<V>>, E>,
127    ) -> Result<Vec<V>, E> {
128        let registrations = self.register_all(keys);
129
130        let results = fetch(keys)?;
131
132        let results = results
133            .into_iter()
134            .zip(registrations)
135            .map(|(a, r)| match a {
136                // Note that Some() clause also drops registration that is already fulfilled
137                Some(ready) => Either::Left(futures::future::ready(ready)),
138                None => Either::Right(r),
139            });
140
141        Ok(join_all(results).await)
142    }
143}
144
145/// Registration resolves to the value but also provides safe cancellation
146/// When Registration is dropped before it is resolved, we de-register from the
147/// pending list
148pub struct Registration<'a, K: Eq + Hash + Clone, V: Clone> {
149    this: &'a NotifyRead<K, V>,
150    registration: Option<(K, oneshot::Receiver<V>)>,
151}
152
153impl<K: Eq + Hash + Clone + Unpin, V: Clone + Unpin> Future for Registration<'_, K, V> {
154    type Output = V;
155
156    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
157        let receiver = self
158            .registration
159            .as_mut()
160            .map(|(_key, receiver)| receiver)
161            .expect("poll can not be called after drop");
162        let poll = Pin::new(receiver).poll(cx);
163        if poll.is_ready() {
164            // When polling complete we no longer need to cancel
165            self.registration.take();
166        }
167        poll.map(|r| r.expect("Sender never drops when registration is pending"))
168    }
169}
170
171impl<K: Eq + Hash + Clone, V: Clone> Drop for Registration<'_, K, V> {
172    fn drop(&mut self) {
173        if let Some((key, receiver)) = self.registration.take() {
174            mem::drop(receiver);
175            // Receiver is dropped before cleanup
176            self.this.cleanup(&key)
177        }
178    }
179}
180impl<K: Eq + Hash + Clone, V: Clone> Default for NotifyRead<K, V> {
181    fn default() -> Self {
182        Self::new()
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use futures::future::join_all;
189
190    use super::*;
191
192    #[tokio::test]
193    pub async fn test_notify_read() {
194        let notify_read = NotifyRead::<u64, u64>::new();
195        let mut registrations = notify_read.register_all(&[1, 2, 3]);
196        assert_eq!(3, notify_read.count_pending.load(Ordering::Relaxed));
197        registrations.pop();
198        assert_eq!(2, notify_read.count_pending.load(Ordering::Relaxed));
199        notify_read.notify(&2, &2);
200        notify_read.notify(&1, &1);
201        let reads = join_all(registrations).await;
202        assert_eq!(0, notify_read.count_pending.load(Ordering::Relaxed));
203        assert_eq!(reads, vec![1, 2]);
204        // ensure cleanup is done correctly
205        for pending in &notify_read.pending {
206            assert!(pending.lock().is_empty());
207        }
208    }
209}