iota_storage/
mutex_table.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::{
7        HashMap,
8        hash_map::{DefaultHasher, RandomState},
9    },
10    error::Error,
11    fmt,
12    hash::{BuildHasher, Hash, Hasher},
13    sync::{
14        Arc,
15        atomic::{AtomicBool, AtomicUsize, Ordering},
16    },
17    time::Duration,
18};
19
20use iota_metrics::spawn_monitored_task;
21use parking_lot::{ArcMutexGuard, ArcRwLockReadGuard, ArcRwLockWriteGuard, Mutex, RwLock};
22use tokio::{task::JoinHandle, time::Instant};
23use tracing::info;
24
25type OwnedMutexGuard<T> = ArcMutexGuard<parking_lot::RawMutex, T>;
26type OwnedRwLockReadGuard<T> = ArcRwLockReadGuard<parking_lot::RawRwLock, T>;
27type OwnedRwLockWriteGuard<T> = ArcRwLockWriteGuard<parking_lot::RawRwLock, T>;
28
29pub trait Lock: Send + Sync + Default {
30    type Guard;
31    type ReadGuard;
32    fn lock_owned(self: Arc<Self>) -> Self::Guard;
33    fn try_lock_owned(self: Arc<Self>) -> Option<Self::Guard>;
34    fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard;
35}
36
37impl Lock for Mutex<()> {
38    type Guard = OwnedMutexGuard<()>;
39    type ReadGuard = Self::Guard;
40
41    fn lock_owned(self: Arc<Self>) -> Self::Guard {
42        self.lock_arc()
43    }
44
45    fn try_lock_owned(self: Arc<Self>) -> Option<Self::Guard> {
46        self.try_lock_arc()
47    }
48
49    fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard {
50        self.lock_arc()
51    }
52}
53
54impl Lock for RwLock<()> {
55    type Guard = OwnedRwLockWriteGuard<()>;
56    type ReadGuard = OwnedRwLockReadGuard<()>;
57
58    fn lock_owned(self: Arc<Self>) -> Self::Guard {
59        self.write_arc()
60    }
61
62    fn try_lock_owned(self: Arc<Self>) -> Option<Self::Guard> {
63        self.try_write_arc()
64    }
65
66    fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard {
67        self.read_arc()
68    }
69}
70
71type InnerLockTable<K, L> = HashMap<K, Arc<L>>;
72// MutexTable supports mutual exclusion on keys such as TransactionDigest or
73// ObjectDigest
74pub struct LockTable<K: Hash, L: Lock> {
75    random_state: RandomState,
76    lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>>,
77    _k: std::marker::PhantomData<K>,
78    _cleaner: JoinHandle<()>,
79    stop: Arc<AtomicBool>,
80    size: Arc<AtomicUsize>,
81}
82
83pub type MutexTable<K> = LockTable<K, Mutex<()>>;
84pub type RwLockTable<K> = LockTable<K, RwLock<()>>;
85
86#[derive(Debug)]
87pub enum TryAcquireLockError {
88    LockTableLocked,
89    LockEntryLocked,
90}
91
92impl fmt::Display for TryAcquireLockError {
93    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
94        write!(fmt, "operation would block")
95    }
96}
97
98impl Error for TryAcquireLockError {}
99pub type MutexGuard = OwnedMutexGuard<()>;
100pub type RwLockGuard = OwnedRwLockReadGuard<()>;
101
102impl<K: Hash + Eq + Send + Sync + 'static, L: Lock + 'static> LockTable<K, L> {
103    pub fn new_with_cleanup(
104        num_shards: usize,
105        cleanup_period: Duration,
106        cleanup_initial_delay: Duration,
107        cleanup_entries_threshold: usize,
108    ) -> Self {
109        let num_shards = if cfg!(msim) { 4 } else { num_shards };
110
111        let lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>> = Arc::new(
112            (0..num_shards)
113                .map(|_| RwLock::new(HashMap::new()))
114                .collect(),
115        );
116        let cloned = lock_table.clone();
117        let stop = Arc::new(AtomicBool::new(false));
118        let stop_cloned = stop.clone();
119        let size: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
120        let size_cloned = size.clone();
121        Self {
122            random_state: RandomState::new(),
123            lock_table,
124            _k: std::marker::PhantomData {},
125            _cleaner: spawn_monitored_task!(async move {
126                tokio::time::sleep(cleanup_initial_delay).await;
127                let mut previous_cleanup_instant = Instant::now();
128                while !stop_cloned.load(Ordering::SeqCst) {
129                    if size_cloned.load(Ordering::SeqCst) >= cleanup_entries_threshold
130                        || previous_cleanup_instant.elapsed() >= cleanup_period
131                    {
132                        let num_removed = Self::cleanup(cloned.clone());
133                        size_cloned.fetch_sub(num_removed, Ordering::SeqCst);
134                        previous_cleanup_instant = Instant::now();
135                    }
136                    tokio::time::sleep(Duration::from_secs(1)).await;
137                }
138                info!("Stopping mutex table cleanup!");
139            }),
140            stop,
141            size,
142        }
143    }
144
145    pub fn new(num_shards: usize) -> Self {
146        Self::new_with_cleanup(
147            num_shards,
148            Duration::from_secs(10),
149            Duration::from_secs(10),
150            10_000,
151        )
152    }
153
154    pub fn size(&self) -> usize {
155        self.size.load(Ordering::SeqCst)
156    }
157
158    pub fn cleanup(lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>>) -> usize {
159        let mut num_removed: usize = 0;
160        for shard in lock_table.iter() {
161            let map = shard.try_write();
162            if map.is_none() {
163                continue;
164            }
165            map.unwrap().retain(|_k, v| {
166                // MutexMap::(try_|)acquire_locks will lock the map and call Arc::clone on the
167                // entry This check ensures that we only drop entry from the map
168                // if this is the only mutex copy This check is also likely
169                // sufficient e.g. you don't even need try_lock below, but keeping it just in
170                // case
171                if Arc::strong_count(v) == 1 {
172                    num_removed += 1;
173                    false
174                } else {
175                    true
176                }
177            });
178        }
179        num_removed
180    }
181
182    fn get_lock_idx(&self, key: &K) -> usize {
183        let mut hasher = if !cfg!(test) {
184            self.random_state.build_hasher()
185        } else {
186            // be deterministic for tests
187            DefaultHasher::new()
188        };
189
190        key.hash(&mut hasher);
191        // unwrap ok - converting u64 -> usize
192        let hash: usize = hasher.finish().try_into().unwrap();
193        hash % self.lock_table.len()
194    }
195
196    pub fn acquire_locks<I>(&self, object_iter: I) -> Vec<L::Guard>
197    where
198        I: Iterator<Item = K>,
199        K: Ord,
200    {
201        let mut objects: Vec<K> = object_iter.into_iter().collect();
202        objects.sort_unstable();
203        objects.dedup();
204
205        let mut guards = Vec::with_capacity(objects.len());
206        for object in objects.into_iter() {
207            guards.push(self.acquire_lock(object));
208        }
209        guards
210    }
211
212    pub fn acquire_read_locks(&self, mut objects: Vec<K>) -> Vec<L::ReadGuard>
213    where
214        K: Ord,
215    {
216        objects.sort_unstable();
217        objects.dedup();
218        let mut guards = Vec::with_capacity(objects.len());
219        for object in objects.into_iter() {
220            guards.push(self.get_lock(object).read_lock_owned());
221        }
222        guards
223    }
224
225    pub fn get_lock(&self, k: K) -> Arc<L> {
226        let lock_idx = self.get_lock_idx(&k);
227        let element = {
228            let map = self.lock_table[lock_idx].read();
229            map.get(&k).cloned()
230        };
231        if let Some(element) = element {
232            element
233        } else {
234            // element doesn't exist
235            let element = {
236                let mut map = self.lock_table[lock_idx].write();
237                map.entry(k)
238                    .or_insert_with(|| {
239                        self.size.fetch_add(1, Ordering::SeqCst);
240                        Arc::new(L::default())
241                    })
242                    .clone()
243            };
244            element
245        }
246    }
247
248    pub fn acquire_lock(&self, k: K) -> L::Guard {
249        self.get_lock(k).lock_owned()
250    }
251
252    pub fn try_acquire_lock(&self, k: K) -> Result<L::Guard, TryAcquireLockError> {
253        let lock_idx = self.get_lock_idx(&k);
254        let element = {
255            let map = self.lock_table[lock_idx]
256                .try_read()
257                .ok_or(TryAcquireLockError::LockTableLocked)?;
258            map.get(&k).cloned()
259        };
260        if let Some(element) = element {
261            let lock = element.try_lock_owned();
262            lock.ok_or(TryAcquireLockError::LockEntryLocked)
263        } else {
264            // element doesn't exist
265            let element = {
266                let mut map = self.lock_table[lock_idx]
267                    .try_write()
268                    .ok_or(TryAcquireLockError::LockTableLocked)?;
269                map.entry(k)
270                    .or_insert_with(|| {
271                        self.size.fetch_add(1, Ordering::SeqCst);
272                        Arc::new(L::default())
273                    })
274                    .clone()
275            };
276            let lock = element.try_lock_owned();
277            lock.ok_or(TryAcquireLockError::LockEntryLocked)
278        }
279    }
280}
281
282impl<K: Hash, L: Lock> Drop for LockTable<K, L> {
283    fn drop(&mut self) {
284        self.stop.store(true, Ordering::SeqCst);
285    }
286}
287
288#[tokio::test]
289// Tests that mutex table provides parallelism on the individual mutex level,
290// e.g. that locks for different entries do not block entire bucket if it needs
291// to wait on individual lock
292async fn test_mutex_table_concurrent_in_same_bucket() {
293    use tokio::time::{sleep, timeout};
294    let mutex_table = Arc::new(MutexTable::<String>::new(1));
295    let john = mutex_table.try_acquire_lock("john".to_string());
296    let _ = john.unwrap();
297    {
298        let mutex_table = mutex_table.clone();
299        std::thread::spawn(move || {
300            let _ = mutex_table.acquire_lock("john".to_string());
301        });
302    }
303    sleep(Duration::from_millis(50)).await;
304    let jane = mutex_table.try_acquire_lock("jane".to_string());
305    let _ = jane.unwrap();
306
307    let mutex_table = Arc::new(MutexTable::<String>::new(1));
308    let _john = mutex_table.acquire_lock("john".to_string());
309    {
310        let mutex_table = mutex_table.clone();
311        std::thread::spawn(move || {
312            let _ = mutex_table.acquire_lock("john".to_string());
313        });
314    }
315    sleep(Duration::from_millis(50)).await;
316    let jane = timeout(
317        Duration::from_secs(1),
318        tokio::task::spawn_blocking(move || {
319            let _ = mutex_table.acquire_lock("jane".to_string());
320        }),
321    )
322    .await;
323    let _ = jane.unwrap();
324}
325
326#[tokio::test]
327async fn test_mutex_table() {
328    // Disable bg cleanup with Duration.MAX for initial delay
329    let mutex_table =
330        MutexTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
331    let john1 = mutex_table.try_acquire_lock("john".to_string());
332    assert!(john1.is_ok());
333    let john2 = mutex_table.try_acquire_lock("john".to_string());
334    assert!(john2.is_err());
335    drop(john1);
336    let john2 = mutex_table.try_acquire_lock("john".to_string());
337    assert!(john2.is_ok());
338    let jane = mutex_table.try_acquire_lock("jane".to_string());
339    assert!(jane.is_ok());
340    MutexTable::cleanup(mutex_table.lock_table.clone());
341    let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
342    assert!(map.is_some());
343    assert_eq!(map.unwrap().len(), 2);
344    drop(john2);
345    MutexTable::cleanup(mutex_table.lock_table.clone());
346    let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
347    assert!(map.is_some());
348    assert_eq!(map.unwrap().len(), 1);
349    drop(jane);
350    MutexTable::cleanup(mutex_table.lock_table.clone());
351    let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
352    assert!(map.is_some());
353    assert!(map.unwrap().is_empty());
354}
355
356#[tokio::test]
357async fn test_acquire_locks() {
358    let mutex_table =
359        RwLockTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
360    let object_1 = "object 1".to_string();
361    let object_2 = "object 2".to_string();
362    let object_3 = "object 3".to_string();
363
364    // ensure even with duplicate objects we succeed acquiring their locks
365    let objects = vec![
366        object_1.clone(),
367        object_2.clone(),
368        object_2,
369        object_1.clone(),
370        object_3,
371        object_1,
372    ];
373
374    let locks = mutex_table.acquire_locks(objects.clone().into_iter());
375    assert_eq!(locks.len(), 3);
376
377    for object in objects.clone() {
378        assert!(mutex_table.try_acquire_lock(object).is_err());
379    }
380
381    drop(locks);
382    let locks = mutex_table.acquire_locks(objects.into_iter());
383    assert_eq!(locks.len(), 3);
384}
385
386#[tokio::test]
387async fn test_read_locks() {
388    let mutex_table =
389        RwLockTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
390    let lock = "lock".to_string();
391    let locks1 = mutex_table.acquire_read_locks(vec![lock.clone()]);
392    assert!(mutex_table.try_acquire_lock(lock.clone()).is_err());
393    let locks2 = mutex_table.acquire_read_locks(vec![lock.clone()]);
394    drop(locks1);
395    drop(locks2);
396    assert!(mutex_table.try_acquire_lock(lock.clone()).is_ok());
397}
398
399#[tokio::test(flavor = "current_thread", start_paused = true)]
400async fn test_mutex_table_bg_cleanup() {
401    let mutex_table = MutexTable::<String>::new_with_cleanup(
402        1,
403        Duration::from_secs(5),
404        Duration::from_secs(1),
405        1000,
406    );
407    let lock1 = mutex_table.try_acquire_lock("lock1".to_string());
408    let lock2 = mutex_table.try_acquire_lock("lock2".to_string());
409    let lock3 = mutex_table.try_acquire_lock("lock3".to_string());
410    let lock4 = mutex_table.try_acquire_lock("lock4".to_string());
411    let lock5 = mutex_table.try_acquire_lock("lock5".to_string());
412    assert!(lock1.is_ok());
413    assert!(lock2.is_ok());
414    assert!(lock3.is_ok());
415    assert!(lock4.is_ok());
416    assert!(lock5.is_ok());
417    // Trigger cleanup
418    MutexTable::cleanup(mutex_table.lock_table.clone());
419    // Try acquiring locks again, these should still fail because locks have not
420    // been released
421    let lock11 = mutex_table.try_acquire_lock("lock1".to_string());
422    let lock22 = mutex_table.try_acquire_lock("lock2".to_string());
423    let lock33 = mutex_table.try_acquire_lock("lock3".to_string());
424    let lock44 = mutex_table.try_acquire_lock("lock4".to_string());
425    let lock55 = mutex_table.try_acquire_lock("lock5".to_string());
426    assert!(lock11.is_err());
427    assert!(lock22.is_err());
428    assert!(lock33.is_err());
429    assert!(lock44.is_err());
430    assert!(lock55.is_err());
431    // drop all locks
432    drop(lock1);
433    drop(lock2);
434    drop(lock3);
435    drop(lock4);
436    drop(lock5);
437    // Wait for bg cleanup to be triggered
438    tokio::time::sleep(Duration::from_secs(10)).await;
439    for entry in mutex_table.lock_table.iter() {
440        let locked = entry.read();
441        assert!(locked.is_empty());
442    }
443}
444
445#[tokio::test(flavor = "current_thread", start_paused = true)]
446async fn test_mutex_table_bg_cleanup_with_size_threshold() {
447    // set up the table to never trigger cleanup because of time period but only
448    // size threshold
449    let mutex_table =
450        MutexTable::<String>::new_with_cleanup(1, Duration::MAX, Duration::from_secs(1), 5);
451    let lock1 = mutex_table.try_acquire_lock("lock1".to_string());
452    let lock2 = mutex_table.try_acquire_lock("lock2".to_string());
453    let lock3 = mutex_table.try_acquire_lock("lock3".to_string());
454    let lock4 = mutex_table.try_acquire_lock("lock4".to_string());
455    let lock5 = mutex_table.try_acquire_lock("lock5".to_string());
456    assert!(lock1.is_ok());
457    assert!(lock2.is_ok());
458    assert!(lock3.is_ok());
459    assert!(lock4.is_ok());
460    assert!(lock5.is_ok());
461    // Trigger cleanup
462    MutexTable::cleanup(mutex_table.lock_table.clone());
463    // Try acquiring locks again, these should still fail because locks have not
464    // been released
465    let lock11 = mutex_table.try_acquire_lock("lock1".to_string());
466    let lock22 = mutex_table.try_acquire_lock("lock2".to_string());
467    let lock33 = mutex_table.try_acquire_lock("lock3".to_string());
468    let lock44 = mutex_table.try_acquire_lock("lock4".to_string());
469    let lock55 = mutex_table.try_acquire_lock("lock5".to_string());
470    assert!(lock11.is_err());
471    assert!(lock22.is_err());
472    assert!(lock33.is_err());
473    assert!(lock44.is_err());
474    assert!(lock55.is_err());
475    assert_eq!(mutex_table.size(), 5);
476    // drop all locks
477    drop(lock1);
478    drop(lock2);
479    drop(lock3);
480    drop(lock4);
481    drop(lock5);
482    tokio::task::yield_now().await;
483    // Wait for bg cleanup to be triggered because of size threshold
484    tokio::time::advance(Duration::from_secs(5)).await;
485    tokio::task::yield_now().await;
486    assert_eq!(mutex_table.size(), 0);
487    for entry in mutex_table.lock_table.iter() {
488        let locked = entry.read();
489        assert!(locked.is_empty());
490    }
491}