identity_core/common/
one_or_set.rs

1// Copyright 2020-2022 IOTA Stiftung
2// SPDX-License-Identifier: Apache-2.0
3
4use core::fmt::Debug;
5use core::fmt::Formatter;
6use core::hash::Hash;
7use core::iter;
8use core::mem::replace;
9use core::ops::Deref;
10use core::slice::from_ref;
11
12use serde::de;
13use serde::Deserialize;
14use serde::Serialize;
15
16use crate::common::KeyComparable;
17use crate::common::OrderedSet;
18use crate::error::Error;
19use crate::error::Result;
20
21/// A generic container that stores exactly one or more unique instances of a given type.
22///
23/// Similar to [`OneOrMany`](crate::common::OneOrMany) except instances are guaranteed to be unique,
24/// and only immutable references are allowed.
25#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)]
26#[serde(transparent)]
27pub struct OneOrSet<T>(OneOrSetInner<T>)
28where
29  T: KeyComparable;
30
31// Private to prevent creations of empty `Set` variants.
32#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)]
33#[serde(untagged)]
34enum OneOrSetInner<T>
35where
36  T: KeyComparable,
37{
38  /// A single instance of `T`.
39  One(T),
40  /// Multiple (one or more) unique instances of `T`.
41  #[serde(deserialize_with = "deserialize_non_empty_set")]
42  Set(OrderedSet<T>),
43}
44
45/// Deserializes an [`OrderedSet`] while enforcing that it is non-empty.
46fn deserialize_non_empty_set<'de, D, T: serde::Deserialize<'de> + KeyComparable>(
47  deserializer: D,
48) -> Result<OrderedSet<T>, D::Error>
49where
50  D: de::Deserializer<'de>,
51{
52  let set: OrderedSet<T> = OrderedSet::deserialize(deserializer)?;
53  if set.is_empty() {
54    return Err(de::Error::custom(Error::OneOrSetEmpty));
55  }
56
57  Ok(set)
58}
59
60impl<T> OneOrSet<T>
61where
62  T: KeyComparable,
63{
64  /// Constructs a new instance with a single item.
65  pub fn new_one(item: T) -> Self {
66    Self(OneOrSetInner::One(item))
67  }
68
69  /// Constructs a new instance from a set of unique items.
70  ///
71  /// Errors if the given set is empty.
72  pub fn new_set(set: OrderedSet<T>) -> Result<Self> {
73    if set.is_empty() {
74      return Err(Error::OneOrSetEmpty);
75    }
76    if set.len() == 1 {
77      Ok(Self::new_one(
78        set.into_vec().pop().expect("infallible OneOrSet new_set"),
79      ))
80    } else {
81      Ok(Self(OneOrSetInner::Set(set)))
82    }
83  }
84
85  /// Apply a map function to convert this into a new `OneOrSet<S>`.
86  pub fn map<S, F>(self, mut f: F) -> OneOrSet<S>
87  where
88    S: KeyComparable,
89    F: FnMut(T) -> S,
90  {
91    OneOrSet(match self.0 {
92      OneOrSetInner::One(item) => OneOrSetInner::One(f(item)),
93      OneOrSetInner::Set(set_t) => {
94        let set_s: OrderedSet<S> = set_t.into_vec().into_iter().map(f).collect();
95        // Key equivalence could differ between T and S.
96        if set_s.len() == 1 {
97          OneOrSetInner::One(set_s.into_vec().pop().expect("OneOrSet::map infallible"))
98        } else {
99          OneOrSetInner::Set(set_s)
100        }
101      }
102    })
103  }
104
105  /// Apply a map function to convert this into a new `OneOrSet<S>`.
106  pub fn try_map<S, F, E>(self, mut f: F) -> Result<OneOrSet<S>, E>
107  where
108    S: KeyComparable,
109    F: FnMut(T) -> Result<S, E>,
110  {
111    Ok(OneOrSet(match self.0 {
112      OneOrSetInner::One(item) => OneOrSetInner::One(f(item)?),
113      OneOrSetInner::Set(set_t) => {
114        let set_s: OrderedSet<S> = set_t
115          .into_vec()
116          .into_iter()
117          .map(f)
118          .collect::<Result<OrderedSet<S>, E>>()?;
119        // Key equivalence could differ between T and S.
120        if set_s.len() == 1 {
121          OneOrSetInner::One(set_s.into_vec().pop().expect("OneOrSet::try_map infallible"))
122        } else {
123          OneOrSetInner::Set(set_s)
124        }
125      }
126    }))
127  }
128
129  /// Returns the number of elements in the collection.
130  #[allow(clippy::len_without_is_empty)]
131  pub fn len(&self) -> usize {
132    match &self.0 {
133      OneOrSetInner::One(_) => 1,
134      OneOrSetInner::Set(inner) => inner.len(),
135    }
136  }
137
138  /// Returns a reference to the element at the given index.
139  pub fn get(&self, index: usize) -> Option<&T> {
140    match &self.0 {
141      OneOrSetInner::One(inner) if index == 0 => Some(inner),
142      OneOrSetInner::One(_) => None,
143      OneOrSetInner::Set(inner) => inner.get(index),
144    }
145  }
146
147  /// Returns `true` if the collection contains the given item's key.
148  pub fn contains<U>(&self, item: &U) -> bool
149  where
150    T: KeyComparable,
151    U: KeyComparable<Key = T::Key> + ?Sized,
152  {
153    match &self.0 {
154      OneOrSetInner::One(inner) => inner.key() == item.key(),
155      OneOrSetInner::Set(inner) => inner.contains(item),
156    }
157  }
158
159  /// Appends a new item to the end of the collection if its key is not present already.
160  ///
161  /// Returns whether or not the value was successfully inserted.
162  pub fn append(&mut self, item: T) -> bool
163  where
164    T: KeyComparable,
165  {
166    match &mut self.0 {
167      OneOrSetInner::One(inner) if inner.key() == item.key() => false,
168      OneOrSetInner::One(_) => match replace(&mut self.0, OneOrSetInner::Set(OrderedSet::new())) {
169        OneOrSetInner::One(inner) => {
170          self.0 = OneOrSetInner::Set(OrderedSet::from_iter([inner, item]));
171          true
172        }
173        OneOrSetInner::Set(_) => unreachable!(),
174      },
175      OneOrSetInner::Set(inner) => inner.append(item),
176    }
177  }
178
179  /// Returns an `Iterator` that yields items from the collection.
180  pub fn iter(&self) -> impl Iterator<Item = &T> + '_ {
181    OneOrSetIter::new(self)
182  }
183
184  /// Returns a reference to the contents as a slice.
185  pub fn as_slice(&self) -> &[T] {
186    self
187  }
188
189  /// Consumes the [`OneOrSet`] and returns the contents as a [`Vec`].
190  pub fn into_vec(self) -> Vec<T> {
191    match self.0 {
192      OneOrSetInner::One(inner) => vec![inner],
193      OneOrSetInner::Set(inner) => inner.into_vec(),
194    }
195  }
196}
197
198impl<T> Debug for OneOrSet<T>
199where
200  T: Debug + KeyComparable,
201{
202  fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
203    match &self.0 {
204      OneOrSetInner::One(inner) => Debug::fmt(inner, f),
205      OneOrSetInner::Set(inner) => Debug::fmt(inner, f),
206    }
207  }
208}
209
210impl<T> Deref for OneOrSet<T>
211where
212  T: KeyComparable,
213{
214  type Target = [T];
215
216  fn deref(&self) -> &Self::Target {
217    match &self.0 {
218      OneOrSetInner::One(inner) => from_ref(inner),
219      OneOrSetInner::Set(inner) => inner.as_slice(),
220    }
221  }
222}
223
224impl<T> AsRef<[T]> for OneOrSet<T>
225where
226  T: KeyComparable,
227{
228  fn as_ref(&self) -> &[T] {
229    self.as_slice()
230  }
231}
232
233impl<T> From<T> for OneOrSet<T>
234where
235  T: KeyComparable,
236{
237  fn from(other: T) -> Self {
238    OneOrSet::new_one(other)
239  }
240}
241
242impl<T> TryFrom<Vec<T>> for OneOrSet<T>
243where
244  T: KeyComparable,
245{
246  type Error = Error;
247
248  fn try_from(other: Vec<T>) -> std::result::Result<Self, Self::Error> {
249    let set: OrderedSet<T> = OrderedSet::try_from(other)?;
250    OneOrSet::new_set(set)
251  }
252}
253
254impl<T> TryFrom<OrderedSet<T>> for OneOrSet<T>
255where
256  T: KeyComparable,
257{
258  type Error = Error;
259
260  fn try_from(other: OrderedSet<T>) -> std::result::Result<Self, Self::Error> {
261    OneOrSet::new_set(other)
262  }
263}
264
265impl<T> From<OneOrSet<T>> for Vec<T>
266where
267  T: KeyComparable,
268{
269  fn from(other: OneOrSet<T>) -> Self {
270    other.into_vec()
271  }
272}
273
274impl<T> From<OneOrSet<T>> for OrderedSet<T>
275where
276  T: KeyComparable,
277{
278  fn from(other: OneOrSet<T>) -> Self {
279    match other.0 {
280      OneOrSetInner::One(item) => OrderedSet::from_iter(iter::once(item)),
281      OneOrSetInner::Set(set) => set,
282    }
283  }
284}
285
286// =============================================================================
287// Iterator
288// =============================================================================
289
290struct OneOrSetIter<'a, T>
291where
292  T: KeyComparable,
293{
294  inner: &'a OneOrSet<T>,
295  index: usize,
296}
297
298impl<'a, T> OneOrSetIter<'a, T>
299where
300  T: KeyComparable,
301{
302  fn new(inner: &'a OneOrSet<T>) -> Self {
303    Self { inner, index: 0 }
304  }
305}
306
307impl<'a, T> Iterator for OneOrSetIter<'a, T>
308where
309  T: KeyComparable,
310{
311  type Item = &'a T;
312
313  fn next(&mut self) -> Option<Self::Item> {
314    self.index += 1;
315    self.inner.get(self.index - 1)
316  }
317}
318
319#[cfg(test)]
320mod tests {
321  use crate::convert::FromJson;
322  use crate::convert::ToJson;
323
324  use super::*;
325
326  #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
327  struct MockKeyU8(u8);
328
329  #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
330  struct MockKeyBool(bool);
331
332  impl KeyComparable for MockKeyU8 {
333    type Key = u8;
334
335    fn key(&self) -> &Self::Key {
336      &self.0
337    }
338  }
339
340  impl KeyComparable for MockKeyBool {
341    type Key = bool;
342
343    fn key(&self) -> &Self::Key {
344      &self.0
345    }
346  }
347
348  #[test]
349  fn test_new_set() {
350    // VALID: non-empty set.
351    let ordered_set: OrderedSet<MockKeyU8> = OrderedSet::from_iter([1, 2, 3].map(MockKeyU8));
352    let new_set: OneOrSet<MockKeyU8> = OneOrSet::new_set(ordered_set.clone()).unwrap();
353    let try_from_set: OneOrSet<MockKeyU8> = OneOrSet::try_from(ordered_set.clone()).unwrap();
354    assert_eq!(new_set, try_from_set);
355    assert_eq!(OrderedSet::from(new_set), ordered_set);
356
357    // INVALID: empty set.
358    let empty: OrderedSet<MockKeyU8> = OrderedSet::new();
359    assert!(matches!(OneOrSet::new_set(empty.clone()), Err(Error::OneOrSetEmpty)));
360    assert!(matches!(OneOrSet::try_from(empty), Err(Error::OneOrSetEmpty)));
361  }
362
363  #[test]
364  fn test_append_from_one() {
365    let mut collection: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(42));
366    assert_eq!(collection.len(), 1);
367
368    // Ignores duplicates.
369    collection.append(MockKeyU8(42));
370    assert_eq!(collection, OneOrSet::new_one(MockKeyU8(42)));
371    assert_eq!(collection.len(), 1);
372
373    // Becomes Set.
374    collection.append(MockKeyU8(128));
375    assert_eq!(
376      collection,
377      OneOrSet::new_set(OrderedSet::from_iter([42, 128].map(MockKeyU8).into_iter())).unwrap()
378    );
379    assert_eq!(collection.len(), 2);
380
381    collection.append(MockKeyU8(200));
382    assert_eq!(
383      collection,
384      OneOrSet::new_set(OrderedSet::from_iter([42, 128, 200].map(MockKeyU8).into_iter())).unwrap()
385    );
386    assert_eq!(collection.len(), 3);
387  }
388
389  #[test]
390  fn test_append_from_set() {
391    let mut collection: OneOrSet<MockKeyU8> = OneOrSet::new_set((0..42).map(MockKeyU8).collect()).unwrap();
392    assert_eq!(collection.len(), 42);
393
394    // Appends to end.
395    collection.append(MockKeyU8(42));
396    let expected: OneOrSet<MockKeyU8> = OneOrSet::new_set((0..=42).map(MockKeyU8).collect()).unwrap();
397    assert_eq!(collection, expected);
398    assert_eq!(collection.len(), 43);
399
400    // Ignores duplicates.
401    for i in 0..=42 {
402      collection.append(MockKeyU8(i));
403      assert_eq!(collection, expected);
404      assert_eq!(collection.len(), 43);
405    }
406  }
407
408  #[test]
409  fn test_contains() {
410    // One.
411    let one: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(1));
412    assert!(one.contains(&1u8));
413    assert!(!one.contains(&2u8));
414    assert!(!one.contains(&3u8));
415
416    // Set.
417    let set: OneOrSet<MockKeyU8> = OneOrSet::new_set((1..=3).map(MockKeyU8).collect()).unwrap();
418    assert!(set.contains(&1u8));
419    assert!(set.contains(&2u8));
420    assert!(set.contains(&3u8));
421    assert!(!set.contains(&4u8));
422  }
423
424  #[test]
425  fn test_get() {
426    // One.
427    let one: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(1));
428    assert_eq!(one.get(0), Some(&MockKeyU8(1)));
429    assert_eq!(one.get(1), None);
430    assert_eq!(one.get(2), None);
431
432    // Set.
433    let set: OneOrSet<MockKeyU8> = OneOrSet::new_set((1..=3).map(MockKeyU8).collect()).unwrap();
434    assert_eq!(set.get(0), Some(&MockKeyU8(1)));
435    assert_eq!(set.get(1), Some(&MockKeyU8(2)));
436    assert_eq!(set.get(2), Some(&MockKeyU8(3)));
437    assert_eq!(set.get(3), None);
438  }
439
440  #[test]
441  fn test_map() {
442    // One.
443    let one: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(1));
444    let one_add: OneOrSet<MockKeyU8> = one.map(|item| MockKeyU8(item.0 + 1));
445    assert_eq!(one_add, OneOrSet::new_one(MockKeyU8(2)));
446
447    // Set.
448    let set: OneOrSet<MockKeyU8> = OneOrSet::new_set((1..=3).map(MockKeyU8).collect()).unwrap();
449    let set_add: OneOrSet<MockKeyU8> = set.map(|item| MockKeyU8(item.0 + 10));
450    assert_eq!(set_add, OneOrSet::new_set((11..=13).map(MockKeyU8).collect()).unwrap());
451
452    // Set reduced to one.
453    let set_many: OneOrSet<MockKeyU8> = OneOrSet::new_set([2, 4, 6, 8].into_iter().map(MockKeyU8).collect()).unwrap();
454    assert_eq!(set_many.len(), 4);
455    let set_bool: OneOrSet<MockKeyBool> = set_many.map(|item| MockKeyBool(item.0 % 2 == 0));
456    assert_eq!(set_bool, OneOrSet::new_one(MockKeyBool(true)));
457    assert_eq!(set_bool.0, OneOrSetInner::One(MockKeyBool(true)));
458    assert_eq!(set_bool.len(), 1);
459  }
460
461  #[test]
462  fn test_try_map() {
463    // One - OK
464    let one: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(1));
465    let one_add: OneOrSet<MockKeyU8> = one
466      .try_map(|item| {
467        if item.key() == &1 {
468          Ok(MockKeyU8(item.0 + 1))
469        } else {
470          Err(Error::OneOrSetEmpty)
471        }
472      })
473      .unwrap();
474    assert_eq!(one_add, OneOrSet::new_one(MockKeyU8(2)));
475
476    // One - ERROR
477    let one_err: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(1));
478    let result_one: Result<OneOrSet<MockKeyBool>> = one_err.try_map(|item| {
479      if item.key() == &1 {
480        Err(Error::OneOrSetEmpty)
481      } else {
482        Ok(MockKeyBool(false))
483      }
484    });
485    assert!(matches!(result_one, Err(Error::OneOrSetEmpty)));
486
487    // Set - OK
488    let set: OneOrSet<MockKeyU8> = OneOrSet::new_set((1..=3).map(MockKeyU8).collect()).unwrap();
489    let set_add: OneOrSet<MockKeyU8> = set
490      .try_map(|item| {
491        if item.key() < &4 {
492          Ok(MockKeyU8(item.0 + 10))
493        } else {
494          Err(Error::OneOrSetEmpty)
495        }
496      })
497      .unwrap();
498    assert_eq!(set_add, OneOrSet::new_set((11..=13).map(MockKeyU8).collect()).unwrap());
499
500    // Set - ERROR
501    let set_err: OneOrSet<MockKeyU8> = OneOrSet::new_set((1..=3).map(MockKeyU8).collect()).unwrap();
502    let result_set: Result<OneOrSet<MockKeyU8>> = set_err.try_map(|item| {
503      if item.key() < &4 {
504        Err(Error::OneOrSetEmpty)
505      } else {
506        Ok(MockKeyU8(item.0))
507      }
508    });
509    assert!(matches!(result_set, Err(Error::OneOrSetEmpty)));
510
511    // Set reduced to one - OK
512    let set_many: OneOrSet<MockKeyU8> = OneOrSet::new_set([2, 4, 6, 8].into_iter().map(MockKeyU8).collect()).unwrap();
513    assert_eq!(set_many.len(), 4);
514    let set_bool: OneOrSet<MockKeyBool> = set_many
515      .try_map(|item| {
516        if item.key() % 2 == 0 {
517          Ok(MockKeyBool(item.0 % 2 == 0))
518        } else {
519          Err(Error::OneOrSetEmpty)
520        }
521      })
522      .unwrap();
523    assert_eq!(set_bool, OneOrSet::new_one(MockKeyBool(true)));
524    assert_eq!(set_bool.0, OneOrSetInner::One(MockKeyBool(true)));
525    assert_eq!(set_bool.len(), 1);
526  }
527
528  #[test]
529  fn test_iter() {
530    // One.
531    let one: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(1));
532    let mut one_iter = one.iter();
533    assert_eq!(one_iter.next(), Some(&MockKeyU8(1)));
534    assert_eq!(one_iter.next(), None);
535    assert_eq!(one_iter.next(), None);
536
537    // Set.
538    let set: OneOrSet<MockKeyU8> = OneOrSet::new_set((1..=3).map(MockKeyU8).collect()).unwrap();
539    let mut set_iter = set.iter();
540    assert_eq!(set_iter.next(), Some(&MockKeyU8(1)));
541    assert_eq!(set_iter.next(), Some(&MockKeyU8(2)));
542    assert_eq!(set_iter.next(), Some(&MockKeyU8(3)));
543    assert_eq!(set_iter.next(), None);
544  }
545
546  #[test]
547  fn test_serde() {
548    // VALID: one.
549    {
550      let one: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(1));
551      let ser: String = one.to_json().unwrap();
552      let de: OneOrSet<MockKeyU8> = OneOrSet::from_json(&ser).unwrap();
553      assert_eq!(ser, "1");
554      assert_eq!(de, one);
555    }
556
557    // VALID: set.
558    {
559      let set: OneOrSet<MockKeyU8> = OneOrSet::new_set((1..=3).map(MockKeyU8).collect()).unwrap();
560      let ser: String = set.to_json().unwrap();
561      let de: OneOrSet<MockKeyU8> = OneOrSet::from_json(&ser).unwrap();
562      assert_eq!(ser, "[1,2,3]");
563      assert_eq!(de, set);
564    }
565
566    // INVALID: empty.
567    {
568      let empty: Result<OneOrSet<MockKeyU8>> = OneOrSet::from_json("");
569      assert!(empty.is_err());
570      let empty_set: Result<OneOrSet<MockKeyU8>> = OneOrSet::from_json("[]");
571      assert!(empty_set.is_err());
572      let empty_space: Result<OneOrSet<MockKeyU8>> = OneOrSet::from_json("[ ]");
573      assert!(empty_space.is_err());
574    }
575  }
576}