iota_common/sync/
notify_read.rs1use std::{
6 collections::{HashMap, hash_map::DefaultHasher},
7 error::Error,
8 future::Future,
9 hash::{Hash, Hasher},
10 mem,
11 pin::Pin,
12 sync::atomic::{AtomicUsize, Ordering},
13 task::{Context, Poll},
14};
15
16use futures::future::{Either, join_all};
17use parking_lot::{Mutex, MutexGuard};
18use tokio::sync::oneshot;
19
20type Registrations<V> = Vec<oneshot::Sender<V>>;
21
22pub struct NotifyRead<K, V> {
23 pending: Vec<Mutex<HashMap<K, Registrations<V>>>>,
24 count_pending: AtomicUsize,
25}
26
27impl<K: Eq + Hash + Clone, V: Clone> NotifyRead<K, V> {
28 pub fn new() -> Self {
29 let pending = (0..255).map(|_| Default::default()).collect();
30 let count_pending = Default::default();
31 Self {
32 pending,
33 count_pending,
34 }
35 }
36
37 pub fn notify(&self, key: &K, value: &V) -> usize {
40 let registrations = self.pending(key).remove(key);
41 let Some(registrations) = registrations else {
42 return self.count_pending.load(Ordering::Relaxed);
43 };
44 let rem = self
45 .count_pending
46 .fetch_sub(registrations.len(), Ordering::Relaxed);
47 for registration in registrations {
48 registration.send(value.clone()).ok();
49 }
50 rem
51 }
52
53 pub fn register_one(&self, key: &K) -> Registration<K, V> {
54 self.count_pending.fetch_add(1, Ordering::Relaxed);
55 let (sender, receiver) = oneshot::channel();
56 self.register(key, sender);
57 Registration {
58 this: self,
59 registration: Some((key.clone(), receiver)),
60 }
61 }
62
63 pub fn register_all(&self, keys: &[K]) -> Vec<Registration<K, V>> {
64 self.count_pending.fetch_add(keys.len(), Ordering::Relaxed);
65 let mut registrations = vec![];
66 for key in keys.iter() {
67 let (sender, receiver) = oneshot::channel();
68 self.register(key, sender);
69 let registration = Registration {
70 this: self,
71 registration: Some((key.clone(), receiver)),
72 };
73 registrations.push(registration);
74 }
75 registrations
76 }
77
78 fn register(&self, key: &K, sender: oneshot::Sender<V>) {
79 self.pending(key)
80 .entry(key.clone())
81 .or_default()
82 .push(sender);
83 }
84
85 fn pending(&self, key: &K) -> MutexGuard<HashMap<K, Registrations<V>>> {
86 let mut state = DefaultHasher::new();
87 key.hash(&mut state);
88 let hash = state.finish();
89 let pending = self
90 .pending
91 .get((hash % self.pending.len() as u64) as usize)
92 .unwrap();
93 pending.lock()
94 }
95
96 pub fn num_pending(&self) -> usize {
97 self.count_pending.load(Ordering::Relaxed)
98 }
99
100 fn cleanup(&self, key: &K) {
101 let mut pending = self.pending(key);
102 let Some(registrations) = pending.get_mut(key) else {
104 return;
105 };
106 let mut count_deleted = 0usize;
107 registrations.retain(|s| {
108 let delete = s.is_closed();
109 if delete {
110 count_deleted += 1;
111 }
112 !delete
113 });
114 self.count_pending
115 .fetch_sub(count_deleted, Ordering::Relaxed);
116 if registrations.is_empty() {
117 pending.remove(key);
118 }
119 }
120}
121
122impl<K: Eq + Hash + Clone + Unpin, V: Clone + Unpin> NotifyRead<K, V> {
123 pub async fn read<E: Error>(
124 &self,
125 keys: &[K],
126 fetch: impl FnOnce(&[K]) -> Result<Vec<Option<V>>, E>,
127 ) -> Result<Vec<V>, E> {
128 let registrations = self.register_all(keys);
129
130 let results = fetch(keys)?;
131
132 let results = results
133 .into_iter()
134 .zip(registrations)
135 .map(|(a, r)| match a {
136 Some(ready) => Either::Left(futures::future::ready(ready)),
138 None => Either::Right(r),
139 });
140
141 Ok(join_all(results).await)
142 }
143}
144
145pub struct Registration<'a, K: Eq + Hash + Clone, V: Clone> {
149 this: &'a NotifyRead<K, V>,
150 registration: Option<(K, oneshot::Receiver<V>)>,
151}
152
153impl<K: Eq + Hash + Clone + Unpin, V: Clone + Unpin> Future for Registration<'_, K, V> {
154 type Output = V;
155
156 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
157 let receiver = self
158 .registration
159 .as_mut()
160 .map(|(_key, receiver)| receiver)
161 .expect("poll can not be called after drop");
162 let poll = Pin::new(receiver).poll(cx);
163 if poll.is_ready() {
164 self.registration.take();
166 }
167 poll.map(|r| r.expect("Sender never drops when registration is pending"))
168 }
169}
170
171impl<K: Eq + Hash + Clone, V: Clone> Drop for Registration<'_, K, V> {
172 fn drop(&mut self) {
173 if let Some((key, receiver)) = self.registration.take() {
174 mem::drop(receiver);
175 self.this.cleanup(&key)
177 }
178 }
179}
180impl<K: Eq + Hash + Clone, V: Clone> Default for NotifyRead<K, V> {
181 fn default() -> Self {
182 Self::new()
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use futures::future::join_all;
189
190 use super::*;
191
192 #[tokio::test]
193 pub async fn test_notify_read() {
194 let notify_read = NotifyRead::<u64, u64>::new();
195 let mut registrations = notify_read.register_all(&[1, 2, 3]);
196 assert_eq!(3, notify_read.count_pending.load(Ordering::Relaxed));
197 registrations.pop();
198 assert_eq!(2, notify_read.count_pending.load(Ordering::Relaxed));
199 notify_read.notify(&2, &2);
200 notify_read.notify(&1, &1);
201 let reads = join_all(registrations).await;
202 assert_eq!(0, notify_read.count_pending.load(Ordering::Relaxed));
203 assert_eq!(reads, vec![1, 2]);
204 for pending in ¬ify_read.pending {
206 assert!(pending.lock().is_empty());
207 }
208 }
209}