iota_storage/
sharded_lru.rs1use std::{
6 collections::{HashMap, hash_map::RandomState},
7 fmt::Debug,
8 hash::{BuildHasher, Hash},
9 num::NonZeroUsize,
10};
11
12use lru::LruCache;
13use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
14
15pub struct ShardedLruCache<K, V, S = RandomState> {
16 shards: Vec<RwLock<LruCache<K, V>>>,
17 hasher: S,
18}
19
20unsafe impl<K, V, S> Send for ShardedLruCache<K, V, S> {}
21unsafe impl<K, V, S> Sync for ShardedLruCache<K, V, S> {}
22
23impl<K, V> ShardedLruCache<K, V, RandomState>
24where
25 K: Send + Sync + Hash + Eq + Clone,
26 V: Send + Sync + Clone,
27{
28 pub fn new(capacity: u64, num_shards: u64) -> Self {
29 let cap_per_shard = capacity.div_ceil(num_shards);
30 let hasher = RandomState::default();
31 Self {
32 hasher,
33 shards: (0..num_shards)
34 .map(|_| {
35 RwLock::new(LruCache::new(
36 NonZeroUsize::new(cap_per_shard as usize).unwrap(),
37 ))
38 })
39 .collect(),
40 }
41 }
42}
43
44impl<K, V, S> ShardedLruCache<K, V, S>
45where
46 K: Hash + Eq + Clone + Debug,
47 V: Clone,
48 S: BuildHasher,
49{
50 fn shard_id(&self, key: &K) -> usize {
51 let h = self.hasher.hash_one(key) as usize;
52 h % self.shards.len()
53 }
54
55 fn read_shard(&self, key: &K) -> RwLockReadGuard<'_, LruCache<K, V>> {
56 let shard_idx = self.shard_id(key);
57 self.shards[shard_idx].read()
58 }
59
60 fn write_shard(&self, key: &K) -> RwLockWriteGuard<'_, LruCache<K, V>> {
61 let shard_idx = self.shard_id(key);
62 self.shards[shard_idx].write()
63 }
64
65 pub fn invalidate(&self, key: &K) -> Option<V> {
66 self.write_shard(key).pop(key)
67 }
68
69 pub fn batch_invalidate(&self, keys: impl IntoIterator<Item = K>) {
70 let mut grouped = HashMap::new();
71 for key in keys.into_iter() {
72 let shard_idx = self.shard_id(&key);
73 grouped.entry(shard_idx).or_insert(vec![]).push(key);
74 }
75 for (shard_idx, keys) in grouped.into_iter() {
76 let mut lock = self.shards[shard_idx].write();
77 for key in keys {
78 lock.pop(&key);
79 }
80 }
81 }
82
83 pub fn merge(&self, key: K, value: &V, f: fn(&V, &V) -> V) {
84 let mut shard = self.write_shard(&key);
85 let old_value = shard.get(&key);
86 if let Some(old_value) = old_value {
87 let new_value = f(old_value, value);
88 shard.put(key, new_value);
89 }
90 }
91
92 pub fn batch_merge(&self, key_values: impl IntoIterator<Item = (K, V)>, f: fn(&V, &V) -> V) {
93 let mut grouped = HashMap::new();
94 for (key, value) in key_values.into_iter() {
95 let shard_idx = self.shard_id(&key);
96 grouped
97 .entry(shard_idx)
98 .or_insert(vec![])
99 .push((key, value));
100 }
101 for (shard_idx, keys) in grouped.into_iter() {
102 let mut shard = self.shards[shard_idx].write();
103 for (key, value) in keys.into_iter() {
104 let old_value = shard.get(&key);
105 if let Some(old_value) = old_value {
106 let new_value = f(old_value, &value);
107 shard.put(key, new_value);
108 }
109 }
110 }
111 }
112
113 pub fn get(&self, key: &K) -> Option<V> {
114 self.read_shard(key).peek(key).cloned()
115 }
116
117 pub fn get_with(&self, key: K, init: impl FnOnce() -> V) -> V {
118 let shard = self.read_shard(&key);
119 let value = shard.peek(&key);
120 if let Some(value) = value {
121 return value.clone();
122 }
123 drop(shard);
124 let mut shard = self.write_shard(&key);
125 let value = shard.get(&key);
126 if let Some(value) = value {
127 return value.clone();
128 }
129 let value = init();
130 let cloned_value = value.clone();
131 shard.push(key, value);
132 cloned_value
133 }
134}