iota_storage/
sharded_lru.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::{HashMap, hash_map::RandomState},
7    fmt::Debug,
8    hash::{BuildHasher, Hash},
9    num::NonZeroUsize,
10};
11
12use lru::LruCache;
13use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
14
15pub struct ShardedLruCache<K, V, S = RandomState> {
16    shards: Vec<RwLock<LruCache<K, V>>>,
17    hasher: S,
18}
19
20unsafe impl<K, V, S> Send for ShardedLruCache<K, V, S> {}
21unsafe impl<K, V, S> Sync for ShardedLruCache<K, V, S> {}
22
23impl<K, V> ShardedLruCache<K, V, RandomState>
24where
25    K: Send + Sync + Hash + Eq + Clone,
26    V: Send + Sync + Clone,
27{
28    pub fn new(capacity: u64, num_shards: u64) -> Self {
29        let cap_per_shard = capacity.div_ceil(num_shards);
30        let hasher = RandomState::default();
31        Self {
32            hasher,
33            shards: (0..num_shards)
34                .map(|_| {
35                    RwLock::new(LruCache::new(
36                        NonZeroUsize::new(cap_per_shard as usize).unwrap(),
37                    ))
38                })
39                .collect(),
40        }
41    }
42}
43
44impl<K, V, S> ShardedLruCache<K, V, S>
45where
46    K: Hash + Eq + Clone + Debug,
47    V: Clone,
48    S: BuildHasher,
49{
50    fn shard_id(&self, key: &K) -> usize {
51        let h = self.hasher.hash_one(key) as usize;
52        h % self.shards.len()
53    }
54
55    fn read_shard(&self, key: &K) -> RwLockReadGuard<'_, LruCache<K, V>> {
56        let shard_idx = self.shard_id(key);
57        self.shards[shard_idx].read()
58    }
59
60    fn write_shard(&self, key: &K) -> RwLockWriteGuard<'_, LruCache<K, V>> {
61        let shard_idx = self.shard_id(key);
62        self.shards[shard_idx].write()
63    }
64
65    pub fn invalidate(&self, key: &K) -> Option<V> {
66        self.write_shard(key).pop(key)
67    }
68
69    pub fn batch_invalidate(&self, keys: impl IntoIterator<Item = K>) {
70        let mut grouped = HashMap::new();
71        for key in keys.into_iter() {
72            let shard_idx = self.shard_id(&key);
73            grouped.entry(shard_idx).or_insert(vec![]).push(key);
74        }
75        for (shard_idx, keys) in grouped.into_iter() {
76            let mut lock = self.shards[shard_idx].write();
77            for key in keys {
78                lock.pop(&key);
79            }
80        }
81    }
82
83    pub fn merge(&self, key: K, value: &V, f: fn(&V, &V) -> V) {
84        let mut shard = self.write_shard(&key);
85        let old_value = shard.get(&key);
86        if let Some(old_value) = old_value {
87            let new_value = f(old_value, value);
88            shard.put(key, new_value);
89        }
90    }
91
92    pub fn batch_merge(&self, key_values: impl IntoIterator<Item = (K, V)>, f: fn(&V, &V) -> V) {
93        let mut grouped = HashMap::new();
94        for (key, value) in key_values.into_iter() {
95            let shard_idx = self.shard_id(&key);
96            grouped
97                .entry(shard_idx)
98                .or_insert(vec![])
99                .push((key, value));
100        }
101        for (shard_idx, keys) in grouped.into_iter() {
102            let mut shard = self.shards[shard_idx].write();
103            for (key, value) in keys.into_iter() {
104                let old_value = shard.get(&key);
105                if let Some(old_value) = old_value {
106                    let new_value = f(old_value, &value);
107                    shard.put(key, new_value);
108                }
109            }
110        }
111    }
112
113    pub fn get(&self, key: &K) -> Option<V> {
114        self.read_shard(key).peek(key).cloned()
115    }
116
117    pub fn get_with(&self, key: K, init: impl FnOnce() -> V) -> V {
118        let shard = self.read_shard(&key);
119        let value = shard.peek(&key);
120        if let Some(value) = value {
121            return value.clone();
122        }
123        drop(shard);
124        let mut shard = self.write_shard(&key);
125        let value = shard.get(&key);
126        if let Some(value) = value {
127            return value.clone();
128        }
129        let value = init();
130        let cloned_value = value.clone();
131        shard.push(key, value);
132        cloned_value
133    }
134}