1#![allow(clippy::await_holding_lock)]
6
7use std::{
8 borrow::Borrow,
9 collections::{BTreeMap, HashMap, VecDeque, btree_map::Iter},
10 marker::PhantomData,
11 ops::RangeBounds,
12 sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard},
13};
14
15use bincode::Options;
16use collectable::TryExtend;
17use ouroboros::self_referencing;
18use rand::distributions::{Alphanumeric, DistString};
19use rocksdb::Direction;
20use serde::{Serialize, de::DeserializeOwned};
21
22use crate::{
23 DbIterator, Map, TypedStoreError, be_fix_int_ser, rocks::errors::typed_store_err_from_bcs_err,
24};
25
26#[derive(Clone, Debug)]
29pub struct TestDB<K, V> {
30 pub rows: Arc<RwLock<BTreeMap<Vec<u8>, Vec<u8>>>>,
31 pub name: String,
32 _phantom: PhantomData<fn(K) -> V>,
33}
34
35impl<K, V> TestDB<K, V> {
36 pub fn open() -> Self {
37 TestDB {
38 rows: Arc::new(RwLock::new(BTreeMap::new())),
39 name: Alphanumeric.sample_string(&mut rand::thread_rng(), 16),
40 _phantom: PhantomData,
41 }
42 }
43 pub fn batch(&self) -> TestDBWriteBatch {
44 TestDBWriteBatch::default()
45 }
46}
47
48#[self_referencing(pub_extras)]
49pub struct TestDBIter<'a, K, V> {
50 pub rows: RwLockReadGuard<'a, BTreeMap<Vec<u8>, Vec<u8>>>,
51 #[borrows(mut rows)]
52 #[covariant]
53 pub iter: Iter<'this, Vec<u8>, Vec<u8>>,
54 phantom: PhantomData<(K, V)>,
55 pub direction: Direction,
56}
57
58#[self_referencing(pub_extras)]
59pub struct TestDBKeys<'a, K> {
60 rows: RwLockReadGuard<'a, BTreeMap<Vec<u8>, Vec<u8>>>,
61 #[borrows(mut rows)]
62 #[covariant]
63 pub iter: Iter<'this, Vec<u8>, Vec<u8>>,
64 phantom: PhantomData<K>,
65}
66
67#[self_referencing(pub_extras)]
68pub struct TestDBValues<'a, V> {
69 rows: RwLockReadGuard<'a, BTreeMap<Vec<u8>, Vec<u8>>>,
70 #[borrows(mut rows)]
71 #[covariant]
72 pub iter: Iter<'this, Vec<u8>, Vec<u8>>,
73 phantom: PhantomData<V>,
74}
75
76impl<K: DeserializeOwned, V: DeserializeOwned> Iterator for TestDBIter<'_, K, V> {
77 type Item = Result<(K, V), TypedStoreError>;
78
79 fn next(&mut self) -> Option<Self::Item> {
80 let mut out: Option<Self::Item> = None;
81 let config = bincode::DefaultOptions::new()
82 .with_big_endian()
83 .with_fixint_encoding();
84 self.with_mut(|fields| {
85 let resp = match fields.direction {
86 Direction::Forward => fields.iter.next(),
87 Direction::Reverse => panic!("Reverse iteration not supported in test db"),
88 };
89 if let Some((raw_key, raw_value)) = resp {
90 let key: K = config.deserialize(raw_key).ok().unwrap();
91 let value: V = bcs::from_bytes(raw_value).ok().unwrap();
92 out = Some(Ok((key, value)));
93 }
94 });
95 out
96 }
97}
98
99impl<'a, K: Serialize, V> TestDBIter<'a, K, V> {
100 pub fn skip_to(mut self, key: &K) -> Result<Self, TypedStoreError> {
104 self.with_mut(|fields| {
105 let serialized_key = be_fix_int_ser(key);
106 let mut peekable = fields.iter.peekable();
107 let mut peeked = peekable.peek();
108 while peeked.is_some() {
109 let serialized = be_fix_int_ser(peeked.unwrap());
110 if serialized >= serialized_key {
111 break;
112 } else {
113 peekable.next();
114 peeked = peekable.peek();
115 }
116 }
117 });
118 Ok(self)
119 }
120
121 pub fn skip_prior_to(mut self, key: &K) -> Result<Self, TypedStoreError> {
125 self.with_mut(|fields| {
126 let serialized_key = be_fix_int_ser(key);
127 let mut peekable = fields.iter.peekable();
128 let mut peeked = peekable.peek();
129 while peeked.is_some() {
130 let serialized = be_fix_int_ser(peeked.unwrap());
131 if serialized > serialized_key {
132 break;
133 } else {
134 peekable.next();
135 peeked = peekable.peek();
136 }
137 }
138 });
139 Ok(self)
140 }
141
142 pub fn skip_to_last(mut self) -> Self {
144 self.with_mut(|fields| {
145 fields.iter.last();
147 });
148 self
149 }
150
151 pub fn reverse(mut self) -> TestDBRevIter<'a, K, V> {
155 self.with_mut(|fields| {
156 *fields.direction = Direction::Reverse;
157 });
158 TestDBRevIter::new(self)
159 }
160}
161
162pub struct TestDBRevIter<'a, K, V> {
167 iter: TestDBIter<'a, K, V>,
168}
169
170impl<'a, K, V> TestDBRevIter<'a, K, V> {
171 fn new(iter: TestDBIter<'a, K, V>) -> Self {
172 Self { iter }
173 }
174}
175
176impl<K: DeserializeOwned, V: DeserializeOwned> Iterator for TestDBRevIter<'_, K, V> {
177 type Item = Result<(K, V), TypedStoreError>;
178
179 fn next(&mut self) -> Option<Self::Item> {
181 self.iter.next()
182 }
183}
184
185impl<K: DeserializeOwned> Iterator for TestDBKeys<'_, K> {
186 type Item = Result<K, TypedStoreError>;
187
188 fn next(&mut self) -> Option<Self::Item> {
189 let mut out: Option<Self::Item> = None;
190 self.with_mut(|fields| {
191 let config = bincode::DefaultOptions::new()
192 .with_big_endian()
193 .with_fixint_encoding();
194 if let Some((raw_key, _)) = fields.iter.next() {
195 let key: K = config.deserialize(raw_key).ok().unwrap();
196 out = Some(Ok(key));
197 }
198 });
199 out
200 }
201}
202
203impl<V: DeserializeOwned> Iterator for TestDBValues<'_, V> {
204 type Item = Result<V, TypedStoreError>;
205
206 fn next(&mut self) -> Option<Self::Item> {
207 let mut out: Option<Self::Item> = None;
208 self.with_mut(|fields| {
209 if let Some((_, raw_value)) = fields.iter.next() {
210 let value: V = bcs::from_bytes(raw_value).ok().unwrap();
211 out = Some(Ok(value));
212 }
213 });
214 out
215 }
216}
217
218impl<'a, K, V> Map<'a, K, V> for TestDB<K, V>
219where
220 K: Serialize + DeserializeOwned,
221 V: Serialize + DeserializeOwned,
222{
223 type Error = TypedStoreError;
224
225 fn contains_key(&self, key: &K) -> Result<bool, Self::Error> {
226 let raw_key = be_fix_int_ser(key);
227 let locked = self.rows.read().unwrap();
228 Ok(locked.contains_key(&raw_key))
229 }
230
231 fn get(&self, key: &K) -> Result<Option<V>, Self::Error> {
232 let raw_key = be_fix_int_ser(key);
233 let locked = self.rows.read().unwrap();
234 let res = locked.get(&raw_key);
235 Ok(res.map(|raw_value| bcs::from_bytes(raw_value).ok().unwrap()))
236 }
237
238 fn insert(&self, key: &K, value: &V) -> Result<(), Self::Error> {
239 let raw_key = be_fix_int_ser(key);
240 let raw_value = bcs::to_bytes(value).map_err(typed_store_err_from_bcs_err)?;
241 let mut locked = self.rows.write().unwrap();
242 locked.insert(raw_key, raw_value);
243 Ok(())
244 }
245
246 fn remove(&self, key: &K) -> Result<(), Self::Error> {
247 let raw_key = be_fix_int_ser(key);
248 let mut locked = self.rows.write().unwrap();
249 locked.remove(&raw_key);
250 Ok(())
251 }
252
253 fn schedule_delete_all(&self) -> Result<(), TypedStoreError> {
254 let mut locked = self.rows.write().unwrap();
255 locked.clear();
256 Ok(())
257 }
258
259 fn is_empty(&self) -> bool {
260 let locked = self.rows.read().unwrap();
261 locked.is_empty()
262 }
263
264 fn safe_iter(&'a self) -> DbIterator<'a, (K, V)> {
265 Box::new(
266 TestDBIterBuilder {
267 rows: self.rows.read().unwrap(),
268 iter_builder: |rows: &mut RwLockReadGuard<'a, BTreeMap<Vec<u8>, Vec<u8>>>| {
269 rows.iter()
270 },
271 phantom: PhantomData,
272 direction: Direction::Forward,
273 }
274 .build(),
275 )
276 }
277
278 fn safe_iter_with_bounds(
279 &'a self,
280 _lower_bound: Option<K>,
281 _upper_bound: Option<K>,
282 ) -> DbIterator<'a, (K, V)> {
283 unimplemented!("unimplemented API");
284 }
285
286 fn safe_range_iter(&'a self, _range: impl RangeBounds<K>) -> DbIterator<'a, (K, V)> {
287 unimplemented!("unimplemented API");
288 }
289
290 fn try_catch_up_with_primary(&self) -> Result<(), Self::Error> {
291 Ok(())
292 }
293}
294
295impl<J, K, U, V> TryExtend<(J, U)> for TestDB<K, V>
296where
297 J: Borrow<K>,
298 U: Borrow<V>,
299 K: Serialize,
300 V: Serialize,
301{
302 type Error = TypedStoreError;
303
304 fn try_extend<T>(&mut self, iter: &mut T) -> Result<(), Self::Error>
305 where
306 T: Iterator<Item = (J, U)>,
307 {
308 let mut wb = self.batch();
309 wb.insert_batch(self, iter)?;
310 wb.write()
311 }
312
313 fn try_extend_from_slice(&mut self, slice: &[(J, U)]) -> Result<(), Self::Error> {
314 let slice_of_refs = slice.iter().map(|(k, v)| (k.borrow(), v.borrow()));
315 let mut wb = self.batch();
316 wb.insert_batch(self, slice_of_refs)?;
317 wb.write()
318 }
319}
320
321pub type DeleteBatchPayload = (
322 Arc<RwLock<BTreeMap<Vec<u8>, Vec<u8>>>>,
323 String,
324 Vec<Vec<u8>>,
325);
326pub type DeleteRangePayload = (
327 Arc<RwLock<BTreeMap<Vec<u8>, Vec<u8>>>>,
328 String,
329 (Vec<u8>, Vec<u8>),
330);
331pub type InsertBatchPayload = (
332 Arc<RwLock<BTreeMap<Vec<u8>, Vec<u8>>>>,
333 String,
334 Vec<(Vec<u8>, Vec<u8>)>,
335);
336type DBAndName = (Arc<RwLock<BTreeMap<Vec<u8>, Vec<u8>>>>, String);
337
338pub enum WriteBatchOp {
339 DeleteBatch(DeleteBatchPayload),
340 DeleteRange(DeleteRangePayload),
341 InsertBatch(InsertBatchPayload),
342}
343
344#[derive(Default)]
345pub struct TestDBWriteBatch {
346 pub ops: VecDeque<WriteBatchOp>,
347}
348
349#[self_referencing]
350pub struct DBLocked {
351 db: Arc<RwLock<BTreeMap<Vec<u8>, Vec<u8>>>>,
352 #[borrows(db)]
353 #[covariant]
354 db_guard: RwLockWriteGuard<'this, BTreeMap<Vec<u8>, Vec<u8>>>,
355}
356
357impl TestDBWriteBatch {
358 pub fn write(self) -> Result<(), TypedStoreError> {
359 let mut dbs: Vec<DBAndName> = self
360 .ops
361 .iter()
362 .map(|op| match op {
363 WriteBatchOp::DeleteBatch((db, name, _)) => (db.clone(), name.clone()),
364 WriteBatchOp::DeleteRange((db, name, _)) => (db.clone(), name.clone()),
365 WriteBatchOp::InsertBatch((db, name, _)) => (db.clone(), name.clone()),
366 })
367 .collect();
368 dbs.sort_by_key(|(_k, v)| v.clone());
369 dbs.dedup_by_key(|(_k, v)| v.clone());
370 let mut db_locks = HashMap::new();
372 dbs.iter().for_each(|(db, name)| {
373 if !db_locks.contains_key(name) {
374 db_locks.insert(
375 name.clone(),
376 DBLockedBuilder {
377 db: db.clone(),
378 db_guard_builder: |db: &Arc<RwLock<BTreeMap<Vec<u8>, Vec<u8>>>>| {
379 db.write().unwrap()
380 },
381 }
382 .build(),
383 );
384 }
385 });
386 self.ops.iter().for_each(|op| match op {
387 WriteBatchOp::DeleteBatch((_, id, keys)) => {
388 let locked = db_locks.get_mut(id).unwrap();
389 locked.with_db_guard_mut(|db| {
390 keys.iter().for_each(|key| {
391 db.remove(key);
392 });
393 });
394 }
395 WriteBatchOp::DeleteRange((_, id, (from, to))) => {
396 let locked = db_locks.get_mut(id).unwrap();
397 locked.with_db_guard_mut(|db| {
398 db.retain(|k, _| k < from || k >= to);
399 });
400 }
401 WriteBatchOp::InsertBatch((_, id, key_values)) => {
402 let locked = db_locks.get_mut(id).unwrap();
403 locked.with_db_guard_mut(|db| {
404 key_values.iter().for_each(|(k, v)| {
405 db.insert(k.clone(), v.clone());
406 });
407 });
408 }
409 });
410 dbs.iter().rev().for_each(|(_db, id)| {
412 if db_locks.contains_key(id) {
413 db_locks.remove(id);
414 }
415 });
416 Ok(())
417 }
418 pub fn delete_batch<J: Borrow<K>, K: Serialize, V>(
420 &mut self,
421 db: &TestDB<K, V>,
422 purged_vals: impl IntoIterator<Item = J>,
423 ) -> Result<(), TypedStoreError> {
424 self.ops.push_back(WriteBatchOp::DeleteBatch((
425 db.rows.clone(),
426 db.name.clone(),
427 purged_vals
428 .into_iter()
429 .map(|key| be_fix_int_ser(&key.borrow()))
430 .collect(),
431 )));
432 Ok(())
433 }
434 pub fn delete_range<K: Serialize, V>(
437 &mut self,
438 db: &TestDB<K, V>,
439 from: &K,
440 to: &K,
441 ) -> Result<(), TypedStoreError> {
442 let raw_from = be_fix_int_ser(from);
443 let raw_to = be_fix_int_ser(to);
444 self.ops.push_back(WriteBatchOp::DeleteRange((
445 db.rows.clone(),
446 db.name.clone(),
447 (raw_from, raw_to),
448 )));
449 Ok(())
450 }
451 pub fn insert_batch<J: Borrow<K>, K: Serialize, U: Borrow<V>, V: Serialize>(
453 &mut self,
454 db: &TestDB<K, V>,
455 new_vals: impl IntoIterator<Item = (J, U)>,
456 ) -> Result<(), TypedStoreError> {
457 self.ops.push_back(WriteBatchOp::InsertBatch((
458 db.rows.clone(),
459 db.name.clone(),
460 new_vals
461 .into_iter()
462 .map(|(key, value)| {
463 (
464 be_fix_int_ser(&key.borrow()),
465 bcs::to_bytes(&value.borrow()).unwrap(),
466 )
467 })
468 .collect(),
469 )));
470 Ok(())
471 }
472}
473
474#[cfg(test)]
475mod test {
476 use crate::{Map, test_db::TestDB};
477
478 #[test]
479 fn test_contains_key() {
480 let db = TestDB::open();
481 db.insert(&123456789, &"123456789".to_string())
482 .expect("Failed to insert");
483 assert!(
484 db.contains_key(&123456789)
485 .expect("Failed to call contains key")
486 );
487 assert!(
488 !db.contains_key(&000000000)
489 .expect("Failed to call contains key")
490 );
491 }
492
493 #[test]
494 fn test_get() {
495 let db = TestDB::open();
496 db.insert(&123456789, &"123456789".to_string())
497 .expect("Failed to insert");
498 assert_eq!(
499 Some("123456789".to_string()),
500 db.get(&123456789).expect("Failed to get")
501 );
502 assert_eq!(None, db.get(&000000000).expect("Failed to get"));
503 }
504
505 #[test]
506 fn test_multi_get() {
507 let db = TestDB::open();
508 db.insert(&123, &"123".to_string())
509 .expect("Failed to insert");
510 db.insert(&456, &"456".to_string())
511 .expect("Failed to insert");
512
513 let result = db.multi_get([123, 456, 789]).expect("Failed to multi get");
514
515 assert_eq!(result.len(), 3);
516 assert_eq!(result[0], Some("123".to_string()));
517 assert_eq!(result[1], Some("456".to_string()));
518 assert_eq!(result[2], None);
519 }
520
521 #[test]
522 fn test_remove() {
523 let db = TestDB::open();
524 db.insert(&123456789, &"123456789".to_string())
525 .expect("Failed to insert");
526 assert!(db.get(&123456789).expect("Failed to get").is_some());
527
528 db.remove(&123456789).expect("Failed to remove");
529 assert!(db.get(&123456789).expect("Failed to get").is_none());
530 }
531
532 #[test]
533 fn test_iter() {
534 let db = TestDB::open();
535 db.insert(&123456789, &"123456789".to_string())
536 .expect("Failed to insert");
537
538 let mut iter = db.safe_iter();
539 assert_eq!(Some(Ok((123456789, "123456789".to_string()))), iter.next());
540 assert_eq!(None, iter.next());
541 }
542
543 #[test]
544 fn test_iter_reverse() {
545 let db = TestDB::open();
546 db.insert(&1, &"1".to_string()).expect("Failed to insert");
547 db.insert(&2, &"2".to_string()).expect("Failed to insert");
548 db.insert(&3, &"3".to_string()).expect("Failed to insert");
549 let mut iter = db.safe_iter();
550
551 assert_eq!(Some(Ok((1, "1".to_string()))), iter.next());
552 assert_eq!(Some(Ok((2, "2".to_string()))), iter.next());
553 assert_eq!(Some(Ok((3, "3".to_string()))), iter.next());
554 assert_eq!(None, iter.next());
555 }
556
557 #[test]
558 fn test_values() {
559 let db = TestDB::open();
560
561 db.insert(&123456789, &"123456789".to_string())
562 .expect("Failed to insert");
563 }
564
565 #[test]
566 fn test_insert_batch() {
567 let db = TestDB::open();
568 let keys_vals = (1..100).map(|i| (i, i.to_string()));
569 let mut wb = db.batch();
570 wb.insert_batch(&db, keys_vals.clone())
571 .expect("Failed to batch insert");
572 wb.write().expect("Failed to execute batch");
573 for (k, v) in keys_vals {
574 let val = db.get(&k).expect("Failed to get inserted key");
575 assert_eq!(Some(v), val);
576 }
577 }
578
579 #[test]
580 fn test_insert_batch_across_cf() {
581 let db_cf_1 = TestDB::open();
582 let keys_vals_1 = (1..100).map(|i| (i, i.to_string()));
583
584 let db_cf_2 = TestDB::open();
585 let keys_vals_2 = (1000..1100).map(|i| (i, i.to_string()));
586
587 let mut wb = db_cf_1.batch();
588 wb.insert_batch(&db_cf_1, keys_vals_1.clone())
589 .expect("Failed to batch insert");
590 wb.insert_batch(&db_cf_2, keys_vals_2.clone())
591 .expect("Failed to batch insert");
592 wb.write().expect("Failed to execute batch");
593 for (k, v) in keys_vals_1 {
594 let val = db_cf_1.get(&k).expect("Failed to get inserted key");
595 assert_eq!(Some(v), val);
596 }
597
598 for (k, v) in keys_vals_2 {
599 let val = db_cf_2.get(&k).expect("Failed to get inserted key");
600 assert_eq!(Some(v), val);
601 }
602 }
603
604 #[test]
605 fn test_delete_batch() {
606 let db: TestDB<i32, String> = TestDB::open();
607
608 let keys_vals = (1..100).map(|i| (i, i.to_string()));
609 let mut wb = db.batch();
610 wb.insert_batch(&db, keys_vals)
611 .expect("Failed to batch insert");
612
613 let deletion_keys = (1..100).step_by(2);
615 wb.delete_batch(&db, deletion_keys)
616 .expect("Failed to batch delete");
617
618 wb.write().expect("Failed to execute batch");
619
620 db.safe_iter().for_each(|item| {
621 assert!(item.unwrap().0 % 2 == 0);
622 });
623 }
624
625 #[test]
626 fn test_delete_range() {
627 let db: TestDB<i32, String> = TestDB::open();
628
629 let keys_vals = (0..101).map(|i| (i, i.to_string()));
631 let mut wb = db.batch();
632 wb.insert_batch(&db, keys_vals)
633 .expect("Failed to batch insert");
634
635 wb.delete_range(&db, &50, &100)
636 .expect("Failed to delete range");
637
638 wb.write().expect("Failed to execute batch");
639
640 for k in 0..50 {
641 assert!(db.contains_key(&k).expect("Failed to query legal key"),);
642 }
643 for k in 50..100 {
644 assert!(!db.contains_key(&k).expect("Failed to query legal key"));
645 }
646
647 assert!(db.contains_key(&100).expect("Failed to query legal key"));
649 }
650
651 #[test]
652 fn test_is_empty() {
653 let db: TestDB<i32, String> = TestDB::open();
654
655 assert!(db.is_empty());
657
658 let keys_vals = (0..101).map(|i| (i, i.to_string()));
659 let mut wb = db.batch();
660 wb.insert_batch(&db, keys_vals)
661 .expect("Failed to batch insert");
662
663 wb.write().expect("Failed to execute batch");
664
665 assert!(db.safe_iter().count() > 1);
667 assert!(!db.is_empty());
668 }
669
670 #[test]
671 fn test_multi_insert() {
672 let db: TestDB<i32, String> = TestDB::open();
674
675 let keys_vals = (0..101).map(|i| (i, i.to_string()));
677
678 db.multi_insert(keys_vals.clone())
679 .expect("Failed to multi-insert");
680
681 for (k, v) in keys_vals {
682 let val = db.get(&k).expect("Failed to get inserted key");
683 assert_eq!(Some(v), val);
684 }
685 }
686
687 #[test]
688 fn test_multi_remove() {
689 let db: TestDB<i32, String> = TestDB::open();
691
692 let keys_vals = (0..101).map(|i| (i, i.to_string()));
694
695 db.multi_insert(keys_vals.clone())
696 .expect("Failed to multi-insert");
697
698 for (k, v) in keys_vals.clone() {
700 let val = db.get(&k).expect("Failed to get inserted key");
701 assert_eq!(Some(v), val);
702 }
703
704 db.multi_remove(keys_vals.clone().map(|kv| kv.0).take(50))
706 .expect("Failed to multi-remove");
707 assert_eq!(db.safe_iter().count(), 101 - 50);
708
709 for (k, v) in keys_vals.skip(50) {
711 let val = db.get(&k).expect("Failed to get inserted key");
712 assert_eq!(Some(v), val);
713 }
714 }
715}