1use std::{
6 collections::{
7 HashMap,
8 hash_map::{DefaultHasher, RandomState},
9 },
10 error::Error,
11 fmt,
12 hash::{BuildHasher, Hash, Hasher},
13 sync::{
14 Arc,
15 atomic::{AtomicBool, AtomicUsize, Ordering},
16 },
17 time::Duration,
18};
19
20use iota_metrics::spawn_monitored_task;
21use parking_lot::{ArcMutexGuard, ArcRwLockReadGuard, ArcRwLockWriteGuard, Mutex, RwLock};
22use tokio::{task::JoinHandle, time::Instant};
23use tracing::info;
24
25type OwnedMutexGuard<T> = ArcMutexGuard<parking_lot::RawMutex, T>;
26type OwnedRwLockReadGuard<T> = ArcRwLockReadGuard<parking_lot::RawRwLock, T>;
27type OwnedRwLockWriteGuard<T> = ArcRwLockWriteGuard<parking_lot::RawRwLock, T>;
28
29pub trait Lock: Send + Sync + Default {
30 type Guard;
31 type ReadGuard;
32 fn lock_owned(self: Arc<Self>) -> Self::Guard;
33 fn try_lock_owned(self: Arc<Self>) -> Option<Self::Guard>;
34 fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard;
35}
36
37impl Lock for Mutex<()> {
38 type Guard = OwnedMutexGuard<()>;
39 type ReadGuard = Self::Guard;
40
41 fn lock_owned(self: Arc<Self>) -> Self::Guard {
42 self.lock_arc()
43 }
44
45 fn try_lock_owned(self: Arc<Self>) -> Option<Self::Guard> {
46 self.try_lock_arc()
47 }
48
49 fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard {
50 self.lock_arc()
51 }
52}
53
54impl Lock for RwLock<()> {
55 type Guard = OwnedRwLockWriteGuard<()>;
56 type ReadGuard = OwnedRwLockReadGuard<()>;
57
58 fn lock_owned(self: Arc<Self>) -> Self::Guard {
59 self.write_arc()
60 }
61
62 fn try_lock_owned(self: Arc<Self>) -> Option<Self::Guard> {
63 self.try_write_arc()
64 }
65
66 fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard {
67 self.read_arc()
68 }
69}
70
71type InnerLockTable<K, L> = HashMap<K, Arc<L>>;
72pub struct LockTable<K: Hash, L: Lock> {
75 random_state: RandomState,
76 lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>>,
77 _k: std::marker::PhantomData<K>,
78 _cleaner: JoinHandle<()>,
79 stop: Arc<AtomicBool>,
80 size: Arc<AtomicUsize>,
81}
82
83pub type MutexTable<K> = LockTable<K, Mutex<()>>;
84pub type RwLockTable<K> = LockTable<K, RwLock<()>>;
85
86#[derive(Debug)]
87pub enum TryAcquireLockError {
88 LockTableLocked,
89 LockEntryLocked,
90}
91
92impl fmt::Display for TryAcquireLockError {
93 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
94 write!(fmt, "operation would block")
95 }
96}
97
98impl Error for TryAcquireLockError {}
99pub type MutexGuard = OwnedMutexGuard<()>;
100pub type RwLockGuard = OwnedRwLockReadGuard<()>;
101
102impl<K: Hash + Eq + Send + Sync + 'static, L: Lock + 'static> LockTable<K, L> {
103 pub fn new_with_cleanup(
104 num_shards: usize,
105 cleanup_period: Duration,
106 cleanup_initial_delay: Duration,
107 cleanup_entries_threshold: usize,
108 ) -> Self {
109 let num_shards = if cfg!(msim) { 4 } else { num_shards };
110
111 let lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>> = Arc::new(
112 (0..num_shards)
113 .map(|_| RwLock::new(HashMap::new()))
114 .collect(),
115 );
116 let cloned = lock_table.clone();
117 let stop = Arc::new(AtomicBool::new(false));
118 let stop_cloned = stop.clone();
119 let size: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
120 let size_cloned = size.clone();
121 Self {
122 random_state: RandomState::new(),
123 lock_table,
124 _k: std::marker::PhantomData {},
125 _cleaner: spawn_monitored_task!(async move {
126 tokio::time::sleep(cleanup_initial_delay).await;
127 let mut previous_cleanup_instant = Instant::now();
128 while !stop_cloned.load(Ordering::SeqCst) {
129 if size_cloned.load(Ordering::SeqCst) >= cleanup_entries_threshold
130 || previous_cleanup_instant.elapsed() >= cleanup_period
131 {
132 let num_removed = Self::cleanup(cloned.clone());
133 size_cloned.fetch_sub(num_removed, Ordering::SeqCst);
134 previous_cleanup_instant = Instant::now();
135 }
136 tokio::time::sleep(Duration::from_secs(1)).await;
137 }
138 info!("Stopping mutex table cleanup!");
139 }),
140 stop,
141 size,
142 }
143 }
144
145 pub fn new(num_shards: usize) -> Self {
146 Self::new_with_cleanup(
147 num_shards,
148 Duration::from_secs(10),
149 Duration::from_secs(10),
150 10_000,
151 )
152 }
153
154 pub fn size(&self) -> usize {
155 self.size.load(Ordering::SeqCst)
156 }
157
158 pub fn cleanup(lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>>) -> usize {
159 let mut num_removed: usize = 0;
160 for shard in lock_table.iter() {
161 let map = shard.try_write();
162 if map.is_none() {
163 continue;
164 }
165 map.unwrap().retain(|_k, v| {
166 if Arc::strong_count(v) == 1 {
172 num_removed += 1;
173 false
174 } else {
175 true
176 }
177 });
178 }
179 num_removed
180 }
181
182 fn get_lock_idx(&self, key: &K) -> usize {
183 let mut hasher = if !cfg!(test) {
184 self.random_state.build_hasher()
185 } else {
186 DefaultHasher::new()
188 };
189
190 key.hash(&mut hasher);
191 let hash: usize = hasher.finish().try_into().unwrap();
193 hash % self.lock_table.len()
194 }
195
196 pub fn acquire_locks<I>(&self, object_iter: I) -> Vec<L::Guard>
197 where
198 I: Iterator<Item = K>,
199 K: Ord,
200 {
201 let mut objects: Vec<K> = object_iter.into_iter().collect();
202 objects.sort_unstable();
203 objects.dedup();
204
205 let mut guards = Vec::with_capacity(objects.len());
206 for object in objects.into_iter() {
207 guards.push(self.acquire_lock(object));
208 }
209 guards
210 }
211
212 pub fn acquire_read_locks(&self, mut objects: Vec<K>) -> Vec<L::ReadGuard>
213 where
214 K: Ord,
215 {
216 objects.sort_unstable();
217 objects.dedup();
218 let mut guards = Vec::with_capacity(objects.len());
219 for object in objects.into_iter() {
220 guards.push(self.get_lock(object).read_lock_owned());
221 }
222 guards
223 }
224
225 pub fn get_lock(&self, k: K) -> Arc<L> {
226 let lock_idx = self.get_lock_idx(&k);
227 let element = {
228 let map = self.lock_table[lock_idx].read();
229 map.get(&k).cloned()
230 };
231 if let Some(element) = element {
232 element
233 } else {
234 let element = {
236 let mut map = self.lock_table[lock_idx].write();
237 map.entry(k)
238 .or_insert_with(|| {
239 self.size.fetch_add(1, Ordering::SeqCst);
240 Arc::new(L::default())
241 })
242 .clone()
243 };
244 element
245 }
246 }
247
248 pub fn acquire_lock(&self, k: K) -> L::Guard {
249 self.get_lock(k).lock_owned()
250 }
251
252 pub fn try_acquire_lock(&self, k: K) -> Result<L::Guard, TryAcquireLockError> {
253 let lock_idx = self.get_lock_idx(&k);
254 let element = {
255 let map = self.lock_table[lock_idx]
256 .try_read()
257 .ok_or(TryAcquireLockError::LockTableLocked)?;
258 map.get(&k).cloned()
259 };
260 if let Some(element) = element {
261 let lock = element.try_lock_owned();
262 lock.ok_or(TryAcquireLockError::LockEntryLocked)
263 } else {
264 let element = {
266 let mut map = self.lock_table[lock_idx]
267 .try_write()
268 .ok_or(TryAcquireLockError::LockTableLocked)?;
269 map.entry(k)
270 .or_insert_with(|| {
271 self.size.fetch_add(1, Ordering::SeqCst);
272 Arc::new(L::default())
273 })
274 .clone()
275 };
276 let lock = element.try_lock_owned();
277 lock.ok_or(TryAcquireLockError::LockEntryLocked)
278 }
279 }
280}
281
282impl<K: Hash, L: Lock> Drop for LockTable<K, L> {
283 fn drop(&mut self) {
284 self.stop.store(true, Ordering::SeqCst);
285 }
286}
287
288#[tokio::test]
289async fn test_mutex_table_concurrent_in_same_bucket() {
293 use tokio::time::{sleep, timeout};
294 let mutex_table = Arc::new(MutexTable::<String>::new(1));
295 let john = mutex_table.try_acquire_lock("john".to_string());
296 let _ = john.unwrap();
297 {
298 let mutex_table = mutex_table.clone();
299 std::thread::spawn(move || {
300 let _ = mutex_table.acquire_lock("john".to_string());
301 });
302 }
303 sleep(Duration::from_millis(50)).await;
304 let jane = mutex_table.try_acquire_lock("jane".to_string());
305 let _ = jane.unwrap();
306
307 let mutex_table = Arc::new(MutexTable::<String>::new(1));
308 let _john = mutex_table.acquire_lock("john".to_string());
309 {
310 let mutex_table = mutex_table.clone();
311 std::thread::spawn(move || {
312 let _ = mutex_table.acquire_lock("john".to_string());
313 });
314 }
315 sleep(Duration::from_millis(50)).await;
316 let jane = timeout(
317 Duration::from_secs(1),
318 tokio::task::spawn_blocking(move || {
319 let _ = mutex_table.acquire_lock("jane".to_string());
320 }),
321 )
322 .await;
323 let _ = jane.unwrap();
324}
325
326#[tokio::test]
327async fn test_mutex_table() {
328 let mutex_table =
330 MutexTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
331 let john1 = mutex_table.try_acquire_lock("john".to_string());
332 assert!(john1.is_ok());
333 let john2 = mutex_table.try_acquire_lock("john".to_string());
334 assert!(john2.is_err());
335 drop(john1);
336 let john2 = mutex_table.try_acquire_lock("john".to_string());
337 assert!(john2.is_ok());
338 let jane = mutex_table.try_acquire_lock("jane".to_string());
339 assert!(jane.is_ok());
340 MutexTable::cleanup(mutex_table.lock_table.clone());
341 let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
342 assert!(map.is_some());
343 assert_eq!(map.unwrap().len(), 2);
344 drop(john2);
345 MutexTable::cleanup(mutex_table.lock_table.clone());
346 let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
347 assert!(map.is_some());
348 assert_eq!(map.unwrap().len(), 1);
349 drop(jane);
350 MutexTable::cleanup(mutex_table.lock_table.clone());
351 let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
352 assert!(map.is_some());
353 assert!(map.unwrap().is_empty());
354}
355
356#[tokio::test]
357async fn test_acquire_locks() {
358 let mutex_table =
359 RwLockTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
360 let object_1 = "object 1".to_string();
361 let object_2 = "object 2".to_string();
362 let object_3 = "object 3".to_string();
363
364 let objects = vec![
366 object_1.clone(),
367 object_2.clone(),
368 object_2,
369 object_1.clone(),
370 object_3,
371 object_1,
372 ];
373
374 let locks = mutex_table.acquire_locks(objects.clone().into_iter());
375 assert_eq!(locks.len(), 3);
376
377 for object in objects.clone() {
378 assert!(mutex_table.try_acquire_lock(object).is_err());
379 }
380
381 drop(locks);
382 let locks = mutex_table.acquire_locks(objects.into_iter());
383 assert_eq!(locks.len(), 3);
384}
385
386#[tokio::test]
387async fn test_read_locks() {
388 let mutex_table =
389 RwLockTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
390 let lock = "lock".to_string();
391 let locks1 = mutex_table.acquire_read_locks(vec![lock.clone()]);
392 assert!(mutex_table.try_acquire_lock(lock.clone()).is_err());
393 let locks2 = mutex_table.acquire_read_locks(vec![lock.clone()]);
394 drop(locks1);
395 drop(locks2);
396 assert!(mutex_table.try_acquire_lock(lock.clone()).is_ok());
397}
398
399#[tokio::test(flavor = "current_thread", start_paused = true)]
400async fn test_mutex_table_bg_cleanup() {
401 let mutex_table = MutexTable::<String>::new_with_cleanup(
402 1,
403 Duration::from_secs(5),
404 Duration::from_secs(1),
405 1000,
406 );
407 let lock1 = mutex_table.try_acquire_lock("lock1".to_string());
408 let lock2 = mutex_table.try_acquire_lock("lock2".to_string());
409 let lock3 = mutex_table.try_acquire_lock("lock3".to_string());
410 let lock4 = mutex_table.try_acquire_lock("lock4".to_string());
411 let lock5 = mutex_table.try_acquire_lock("lock5".to_string());
412 assert!(lock1.is_ok());
413 assert!(lock2.is_ok());
414 assert!(lock3.is_ok());
415 assert!(lock4.is_ok());
416 assert!(lock5.is_ok());
417 MutexTable::cleanup(mutex_table.lock_table.clone());
419 let lock11 = mutex_table.try_acquire_lock("lock1".to_string());
422 let lock22 = mutex_table.try_acquire_lock("lock2".to_string());
423 let lock33 = mutex_table.try_acquire_lock("lock3".to_string());
424 let lock44 = mutex_table.try_acquire_lock("lock4".to_string());
425 let lock55 = mutex_table.try_acquire_lock("lock5".to_string());
426 assert!(lock11.is_err());
427 assert!(lock22.is_err());
428 assert!(lock33.is_err());
429 assert!(lock44.is_err());
430 assert!(lock55.is_err());
431 drop(lock1);
433 drop(lock2);
434 drop(lock3);
435 drop(lock4);
436 drop(lock5);
437 tokio::time::sleep(Duration::from_secs(10)).await;
439 for entry in mutex_table.lock_table.iter() {
440 let locked = entry.read();
441 assert!(locked.is_empty());
442 }
443}
444
445#[tokio::test(flavor = "current_thread", start_paused = true)]
446async fn test_mutex_table_bg_cleanup_with_size_threshold() {
447 let mutex_table =
450 MutexTable::<String>::new_with_cleanup(1, Duration::MAX, Duration::from_secs(1), 5);
451 let lock1 = mutex_table.try_acquire_lock("lock1".to_string());
452 let lock2 = mutex_table.try_acquire_lock("lock2".to_string());
453 let lock3 = mutex_table.try_acquire_lock("lock3".to_string());
454 let lock4 = mutex_table.try_acquire_lock("lock4".to_string());
455 let lock5 = mutex_table.try_acquire_lock("lock5".to_string());
456 assert!(lock1.is_ok());
457 assert!(lock2.is_ok());
458 assert!(lock3.is_ok());
459 assert!(lock4.is_ok());
460 assert!(lock5.is_ok());
461 MutexTable::cleanup(mutex_table.lock_table.clone());
463 let lock11 = mutex_table.try_acquire_lock("lock1".to_string());
466 let lock22 = mutex_table.try_acquire_lock("lock2".to_string());
467 let lock33 = mutex_table.try_acquire_lock("lock3".to_string());
468 let lock44 = mutex_table.try_acquire_lock("lock4".to_string());
469 let lock55 = mutex_table.try_acquire_lock("lock5".to_string());
470 assert!(lock11.is_err());
471 assert!(lock22.is_err());
472 assert!(lock33.is_err());
473 assert!(lock44.is_err());
474 assert!(lock55.is_err());
475 assert_eq!(mutex_table.size(), 5);
476 drop(lock1);
478 drop(lock2);
479 drop(lock3);
480 drop(lock4);
481 drop(lock5);
482 tokio::task::yield_now().await;
483 tokio::time::advance(Duration::from_secs(5)).await;
485 tokio::task::yield_now().await;
486 assert_eq!(mutex_table.size(), 0);
487 for entry in mutex_table.lock_table.iter() {
488 let locked = entry.read();
489 assert!(locked.is_empty());
490 }
491}