iota_storage/
mutex_table.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::{
7        HashMap,
8        hash_map::{DefaultHasher, RandomState},
9    },
10    error::Error,
11    fmt,
12    hash::{BuildHasher, Hash, Hasher},
13    sync::{
14        Arc,
15        atomic::{AtomicBool, AtomicUsize, Ordering},
16    },
17    time::Duration,
18};
19
20use async_trait::async_trait;
21use iota_metrics::spawn_monitored_task;
22use tokio::{
23    sync::{
24        Mutex, OwnedMutexGuard, OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock, TryLockError,
25    },
26    task::JoinHandle,
27    time::Instant,
28};
29use tracing::info;
30
31#[async_trait]
32pub trait Lock: Send + Sync + Default {
33    type Guard;
34    type ReadGuard;
35    async fn lock_owned(self: Arc<Self>) -> Self::Guard;
36    fn try_lock_owned(self: Arc<Self>) -> Result<Self::Guard, TryLockError>;
37    async fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard;
38}
39
40#[async_trait::async_trait]
41impl Lock for Mutex<()> {
42    type Guard = OwnedMutexGuard<()>;
43    type ReadGuard = Self::Guard;
44
45    async fn lock_owned(self: Arc<Self>) -> Self::Guard {
46        self.lock_owned().await
47    }
48
49    fn try_lock_owned(self: Arc<Self>) -> Result<Self::Guard, TryLockError> {
50        self.try_lock_owned()
51    }
52
53    async fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard {
54        self.lock_owned().await
55    }
56}
57
58#[async_trait::async_trait]
59impl Lock for RwLock<()> {
60    type Guard = OwnedRwLockWriteGuard<()>;
61    type ReadGuard = OwnedRwLockReadGuard<()>;
62
63    async fn lock_owned(self: Arc<Self>) -> Self::Guard {
64        self.write_owned().await
65    }
66
67    fn try_lock_owned(self: Arc<Self>) -> Result<Self::Guard, TryLockError> {
68        self.try_write_owned()
69    }
70
71    async fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard {
72        self.read_owned().await
73    }
74}
75
76type InnerLockTable<K, L> = HashMap<K, Arc<L>>;
77// MutexTable supports mutual exclusion on keys such as TransactionDigest or
78// ObjectDigest
79pub struct LockTable<K: Hash, L: Lock> {
80    random_state: RandomState,
81    lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>>,
82    _k: std::marker::PhantomData<K>,
83    _cleaner: JoinHandle<()>,
84    stop: Arc<AtomicBool>,
85    size: Arc<AtomicUsize>,
86}
87
88pub type MutexTable<K> = LockTable<K, Mutex<()>>;
89pub type RwLockTable<K> = LockTable<K, RwLock<()>>;
90
91#[derive(Debug)]
92pub enum TryAcquireLockError {
93    LockTableLocked,
94    LockEntryLocked,
95}
96
97impl fmt::Display for TryAcquireLockError {
98    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
99        write!(fmt, "operation would block")
100    }
101}
102
103impl Error for TryAcquireLockError {}
104pub type MutexGuard = tokio::sync::OwnedMutexGuard<()>;
105pub type RwLockGuard = tokio::sync::OwnedRwLockReadGuard<()>;
106
107impl<K: Hash + Eq + Send + Sync + 'static, L: Lock + 'static> LockTable<K, L> {
108    pub fn new_with_cleanup(
109        num_shards: usize,
110        cleanup_period: Duration,
111        cleanup_initial_delay: Duration,
112        cleanup_entries_threshold: usize,
113    ) -> Self {
114        let num_shards = if cfg!(msim) { 4 } else { num_shards };
115
116        let lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>> = Arc::new(
117            (0..num_shards)
118                .map(|_| RwLock::new(HashMap::new()))
119                .collect(),
120        );
121        let cloned = lock_table.clone();
122        let stop = Arc::new(AtomicBool::new(false));
123        let stop_cloned = stop.clone();
124        let size: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
125        let size_cloned = size.clone();
126        Self {
127            random_state: RandomState::new(),
128            lock_table,
129            _k: std::marker::PhantomData {},
130            _cleaner: spawn_monitored_task!(async move {
131                tokio::time::sleep(cleanup_initial_delay).await;
132                let mut previous_cleanup_instant = Instant::now();
133                while !stop_cloned.load(Ordering::SeqCst) {
134                    if size_cloned.load(Ordering::SeqCst) >= cleanup_entries_threshold
135                        || previous_cleanup_instant.elapsed() >= cleanup_period
136                    {
137                        let num_removed = Self::cleanup(cloned.clone());
138                        size_cloned.fetch_sub(num_removed, Ordering::SeqCst);
139                        previous_cleanup_instant = Instant::now();
140                    }
141                    tokio::time::sleep(Duration::from_secs(1)).await;
142                }
143                info!("Stopping mutex table cleanup!");
144            }),
145            stop,
146            size,
147        }
148    }
149
150    pub fn new(num_shards: usize) -> Self {
151        Self::new_with_cleanup(
152            num_shards,
153            Duration::from_secs(10),
154            Duration::from_secs(10),
155            10_000,
156        )
157    }
158
159    pub fn size(&self) -> usize {
160        self.size.load(Ordering::SeqCst)
161    }
162
163    pub fn cleanup(lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>>) -> usize {
164        let mut num_removed: usize = 0;
165        for shard in lock_table.iter() {
166            let map = shard.try_write();
167            if map.is_err() {
168                continue;
169            }
170            map.unwrap().retain(|_k, v| {
171                // MutexMap::(try_|)acquire_locks will lock the map and call Arc::clone on the
172                // entry This check ensures that we only drop entry from the map
173                // if this is the only mutex copy This check is also likely
174                // sufficient e.g. you don't even need try_lock below, but keeping it just in
175                // case
176                if Arc::strong_count(v) == 1 {
177                    num_removed += 1;
178                    false
179                } else {
180                    true
181                }
182            });
183        }
184        num_removed
185    }
186
187    fn get_lock_idx(&self, key: &K) -> usize {
188        let mut hasher = if !cfg!(test) {
189            self.random_state.build_hasher()
190        } else {
191            // be deterministic for tests
192            DefaultHasher::new()
193        };
194
195        key.hash(&mut hasher);
196        // unwrap ok - converting u64 -> usize
197        let hash: usize = hasher.finish().try_into().unwrap();
198        hash % self.lock_table.len()
199    }
200
201    pub async fn acquire_locks<I>(&self, object_iter: I) -> Vec<L::Guard>
202    where
203        I: Iterator<Item = K>,
204        K: Ord,
205    {
206        let mut objects: Vec<K> = object_iter.into_iter().collect();
207        objects.sort_unstable();
208        objects.dedup();
209
210        let mut guards = Vec::with_capacity(objects.len());
211        for object in objects.into_iter() {
212            guards.push(self.acquire_lock(object).await);
213        }
214        guards
215    }
216
217    pub async fn acquire_read_locks(&self, mut objects: Vec<K>) -> Vec<L::ReadGuard>
218    where
219        K: Ord,
220    {
221        objects.sort_unstable();
222        objects.dedup();
223        let mut guards = Vec::with_capacity(objects.len());
224        for object in objects.into_iter() {
225            guards.push(self.get_lock(object).await.read_lock_owned().await);
226        }
227        guards
228    }
229
230    pub async fn get_lock(&self, k: K) -> Arc<L> {
231        let lock_idx = self.get_lock_idx(&k);
232        let element = {
233            let map = self.lock_table[lock_idx].read().await;
234            map.get(&k).cloned()
235        };
236        if let Some(element) = element {
237            element
238        } else {
239            // element doesn't exist
240            let element = {
241                let mut map = self.lock_table[lock_idx].write().await;
242                map.entry(k)
243                    .or_insert_with(|| {
244                        self.size.fetch_add(1, Ordering::SeqCst);
245                        Arc::new(L::default())
246                    })
247                    .clone()
248            };
249            element
250        }
251    }
252
253    pub async fn acquire_lock(&self, k: K) -> L::Guard {
254        self.get_lock(k).await.lock_owned().await
255    }
256
257    pub fn try_acquire_lock(&self, k: K) -> Result<L::Guard, TryAcquireLockError> {
258        let lock_idx = self.get_lock_idx(&k);
259        let element = {
260            let map = self.lock_table[lock_idx]
261                .try_read()
262                .map_err(|_| TryAcquireLockError::LockTableLocked)?;
263            map.get(&k).cloned()
264        };
265        if let Some(element) = element {
266            let lock = element.try_lock_owned();
267            Ok(lock.map_err(|_| TryAcquireLockError::LockEntryLocked)?)
268        } else {
269            // element doesn't exist
270            let element = {
271                let mut map = self.lock_table[lock_idx]
272                    .try_write()
273                    .map_err(|_| TryAcquireLockError::LockTableLocked)?;
274                map.entry(k)
275                    .or_insert_with(|| {
276                        self.size.fetch_add(1, Ordering::SeqCst);
277                        Arc::new(L::default())
278                    })
279                    .clone()
280            };
281            let lock = element.try_lock_owned();
282            lock.map_err(|_| TryAcquireLockError::LockEntryLocked)
283        }
284    }
285}
286
287impl<K: Hash, L: Lock> Drop for LockTable<K, L> {
288    fn drop(&mut self) {
289        self.stop.store(true, Ordering::SeqCst);
290    }
291}
292
293#[tokio::test]
294// Tests that mutex table provides parallelism on the individual mutex level,
295// e.g. that locks for different entries do not block entire bucket if it needs
296// to wait on individual lock
297async fn test_mutex_table_concurrent_in_same_bucket() {
298    use tokio::time::{sleep, timeout};
299    let mutex_table = Arc::new(MutexTable::<String>::new(1));
300    let john = mutex_table.try_acquire_lock("john".to_string());
301    john.unwrap();
302    {
303        let mutex_table = mutex_table.clone();
304        spawn_monitored_task!(async move {
305            mutex_table.acquire_lock("john".to_string()).await;
306        });
307    }
308    sleep(Duration::from_millis(50)).await;
309    let jane = mutex_table.try_acquire_lock("jane".to_string());
310    jane.unwrap();
311
312    let mutex_table = Arc::new(MutexTable::<String>::new(1));
313    let _john = mutex_table.acquire_lock("john".to_string()).await;
314    {
315        let mutex_table = mutex_table.clone();
316        spawn_monitored_task!(async move {
317            mutex_table.acquire_lock("john".to_string()).await;
318        });
319    }
320    sleep(Duration::from_millis(50)).await;
321    let jane = timeout(
322        Duration::from_secs(1),
323        mutex_table.acquire_lock("jane".to_string()),
324    )
325    .await;
326    jane.unwrap();
327}
328
329#[tokio::test]
330async fn test_mutex_table() {
331    // Disable bg cleanup with Duration.MAX for initial delay
332    let mutex_table =
333        MutexTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
334    let john1 = mutex_table.try_acquire_lock("john".to_string());
335    assert!(john1.is_ok());
336    let john2 = mutex_table.try_acquire_lock("john".to_string());
337    assert!(john2.is_err());
338    drop(john1);
339    let john2 = mutex_table.try_acquire_lock("john".to_string());
340    assert!(john2.is_ok());
341    let jane = mutex_table.try_acquire_lock("jane".to_string());
342    assert!(jane.is_ok());
343    MutexTable::cleanup(mutex_table.lock_table.clone());
344    let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
345    assert!(map.is_ok());
346    assert_eq!(map.unwrap().len(), 2);
347    drop(john2);
348    MutexTable::cleanup(mutex_table.lock_table.clone());
349    let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
350    assert!(map.is_ok());
351    assert_eq!(map.unwrap().len(), 1);
352    drop(jane);
353    MutexTable::cleanup(mutex_table.lock_table.clone());
354    let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
355    assert!(map.is_ok());
356    assert!(map.unwrap().is_empty());
357}
358
359#[tokio::test]
360async fn test_acquire_locks() {
361    let mutex_table =
362        RwLockTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
363    let object_1 = "object 1".to_string();
364    let object_2 = "object 2".to_string();
365    let object_3 = "object 3".to_string();
366
367    // ensure even with duplicate objects we succeed acquiring their locks
368    let objects = vec![
369        object_1.clone(),
370        object_2.clone(),
371        object_2,
372        object_1.clone(),
373        object_3,
374        object_1,
375    ];
376
377    let locks = mutex_table.acquire_locks(objects.clone().into_iter()).await;
378    assert_eq!(locks.len(), 3);
379
380    for object in objects.clone() {
381        assert!(mutex_table.try_acquire_lock(object).is_err());
382    }
383
384    drop(locks);
385    let locks = mutex_table.acquire_locks(objects.into_iter()).await;
386    assert_eq!(locks.len(), 3);
387}
388
389#[tokio::test]
390async fn test_read_locks() {
391    let mutex_table =
392        RwLockTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
393    let lock = "lock".to_string();
394    let locks1 = mutex_table.acquire_read_locks(vec![lock.clone()]).await;
395    assert!(mutex_table.try_acquire_lock(lock.clone()).is_err());
396    let locks2 = mutex_table.acquire_read_locks(vec![lock.clone()]).await;
397    drop(locks1);
398    drop(locks2);
399    assert!(mutex_table.try_acquire_lock(lock.clone()).is_ok());
400}
401
402#[tokio::test(flavor = "current_thread", start_paused = true)]
403async fn test_mutex_table_bg_cleanup() {
404    let mutex_table = MutexTable::<String>::new_with_cleanup(
405        1,
406        Duration::from_secs(5),
407        Duration::from_secs(1),
408        1000,
409    );
410    let lock1 = mutex_table.try_acquire_lock("lock1".to_string());
411    let lock2 = mutex_table.try_acquire_lock("lock2".to_string());
412    let lock3 = mutex_table.try_acquire_lock("lock3".to_string());
413    let lock4 = mutex_table.try_acquire_lock("lock4".to_string());
414    let lock5 = mutex_table.try_acquire_lock("lock5".to_string());
415    assert!(lock1.is_ok());
416    assert!(lock2.is_ok());
417    assert!(lock3.is_ok());
418    assert!(lock4.is_ok());
419    assert!(lock5.is_ok());
420    // Trigger cleanup
421    MutexTable::cleanup(mutex_table.lock_table.clone());
422    // Try acquiring locks again, these should still fail because locks have not
423    // been released
424    let lock11 = mutex_table.try_acquire_lock("lock1".to_string());
425    let lock22 = mutex_table.try_acquire_lock("lock2".to_string());
426    let lock33 = mutex_table.try_acquire_lock("lock3".to_string());
427    let lock44 = mutex_table.try_acquire_lock("lock4".to_string());
428    let lock55 = mutex_table.try_acquire_lock("lock5".to_string());
429    assert!(lock11.is_err());
430    assert!(lock22.is_err());
431    assert!(lock33.is_err());
432    assert!(lock44.is_err());
433    assert!(lock55.is_err());
434    // drop all locks
435    drop(lock1);
436    drop(lock2);
437    drop(lock3);
438    drop(lock4);
439    drop(lock5);
440    // Wait for bg cleanup to be triggered
441    tokio::time::sleep(Duration::from_secs(10)).await;
442    for entry in mutex_table.lock_table.iter() {
443        let locked = entry.read().await;
444        assert!(locked.is_empty());
445    }
446}
447
448#[tokio::test(flavor = "current_thread", start_paused = true)]
449async fn test_mutex_table_bg_cleanup_with_size_threshold() {
450    // set up the table to never trigger cleanup because of time period but only
451    // size threshold
452    let mutex_table =
453        MutexTable::<String>::new_with_cleanup(1, Duration::MAX, Duration::from_secs(1), 5);
454    let lock1 = mutex_table.try_acquire_lock("lock1".to_string());
455    let lock2 = mutex_table.try_acquire_lock("lock2".to_string());
456    let lock3 = mutex_table.try_acquire_lock("lock3".to_string());
457    let lock4 = mutex_table.try_acquire_lock("lock4".to_string());
458    let lock5 = mutex_table.try_acquire_lock("lock5".to_string());
459    assert!(lock1.is_ok());
460    assert!(lock2.is_ok());
461    assert!(lock3.is_ok());
462    assert!(lock4.is_ok());
463    assert!(lock5.is_ok());
464    // Trigger cleanup
465    MutexTable::cleanup(mutex_table.lock_table.clone());
466    // Try acquiring locks again, these should still fail because locks have not
467    // been released
468    let lock11 = mutex_table.try_acquire_lock("lock1".to_string());
469    let lock22 = mutex_table.try_acquire_lock("lock2".to_string());
470    let lock33 = mutex_table.try_acquire_lock("lock3".to_string());
471    let lock44 = mutex_table.try_acquire_lock("lock4".to_string());
472    let lock55 = mutex_table.try_acquire_lock("lock5".to_string());
473    assert!(lock11.is_err());
474    assert!(lock22.is_err());
475    assert!(lock33.is_err());
476    assert!(lock44.is_err());
477    assert!(lock55.is_err());
478    assert_eq!(mutex_table.size(), 5);
479    // drop all locks
480    drop(lock1);
481    drop(lock2);
482    drop(lock3);
483    drop(lock4);
484    drop(lock5);
485    tokio::task::yield_now().await;
486    // Wait for bg cleanup to be triggered because of size threshold
487    tokio::time::advance(Duration::from_secs(5)).await;
488    tokio::task::yield_now().await;
489    assert_eq!(mutex_table.size(), 0);
490    for entry in mutex_table.lock_table.iter() {
491        let locked = entry.read().await;
492        assert!(locked.is_empty());
493    }
494}