iota_storage/
sharded_lru.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::{HashMap, hash_map::RandomState},
7    fmt::Debug,
8    future::Future,
9    hash::{BuildHasher, Hash},
10    num::NonZeroUsize,
11};
12
13use lru::LruCache;
14use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
15
16pub struct ShardedLruCache<K, V, S = RandomState> {
17    shards: Vec<RwLock<LruCache<K, V>>>,
18    hasher: S,
19}
20
21unsafe impl<K, V, S> Send for ShardedLruCache<K, V, S> {}
22unsafe impl<K, V, S> Sync for ShardedLruCache<K, V, S> {}
23
24impl<K, V> ShardedLruCache<K, V, RandomState>
25where
26    K: Send + Sync + Hash + Eq + Clone,
27    V: Send + Sync + Clone,
28{
29    pub fn new(capacity: u64, num_shards: u64) -> Self {
30        let cap_per_shard = capacity.div_ceil(num_shards);
31        let hasher = RandomState::default();
32        Self {
33            hasher,
34            shards: (0..num_shards)
35                .map(|_| {
36                    RwLock::new(LruCache::new(
37                        NonZeroUsize::new(cap_per_shard as usize).unwrap(),
38                    ))
39                })
40                .collect(),
41        }
42    }
43}
44
45impl<K, V, S> ShardedLruCache<K, V, S>
46where
47    K: Hash + Eq + Clone + Debug,
48    V: Clone,
49    S: BuildHasher,
50{
51    fn shard_id(&self, key: &K) -> usize {
52        let h = self.hasher.hash_one(key) as usize;
53        h % self.shards.len()
54    }
55
56    async fn read_shard(&self, key: &K) -> RwLockReadGuard<'_, LruCache<K, V>> {
57        let shard_idx = self.shard_id(key);
58        self.shards[shard_idx].read().await
59    }
60
61    async fn write_shard(&self, key: &K) -> RwLockWriteGuard<'_, LruCache<K, V>> {
62        let shard_idx = self.shard_id(key);
63        self.shards[shard_idx].write().await
64    }
65
66    pub async fn invalidate(&self, key: &K) -> Option<V> {
67        self.write_shard(key).await.pop(key)
68    }
69
70    pub async fn batch_invalidate(&self, keys: impl IntoIterator<Item = K>) {
71        let mut grouped = HashMap::new();
72        for key in keys.into_iter() {
73            let shard_idx = self.shard_id(&key);
74            grouped.entry(shard_idx).or_insert(vec![]).push(key);
75        }
76        for (shard_idx, keys) in grouped.into_iter() {
77            let mut lock = self.shards[shard_idx].write().await;
78            for key in keys {
79                lock.pop(&key);
80            }
81        }
82    }
83
84    pub async fn merge(&self, key: K, value: &V, f: fn(&V, &V) -> V) {
85        let mut shard = self.write_shard(&key).await;
86        let old_value = shard.get(&key);
87        if let Some(old_value) = old_value {
88            let new_value = f(old_value, value);
89            shard.put(key, new_value);
90        }
91    }
92
93    pub async fn batch_merge(
94        &self,
95        key_values: impl IntoIterator<Item = (K, V)>,
96        f: fn(&V, &V) -> V,
97    ) {
98        let mut grouped = HashMap::new();
99        for (key, value) in key_values.into_iter() {
100            let shard_idx = self.shard_id(&key);
101            grouped
102                .entry(shard_idx)
103                .or_insert(vec![])
104                .push((key, value));
105        }
106        for (shard_idx, keys) in grouped.into_iter() {
107            let mut shard = self.shards[shard_idx].write().await;
108            for (key, value) in keys.into_iter() {
109                let old_value = shard.get(&key);
110                if let Some(old_value) = old_value {
111                    let new_value = f(old_value, &value);
112                    shard.put(key, new_value);
113                }
114            }
115        }
116    }
117
118    pub async fn get(&self, key: &K) -> Option<V> {
119        self.read_shard(key).await.peek(key).cloned()
120    }
121
122    pub async fn get_with(&self, key: K, init: impl Future<Output = V>) -> V {
123        let shard = self.read_shard(&key).await;
124        let value = shard.peek(&key);
125        if let Some(value) = value {
126            return value.clone();
127        }
128        drop(shard);
129        let mut shard = self.write_shard(&key).await;
130        let value = shard.get(&key);
131        if let Some(value) = value {
132            return value.clone();
133        }
134        let value = init.await;
135        let cloned_value = value.clone();
136        shard.push(key, value);
137        cloned_value
138    }
139}