1use core::fmt::Debug;
5use core::fmt::Formatter;
6use core::hash::Hash;
7use core::iter;
8use core::mem::replace;
9use core::ops::Deref;
10use core::slice::from_ref;
11
12use serde::de;
13use serde::Deserialize;
14use serde::Serialize;
15
16use crate::common::KeyComparable;
17use crate::common::OrderedSet;
18use crate::error::Error;
19use crate::error::Result;
20
21#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)]
26#[serde(transparent)]
27pub struct OneOrSet<T>(OneOrSetInner<T>)
28where
29 T: KeyComparable;
30
31#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)]
33#[serde(untagged)]
34enum OneOrSetInner<T>
35where
36 T: KeyComparable,
37{
38 One(T),
40 #[serde(deserialize_with = "deserialize_non_empty_set")]
42 Set(OrderedSet<T>),
43}
44
45fn deserialize_non_empty_set<'de, D, T: serde::Deserialize<'de> + KeyComparable>(
47 deserializer: D,
48) -> Result<OrderedSet<T>, D::Error>
49where
50 D: de::Deserializer<'de>,
51{
52 let set: OrderedSet<T> = OrderedSet::deserialize(deserializer)?;
53 if set.is_empty() {
54 return Err(de::Error::custom(Error::OneOrSetEmpty));
55 }
56
57 Ok(set)
58}
59
60impl<T> OneOrSet<T>
61where
62 T: KeyComparable,
63{
64 pub fn new_one(item: T) -> Self {
66 Self(OneOrSetInner::One(item))
67 }
68
69 pub fn new_set(set: OrderedSet<T>) -> Result<Self> {
73 if set.is_empty() {
74 return Err(Error::OneOrSetEmpty);
75 }
76 if set.len() == 1 {
77 Ok(Self::new_one(
78 set.into_vec().pop().expect("infallible OneOrSet new_set"),
79 ))
80 } else {
81 Ok(Self(OneOrSetInner::Set(set)))
82 }
83 }
84
85 pub fn map<S, F>(self, mut f: F) -> OneOrSet<S>
87 where
88 S: KeyComparable,
89 F: FnMut(T) -> S,
90 {
91 OneOrSet(match self.0 {
92 OneOrSetInner::One(item) => OneOrSetInner::One(f(item)),
93 OneOrSetInner::Set(set_t) => {
94 let set_s: OrderedSet<S> = set_t.into_vec().into_iter().map(f).collect();
95 if set_s.len() == 1 {
97 OneOrSetInner::One(set_s.into_vec().pop().expect("OneOrSet::map infallible"))
98 } else {
99 OneOrSetInner::Set(set_s)
100 }
101 }
102 })
103 }
104
105 pub fn try_map<S, F, E>(self, mut f: F) -> Result<OneOrSet<S>, E>
107 where
108 S: KeyComparable,
109 F: FnMut(T) -> Result<S, E>,
110 {
111 Ok(OneOrSet(match self.0 {
112 OneOrSetInner::One(item) => OneOrSetInner::One(f(item)?),
113 OneOrSetInner::Set(set_t) => {
114 let set_s: OrderedSet<S> = set_t
115 .into_vec()
116 .into_iter()
117 .map(f)
118 .collect::<Result<OrderedSet<S>, E>>()?;
119 if set_s.len() == 1 {
121 OneOrSetInner::One(set_s.into_vec().pop().expect("OneOrSet::try_map infallible"))
122 } else {
123 OneOrSetInner::Set(set_s)
124 }
125 }
126 }))
127 }
128
129 #[allow(clippy::len_without_is_empty)]
131 pub fn len(&self) -> usize {
132 match &self.0 {
133 OneOrSetInner::One(_) => 1,
134 OneOrSetInner::Set(inner) => inner.len(),
135 }
136 }
137
138 pub fn get(&self, index: usize) -> Option<&T> {
140 match &self.0 {
141 OneOrSetInner::One(inner) if index == 0 => Some(inner),
142 OneOrSetInner::One(_) => None,
143 OneOrSetInner::Set(inner) => inner.get(index),
144 }
145 }
146
147 pub fn contains<U>(&self, item: &U) -> bool
149 where
150 T: KeyComparable,
151 U: KeyComparable<Key = T::Key> + ?Sized,
152 {
153 match &self.0 {
154 OneOrSetInner::One(inner) => inner.key() == item.key(),
155 OneOrSetInner::Set(inner) => inner.contains(item),
156 }
157 }
158
159 pub fn append(&mut self, item: T) -> bool
163 where
164 T: KeyComparable,
165 {
166 match &mut self.0 {
167 OneOrSetInner::One(inner) if inner.key() == item.key() => false,
168 OneOrSetInner::One(_) => match replace(&mut self.0, OneOrSetInner::Set(OrderedSet::new())) {
169 OneOrSetInner::One(inner) => {
170 self.0 = OneOrSetInner::Set(OrderedSet::from_iter([inner, item]));
171 true
172 }
173 OneOrSetInner::Set(_) => unreachable!(),
174 },
175 OneOrSetInner::Set(inner) => inner.append(item),
176 }
177 }
178
179 pub fn iter(&self) -> impl Iterator<Item = &T> + '_ {
181 OneOrSetIter::new(self)
182 }
183
184 pub fn as_slice(&self) -> &[T] {
186 self
187 }
188
189 pub fn into_vec(self) -> Vec<T> {
191 match self.0 {
192 OneOrSetInner::One(inner) => vec![inner],
193 OneOrSetInner::Set(inner) => inner.into_vec(),
194 }
195 }
196}
197
198impl<T> Debug for OneOrSet<T>
199where
200 T: Debug + KeyComparable,
201{
202 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
203 match &self.0 {
204 OneOrSetInner::One(inner) => Debug::fmt(inner, f),
205 OneOrSetInner::Set(inner) => Debug::fmt(inner, f),
206 }
207 }
208}
209
210impl<T> Deref for OneOrSet<T>
211where
212 T: KeyComparable,
213{
214 type Target = [T];
215
216 fn deref(&self) -> &Self::Target {
217 match &self.0 {
218 OneOrSetInner::One(inner) => from_ref(inner),
219 OneOrSetInner::Set(inner) => inner.as_slice(),
220 }
221 }
222}
223
224impl<T> AsRef<[T]> for OneOrSet<T>
225where
226 T: KeyComparable,
227{
228 fn as_ref(&self) -> &[T] {
229 self.as_slice()
230 }
231}
232
233impl<T> From<T> for OneOrSet<T>
234where
235 T: KeyComparable,
236{
237 fn from(other: T) -> Self {
238 OneOrSet::new_one(other)
239 }
240}
241
242impl<T> TryFrom<Vec<T>> for OneOrSet<T>
243where
244 T: KeyComparable,
245{
246 type Error = Error;
247
248 fn try_from(other: Vec<T>) -> std::result::Result<Self, Self::Error> {
249 let set: OrderedSet<T> = OrderedSet::try_from(other)?;
250 OneOrSet::new_set(set)
251 }
252}
253
254impl<T> TryFrom<OrderedSet<T>> for OneOrSet<T>
255where
256 T: KeyComparable,
257{
258 type Error = Error;
259
260 fn try_from(other: OrderedSet<T>) -> std::result::Result<Self, Self::Error> {
261 OneOrSet::new_set(other)
262 }
263}
264
265impl<T> From<OneOrSet<T>> for Vec<T>
266where
267 T: KeyComparable,
268{
269 fn from(other: OneOrSet<T>) -> Self {
270 other.into_vec()
271 }
272}
273
274impl<T> From<OneOrSet<T>> for OrderedSet<T>
275where
276 T: KeyComparable,
277{
278 fn from(other: OneOrSet<T>) -> Self {
279 match other.0 {
280 OneOrSetInner::One(item) => OrderedSet::from_iter(iter::once(item)),
281 OneOrSetInner::Set(set) => set,
282 }
283 }
284}
285
286struct OneOrSetIter<'a, T>
291where
292 T: KeyComparable,
293{
294 inner: &'a OneOrSet<T>,
295 index: usize,
296}
297
298impl<'a, T> OneOrSetIter<'a, T>
299where
300 T: KeyComparable,
301{
302 fn new(inner: &'a OneOrSet<T>) -> Self {
303 Self { inner, index: 0 }
304 }
305}
306
307impl<'a, T> Iterator for OneOrSetIter<'a, T>
308where
309 T: KeyComparable,
310{
311 type Item = &'a T;
312
313 fn next(&mut self) -> Option<Self::Item> {
314 self.index += 1;
315 self.inner.get(self.index - 1)
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use crate::convert::FromJson;
322 use crate::convert::ToJson;
323
324 use super::*;
325
326 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
327 struct MockKeyU8(u8);
328
329 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
330 struct MockKeyBool(bool);
331
332 impl KeyComparable for MockKeyU8 {
333 type Key = u8;
334
335 fn key(&self) -> &Self::Key {
336 &self.0
337 }
338 }
339
340 impl KeyComparable for MockKeyBool {
341 type Key = bool;
342
343 fn key(&self) -> &Self::Key {
344 &self.0
345 }
346 }
347
348 #[test]
349 fn test_new_set() {
350 let ordered_set: OrderedSet<MockKeyU8> = OrderedSet::from_iter([1, 2, 3].map(MockKeyU8));
352 let new_set: OneOrSet<MockKeyU8> = OneOrSet::new_set(ordered_set.clone()).unwrap();
353 let try_from_set: OneOrSet<MockKeyU8> = OneOrSet::try_from(ordered_set.clone()).unwrap();
354 assert_eq!(new_set, try_from_set);
355 assert_eq!(OrderedSet::from(new_set), ordered_set);
356
357 let empty: OrderedSet<MockKeyU8> = OrderedSet::new();
359 assert!(matches!(OneOrSet::new_set(empty.clone()), Err(Error::OneOrSetEmpty)));
360 assert!(matches!(OneOrSet::try_from(empty), Err(Error::OneOrSetEmpty)));
361 }
362
363 #[test]
364 fn test_append_from_one() {
365 let mut collection: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(42));
366 assert_eq!(collection.len(), 1);
367
368 collection.append(MockKeyU8(42));
370 assert_eq!(collection, OneOrSet::new_one(MockKeyU8(42)));
371 assert_eq!(collection.len(), 1);
372
373 collection.append(MockKeyU8(128));
375 assert_eq!(
376 collection,
377 OneOrSet::new_set(OrderedSet::from_iter([42, 128].map(MockKeyU8).into_iter())).unwrap()
378 );
379 assert_eq!(collection.len(), 2);
380
381 collection.append(MockKeyU8(200));
382 assert_eq!(
383 collection,
384 OneOrSet::new_set(OrderedSet::from_iter([42, 128, 200].map(MockKeyU8).into_iter())).unwrap()
385 );
386 assert_eq!(collection.len(), 3);
387 }
388
389 #[test]
390 fn test_append_from_set() {
391 let mut collection: OneOrSet<MockKeyU8> = OneOrSet::new_set((0..42).map(MockKeyU8).collect()).unwrap();
392 assert_eq!(collection.len(), 42);
393
394 collection.append(MockKeyU8(42));
396 let expected: OneOrSet<MockKeyU8> = OneOrSet::new_set((0..=42).map(MockKeyU8).collect()).unwrap();
397 assert_eq!(collection, expected);
398 assert_eq!(collection.len(), 43);
399
400 for i in 0..=42 {
402 collection.append(MockKeyU8(i));
403 assert_eq!(collection, expected);
404 assert_eq!(collection.len(), 43);
405 }
406 }
407
408 #[test]
409 fn test_contains() {
410 let one: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(1));
412 assert!(one.contains(&1u8));
413 assert!(!one.contains(&2u8));
414 assert!(!one.contains(&3u8));
415
416 let set: OneOrSet<MockKeyU8> = OneOrSet::new_set((1..=3).map(MockKeyU8).collect()).unwrap();
418 assert!(set.contains(&1u8));
419 assert!(set.contains(&2u8));
420 assert!(set.contains(&3u8));
421 assert!(!set.contains(&4u8));
422 }
423
424 #[test]
425 fn test_get() {
426 let one: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(1));
428 assert_eq!(one.get(0), Some(&MockKeyU8(1)));
429 assert_eq!(one.get(1), None);
430 assert_eq!(one.get(2), None);
431
432 let set: OneOrSet<MockKeyU8> = OneOrSet::new_set((1..=3).map(MockKeyU8).collect()).unwrap();
434 assert_eq!(set.get(0), Some(&MockKeyU8(1)));
435 assert_eq!(set.get(1), Some(&MockKeyU8(2)));
436 assert_eq!(set.get(2), Some(&MockKeyU8(3)));
437 assert_eq!(set.get(3), None);
438 }
439
440 #[test]
441 fn test_map() {
442 let one: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(1));
444 let one_add: OneOrSet<MockKeyU8> = one.map(|item| MockKeyU8(item.0 + 1));
445 assert_eq!(one_add, OneOrSet::new_one(MockKeyU8(2)));
446
447 let set: OneOrSet<MockKeyU8> = OneOrSet::new_set((1..=3).map(MockKeyU8).collect()).unwrap();
449 let set_add: OneOrSet<MockKeyU8> = set.map(|item| MockKeyU8(item.0 + 10));
450 assert_eq!(set_add, OneOrSet::new_set((11..=13).map(MockKeyU8).collect()).unwrap());
451
452 let set_many: OneOrSet<MockKeyU8> = OneOrSet::new_set([2, 4, 6, 8].into_iter().map(MockKeyU8).collect()).unwrap();
454 assert_eq!(set_many.len(), 4);
455 let set_bool: OneOrSet<MockKeyBool> = set_many.map(|item| MockKeyBool(item.0 % 2 == 0));
456 assert_eq!(set_bool, OneOrSet::new_one(MockKeyBool(true)));
457 assert_eq!(set_bool.0, OneOrSetInner::One(MockKeyBool(true)));
458 assert_eq!(set_bool.len(), 1);
459 }
460
461 #[test]
462 fn test_try_map() {
463 let one: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(1));
465 let one_add: OneOrSet<MockKeyU8> = one
466 .try_map(|item| {
467 if item.key() == &1 {
468 Ok(MockKeyU8(item.0 + 1))
469 } else {
470 Err(Error::OneOrSetEmpty)
471 }
472 })
473 .unwrap();
474 assert_eq!(one_add, OneOrSet::new_one(MockKeyU8(2)));
475
476 let one_err: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(1));
478 let result_one: Result<OneOrSet<MockKeyBool>> = one_err.try_map(|item| {
479 if item.key() == &1 {
480 Err(Error::OneOrSetEmpty)
481 } else {
482 Ok(MockKeyBool(false))
483 }
484 });
485 assert!(matches!(result_one, Err(Error::OneOrSetEmpty)));
486
487 let set: OneOrSet<MockKeyU8> = OneOrSet::new_set((1..=3).map(MockKeyU8).collect()).unwrap();
489 let set_add: OneOrSet<MockKeyU8> = set
490 .try_map(|item| {
491 if item.key() < &4 {
492 Ok(MockKeyU8(item.0 + 10))
493 } else {
494 Err(Error::OneOrSetEmpty)
495 }
496 })
497 .unwrap();
498 assert_eq!(set_add, OneOrSet::new_set((11..=13).map(MockKeyU8).collect()).unwrap());
499
500 let set_err: OneOrSet<MockKeyU8> = OneOrSet::new_set((1..=3).map(MockKeyU8).collect()).unwrap();
502 let result_set: Result<OneOrSet<MockKeyU8>> = set_err.try_map(|item| {
503 if item.key() < &4 {
504 Err(Error::OneOrSetEmpty)
505 } else {
506 Ok(MockKeyU8(item.0))
507 }
508 });
509 assert!(matches!(result_set, Err(Error::OneOrSetEmpty)));
510
511 let set_many: OneOrSet<MockKeyU8> = OneOrSet::new_set([2, 4, 6, 8].into_iter().map(MockKeyU8).collect()).unwrap();
513 assert_eq!(set_many.len(), 4);
514 let set_bool: OneOrSet<MockKeyBool> = set_many
515 .try_map(|item| {
516 if item.key() % 2 == 0 {
517 Ok(MockKeyBool(item.0 % 2 == 0))
518 } else {
519 Err(Error::OneOrSetEmpty)
520 }
521 })
522 .unwrap();
523 assert_eq!(set_bool, OneOrSet::new_one(MockKeyBool(true)));
524 assert_eq!(set_bool.0, OneOrSetInner::One(MockKeyBool(true)));
525 assert_eq!(set_bool.len(), 1);
526 }
527
528 #[test]
529 fn test_iter() {
530 let one: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(1));
532 let mut one_iter = one.iter();
533 assert_eq!(one_iter.next(), Some(&MockKeyU8(1)));
534 assert_eq!(one_iter.next(), None);
535 assert_eq!(one_iter.next(), None);
536
537 let set: OneOrSet<MockKeyU8> = OneOrSet::new_set((1..=3).map(MockKeyU8).collect()).unwrap();
539 let mut set_iter = set.iter();
540 assert_eq!(set_iter.next(), Some(&MockKeyU8(1)));
541 assert_eq!(set_iter.next(), Some(&MockKeyU8(2)));
542 assert_eq!(set_iter.next(), Some(&MockKeyU8(3)));
543 assert_eq!(set_iter.next(), None);
544 }
545
546 #[test]
547 fn test_serde() {
548 {
550 let one: OneOrSet<MockKeyU8> = OneOrSet::new_one(MockKeyU8(1));
551 let ser: String = one.to_json().unwrap();
552 let de: OneOrSet<MockKeyU8> = OneOrSet::from_json(&ser).unwrap();
553 assert_eq!(ser, "1");
554 assert_eq!(de, one);
555 }
556
557 {
559 let set: OneOrSet<MockKeyU8> = OneOrSet::new_set((1..=3).map(MockKeyU8).collect()).unwrap();
560 let ser: String = set.to_json().unwrap();
561 let de: OneOrSet<MockKeyU8> = OneOrSet::from_json(&ser).unwrap();
562 assert_eq!(ser, "[1,2,3]");
563 assert_eq!(de, set);
564 }
565
566 {
568 let empty: Result<OneOrSet<MockKeyU8>> = OneOrSet::from_json("");
569 assert!(empty.is_err());
570 let empty_set: Result<OneOrSet<MockKeyU8>> = OneOrSet::from_json("[]");
571 assert!(empty_set.is_err());
572 let empty_space: Result<OneOrSet<MockKeyU8>> = OneOrSet::from_json("[ ]");
573 assert!(empty_space.is_err());
574 }
575 }
576}