1use std::{
6 collections::{
7 HashMap,
8 hash_map::{DefaultHasher, RandomState},
9 },
10 error::Error,
11 fmt,
12 hash::{BuildHasher, Hash, Hasher},
13 sync::{
14 Arc,
15 atomic::{AtomicBool, AtomicUsize, Ordering},
16 },
17 time::Duration,
18};
19
20use async_trait::async_trait;
21use iota_metrics::spawn_monitored_task;
22use tokio::{
23 sync::{
24 Mutex, OwnedMutexGuard, OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock, TryLockError,
25 },
26 task::JoinHandle,
27 time::Instant,
28};
29use tracing::info;
30
31#[async_trait]
32pub trait Lock: Send + Sync + Default {
33 type Guard;
34 type ReadGuard;
35 async fn lock_owned(self: Arc<Self>) -> Self::Guard;
36 fn try_lock_owned(self: Arc<Self>) -> Result<Self::Guard, TryLockError>;
37 async fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard;
38}
39
40#[async_trait::async_trait]
41impl Lock for Mutex<()> {
42 type Guard = OwnedMutexGuard<()>;
43 type ReadGuard = Self::Guard;
44
45 async fn lock_owned(self: Arc<Self>) -> Self::Guard {
46 self.lock_owned().await
47 }
48
49 fn try_lock_owned(self: Arc<Self>) -> Result<Self::Guard, TryLockError> {
50 self.try_lock_owned()
51 }
52
53 async fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard {
54 self.lock_owned().await
55 }
56}
57
58#[async_trait::async_trait]
59impl Lock for RwLock<()> {
60 type Guard = OwnedRwLockWriteGuard<()>;
61 type ReadGuard = OwnedRwLockReadGuard<()>;
62
63 async fn lock_owned(self: Arc<Self>) -> Self::Guard {
64 self.write_owned().await
65 }
66
67 fn try_lock_owned(self: Arc<Self>) -> Result<Self::Guard, TryLockError> {
68 self.try_write_owned()
69 }
70
71 async fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard {
72 self.read_owned().await
73 }
74}
75
76type InnerLockTable<K, L> = HashMap<K, Arc<L>>;
77pub struct LockTable<K: Hash, L: Lock> {
80 random_state: RandomState,
81 lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>>,
82 _k: std::marker::PhantomData<K>,
83 _cleaner: JoinHandle<()>,
84 stop: Arc<AtomicBool>,
85 size: Arc<AtomicUsize>,
86}
87
88pub type MutexTable<K> = LockTable<K, Mutex<()>>;
89pub type RwLockTable<K> = LockTable<K, RwLock<()>>;
90
91#[derive(Debug)]
92pub enum TryAcquireLockError {
93 LockTableLocked,
94 LockEntryLocked,
95}
96
97impl fmt::Display for TryAcquireLockError {
98 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
99 write!(fmt, "operation would block")
100 }
101}
102
103impl Error for TryAcquireLockError {}
104pub type MutexGuard = tokio::sync::OwnedMutexGuard<()>;
105pub type RwLockGuard = tokio::sync::OwnedRwLockReadGuard<()>;
106
107impl<K: Hash + Eq + Send + Sync + 'static, L: Lock + 'static> LockTable<K, L> {
108 pub fn new_with_cleanup(
109 num_shards: usize,
110 cleanup_period: Duration,
111 cleanup_initial_delay: Duration,
112 cleanup_entries_threshold: usize,
113 ) -> Self {
114 let num_shards = if cfg!(msim) { 4 } else { num_shards };
115
116 let lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>> = Arc::new(
117 (0..num_shards)
118 .map(|_| RwLock::new(HashMap::new()))
119 .collect(),
120 );
121 let cloned = lock_table.clone();
122 let stop = Arc::new(AtomicBool::new(false));
123 let stop_cloned = stop.clone();
124 let size: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
125 let size_cloned = size.clone();
126 Self {
127 random_state: RandomState::new(),
128 lock_table,
129 _k: std::marker::PhantomData {},
130 _cleaner: spawn_monitored_task!(async move {
131 tokio::time::sleep(cleanup_initial_delay).await;
132 let mut previous_cleanup_instant = Instant::now();
133 while !stop_cloned.load(Ordering::SeqCst) {
134 if size_cloned.load(Ordering::SeqCst) >= cleanup_entries_threshold
135 || previous_cleanup_instant.elapsed() >= cleanup_period
136 {
137 let num_removed = Self::cleanup(cloned.clone());
138 size_cloned.fetch_sub(num_removed, Ordering::SeqCst);
139 previous_cleanup_instant = Instant::now();
140 }
141 tokio::time::sleep(Duration::from_secs(1)).await;
142 }
143 info!("Stopping mutex table cleanup!");
144 }),
145 stop,
146 size,
147 }
148 }
149
150 pub fn new(num_shards: usize) -> Self {
151 Self::new_with_cleanup(
152 num_shards,
153 Duration::from_secs(10),
154 Duration::from_secs(10),
155 10_000,
156 )
157 }
158
159 pub fn size(&self) -> usize {
160 self.size.load(Ordering::SeqCst)
161 }
162
163 pub fn cleanup(lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>>) -> usize {
164 let mut num_removed: usize = 0;
165 for shard in lock_table.iter() {
166 let map = shard.try_write();
167 if map.is_err() {
168 continue;
169 }
170 map.unwrap().retain(|_k, v| {
171 if Arc::strong_count(v) == 1 {
177 num_removed += 1;
178 false
179 } else {
180 true
181 }
182 });
183 }
184 num_removed
185 }
186
187 fn get_lock_idx(&self, key: &K) -> usize {
188 let mut hasher = if !cfg!(test) {
189 self.random_state.build_hasher()
190 } else {
191 DefaultHasher::new()
193 };
194
195 key.hash(&mut hasher);
196 let hash: usize = hasher.finish().try_into().unwrap();
198 hash % self.lock_table.len()
199 }
200
201 pub async fn acquire_locks<I>(&self, object_iter: I) -> Vec<L::Guard>
202 where
203 I: Iterator<Item = K>,
204 K: Ord,
205 {
206 let mut objects: Vec<K> = object_iter.into_iter().collect();
207 objects.sort_unstable();
208 objects.dedup();
209
210 let mut guards = Vec::with_capacity(objects.len());
211 for object in objects.into_iter() {
212 guards.push(self.acquire_lock(object).await);
213 }
214 guards
215 }
216
217 pub async fn acquire_read_locks(&self, mut objects: Vec<K>) -> Vec<L::ReadGuard>
218 where
219 K: Ord,
220 {
221 objects.sort_unstable();
222 objects.dedup();
223 let mut guards = Vec::with_capacity(objects.len());
224 for object in objects.into_iter() {
225 guards.push(self.get_lock(object).await.read_lock_owned().await);
226 }
227 guards
228 }
229
230 pub async fn get_lock(&self, k: K) -> Arc<L> {
231 let lock_idx = self.get_lock_idx(&k);
232 let element = {
233 let map = self.lock_table[lock_idx].read().await;
234 map.get(&k).cloned()
235 };
236 if let Some(element) = element {
237 element
238 } else {
239 let element = {
241 let mut map = self.lock_table[lock_idx].write().await;
242 map.entry(k)
243 .or_insert_with(|| {
244 self.size.fetch_add(1, Ordering::SeqCst);
245 Arc::new(L::default())
246 })
247 .clone()
248 };
249 element
250 }
251 }
252
253 pub async fn acquire_lock(&self, k: K) -> L::Guard {
254 self.get_lock(k).await.lock_owned().await
255 }
256
257 pub fn try_acquire_lock(&self, k: K) -> Result<L::Guard, TryAcquireLockError> {
258 let lock_idx = self.get_lock_idx(&k);
259 let element = {
260 let map = self.lock_table[lock_idx]
261 .try_read()
262 .map_err(|_| TryAcquireLockError::LockTableLocked)?;
263 map.get(&k).cloned()
264 };
265 if let Some(element) = element {
266 let lock = element.try_lock_owned();
267 Ok(lock.map_err(|_| TryAcquireLockError::LockEntryLocked)?)
268 } else {
269 let element = {
271 let mut map = self.lock_table[lock_idx]
272 .try_write()
273 .map_err(|_| TryAcquireLockError::LockTableLocked)?;
274 map.entry(k)
275 .or_insert_with(|| {
276 self.size.fetch_add(1, Ordering::SeqCst);
277 Arc::new(L::default())
278 })
279 .clone()
280 };
281 let lock = element.try_lock_owned();
282 lock.map_err(|_| TryAcquireLockError::LockEntryLocked)
283 }
284 }
285}
286
287impl<K: Hash, L: Lock> Drop for LockTable<K, L> {
288 fn drop(&mut self) {
289 self.stop.store(true, Ordering::SeqCst);
290 }
291}
292
293#[tokio::test]
294async fn test_mutex_table_concurrent_in_same_bucket() {
298 use tokio::time::{sleep, timeout};
299 let mutex_table = Arc::new(MutexTable::<String>::new(1));
300 let john = mutex_table.try_acquire_lock("john".to_string());
301 john.unwrap();
302 {
303 let mutex_table = mutex_table.clone();
304 spawn_monitored_task!(async move {
305 mutex_table.acquire_lock("john".to_string()).await;
306 });
307 }
308 sleep(Duration::from_millis(50)).await;
309 let jane = mutex_table.try_acquire_lock("jane".to_string());
310 jane.unwrap();
311
312 let mutex_table = Arc::new(MutexTable::<String>::new(1));
313 let _john = mutex_table.acquire_lock("john".to_string()).await;
314 {
315 let mutex_table = mutex_table.clone();
316 spawn_monitored_task!(async move {
317 mutex_table.acquire_lock("john".to_string()).await;
318 });
319 }
320 sleep(Duration::from_millis(50)).await;
321 let jane = timeout(
322 Duration::from_secs(1),
323 mutex_table.acquire_lock("jane".to_string()),
324 )
325 .await;
326 jane.unwrap();
327}
328
329#[tokio::test]
330async fn test_mutex_table() {
331 let mutex_table =
333 MutexTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
334 let john1 = mutex_table.try_acquire_lock("john".to_string());
335 assert!(john1.is_ok());
336 let john2 = mutex_table.try_acquire_lock("john".to_string());
337 assert!(john2.is_err());
338 drop(john1);
339 let john2 = mutex_table.try_acquire_lock("john".to_string());
340 assert!(john2.is_ok());
341 let jane = mutex_table.try_acquire_lock("jane".to_string());
342 assert!(jane.is_ok());
343 MutexTable::cleanup(mutex_table.lock_table.clone());
344 let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
345 assert!(map.is_ok());
346 assert_eq!(map.unwrap().len(), 2);
347 drop(john2);
348 MutexTable::cleanup(mutex_table.lock_table.clone());
349 let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
350 assert!(map.is_ok());
351 assert_eq!(map.unwrap().len(), 1);
352 drop(jane);
353 MutexTable::cleanup(mutex_table.lock_table.clone());
354 let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
355 assert!(map.is_ok());
356 assert!(map.unwrap().is_empty());
357}
358
359#[tokio::test]
360async fn test_acquire_locks() {
361 let mutex_table =
362 RwLockTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
363 let object_1 = "object 1".to_string();
364 let object_2 = "object 2".to_string();
365 let object_3 = "object 3".to_string();
366
367 let objects = vec![
369 object_1.clone(),
370 object_2.clone(),
371 object_2,
372 object_1.clone(),
373 object_3,
374 object_1,
375 ];
376
377 let locks = mutex_table.acquire_locks(objects.clone().into_iter()).await;
378 assert_eq!(locks.len(), 3);
379
380 for object in objects.clone() {
381 assert!(mutex_table.try_acquire_lock(object).is_err());
382 }
383
384 drop(locks);
385 let locks = mutex_table.acquire_locks(objects.into_iter()).await;
386 assert_eq!(locks.len(), 3);
387}
388
389#[tokio::test]
390async fn test_read_locks() {
391 let mutex_table =
392 RwLockTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
393 let lock = "lock".to_string();
394 let locks1 = mutex_table.acquire_read_locks(vec![lock.clone()]).await;
395 assert!(mutex_table.try_acquire_lock(lock.clone()).is_err());
396 let locks2 = mutex_table.acquire_read_locks(vec![lock.clone()]).await;
397 drop(locks1);
398 drop(locks2);
399 assert!(mutex_table.try_acquire_lock(lock.clone()).is_ok());
400}
401
402#[tokio::test(flavor = "current_thread", start_paused = true)]
403async fn test_mutex_table_bg_cleanup() {
404 let mutex_table = MutexTable::<String>::new_with_cleanup(
405 1,
406 Duration::from_secs(5),
407 Duration::from_secs(1),
408 1000,
409 );
410 let lock1 = mutex_table.try_acquire_lock("lock1".to_string());
411 let lock2 = mutex_table.try_acquire_lock("lock2".to_string());
412 let lock3 = mutex_table.try_acquire_lock("lock3".to_string());
413 let lock4 = mutex_table.try_acquire_lock("lock4".to_string());
414 let lock5 = mutex_table.try_acquire_lock("lock5".to_string());
415 assert!(lock1.is_ok());
416 assert!(lock2.is_ok());
417 assert!(lock3.is_ok());
418 assert!(lock4.is_ok());
419 assert!(lock5.is_ok());
420 MutexTable::cleanup(mutex_table.lock_table.clone());
422 let lock11 = mutex_table.try_acquire_lock("lock1".to_string());
425 let lock22 = mutex_table.try_acquire_lock("lock2".to_string());
426 let lock33 = mutex_table.try_acquire_lock("lock3".to_string());
427 let lock44 = mutex_table.try_acquire_lock("lock4".to_string());
428 let lock55 = mutex_table.try_acquire_lock("lock5".to_string());
429 assert!(lock11.is_err());
430 assert!(lock22.is_err());
431 assert!(lock33.is_err());
432 assert!(lock44.is_err());
433 assert!(lock55.is_err());
434 drop(lock1);
436 drop(lock2);
437 drop(lock3);
438 drop(lock4);
439 drop(lock5);
440 tokio::time::sleep(Duration::from_secs(10)).await;
442 for entry in mutex_table.lock_table.iter() {
443 let locked = entry.read().await;
444 assert!(locked.is_empty());
445 }
446}
447
448#[tokio::test(flavor = "current_thread", start_paused = true)]
449async fn test_mutex_table_bg_cleanup_with_size_threshold() {
450 let mutex_table =
453 MutexTable::<String>::new_with_cleanup(1, Duration::MAX, Duration::from_secs(1), 5);
454 let lock1 = mutex_table.try_acquire_lock("lock1".to_string());
455 let lock2 = mutex_table.try_acquire_lock("lock2".to_string());
456 let lock3 = mutex_table.try_acquire_lock("lock3".to_string());
457 let lock4 = mutex_table.try_acquire_lock("lock4".to_string());
458 let lock5 = mutex_table.try_acquire_lock("lock5".to_string());
459 assert!(lock1.is_ok());
460 assert!(lock2.is_ok());
461 assert!(lock3.is_ok());
462 assert!(lock4.is_ok());
463 assert!(lock5.is_ok());
464 MutexTable::cleanup(mutex_table.lock_table.clone());
466 let lock11 = mutex_table.try_acquire_lock("lock1".to_string());
469 let lock22 = mutex_table.try_acquire_lock("lock2".to_string());
470 let lock33 = mutex_table.try_acquire_lock("lock3".to_string());
471 let lock44 = mutex_table.try_acquire_lock("lock4".to_string());
472 let lock55 = mutex_table.try_acquire_lock("lock5".to_string());
473 assert!(lock11.is_err());
474 assert!(lock22.is_err());
475 assert!(lock33.is_err());
476 assert!(lock44.is_err());
477 assert!(lock55.is_err());
478 assert_eq!(mutex_table.size(), 5);
479 drop(lock1);
481 drop(lock2);
482 drop(lock3);
483 drop(lock4);
484 drop(lock5);
485 tokio::task::yield_now().await;
486 tokio::time::advance(Duration::from_secs(5)).await;
488 tokio::task::yield_now().await;
489 assert_eq!(mutex_table.size(), 0);
490 for entry in mutex_table.lock_table.iter() {
491 let locked = entry.read().await;
492 assert!(locked.is_empty());
493 }
494}