iota_storage/
sharded_lru.rs1use std::{
6 collections::{HashMap, hash_map::RandomState},
7 fmt::Debug,
8 future::Future,
9 hash::{BuildHasher, Hash},
10 num::NonZeroUsize,
11};
12
13use lru::LruCache;
14use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
15
16pub struct ShardedLruCache<K, V, S = RandomState> {
17 shards: Vec<RwLock<LruCache<K, V>>>,
18 hasher: S,
19}
20
21unsafe impl<K, V, S> Send for ShardedLruCache<K, V, S> {}
22unsafe impl<K, V, S> Sync for ShardedLruCache<K, V, S> {}
23
24impl<K, V> ShardedLruCache<K, V, RandomState>
25where
26 K: Send + Sync + Hash + Eq + Clone,
27 V: Send + Sync + Clone,
28{
29 pub fn new(capacity: u64, num_shards: u64) -> Self {
30 let cap_per_shard = capacity.div_ceil(num_shards);
31 let hasher = RandomState::default();
32 Self {
33 hasher,
34 shards: (0..num_shards)
35 .map(|_| {
36 RwLock::new(LruCache::new(
37 NonZeroUsize::new(cap_per_shard as usize).unwrap(),
38 ))
39 })
40 .collect(),
41 }
42 }
43}
44
45impl<K, V, S> ShardedLruCache<K, V, S>
46where
47 K: Hash + Eq + Clone + Debug,
48 V: Clone,
49 S: BuildHasher,
50{
51 fn shard_id(&self, key: &K) -> usize {
52 let h = self.hasher.hash_one(key) as usize;
53 h % self.shards.len()
54 }
55
56 async fn read_shard(&self, key: &K) -> RwLockReadGuard<'_, LruCache<K, V>> {
57 let shard_idx = self.shard_id(key);
58 self.shards[shard_idx].read().await
59 }
60
61 async fn write_shard(&self, key: &K) -> RwLockWriteGuard<'_, LruCache<K, V>> {
62 let shard_idx = self.shard_id(key);
63 self.shards[shard_idx].write().await
64 }
65
66 pub async fn invalidate(&self, key: &K) -> Option<V> {
67 self.write_shard(key).await.pop(key)
68 }
69
70 pub async fn batch_invalidate(&self, keys: impl IntoIterator<Item = K>) {
71 let mut grouped = HashMap::new();
72 for key in keys.into_iter() {
73 let shard_idx = self.shard_id(&key);
74 grouped.entry(shard_idx).or_insert(vec![]).push(key);
75 }
76 for (shard_idx, keys) in grouped.into_iter() {
77 let mut lock = self.shards[shard_idx].write().await;
78 for key in keys {
79 lock.pop(&key);
80 }
81 }
82 }
83
84 pub async fn merge(&self, key: K, value: &V, f: fn(&V, &V) -> V) {
85 let mut shard = self.write_shard(&key).await;
86 let old_value = shard.get(&key);
87 if let Some(old_value) = old_value {
88 let new_value = f(old_value, value);
89 shard.put(key, new_value);
90 }
91 }
92
93 pub async fn batch_merge(
94 &self,
95 key_values: impl IntoIterator<Item = (K, V)>,
96 f: fn(&V, &V) -> V,
97 ) {
98 let mut grouped = HashMap::new();
99 for (key, value) in key_values.into_iter() {
100 let shard_idx = self.shard_id(&key);
101 grouped
102 .entry(shard_idx)
103 .or_insert(vec![])
104 .push((key, value));
105 }
106 for (shard_idx, keys) in grouped.into_iter() {
107 let mut shard = self.shards[shard_idx].write().await;
108 for (key, value) in keys.into_iter() {
109 let old_value = shard.get(&key);
110 if let Some(old_value) = old_value {
111 let new_value = f(old_value, &value);
112 shard.put(key, new_value);
113 }
114 }
115 }
116 }
117
118 pub async fn get(&self, key: &K) -> Option<V> {
119 self.read_shard(key).await.peek(key).cloned()
120 }
121
122 pub async fn get_with(&self, key: K, init: impl Future<Output = V>) -> V {
123 let shard = self.read_shard(&key).await;
124 let value = shard.peek(&key);
125 if let Some(value) = value {
126 return value.clone();
127 }
128 drop(shard);
129 let mut shard = self.write_shard(&key).await;
130 let value = shard.get(&key);
131 if let Some(value) = value {
132 return value.clone();
133 }
134 let value = init.await;
135 let cloned_value = value.clone();
136 shard.push(key, value);
137 cloned_value
138 }
139}