consensus_core/
stake_aggregator.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{collections::BTreeSet, marker::PhantomData};
6
7use consensus_config::{AuthorityIndex, Committee, Stake};
8
9pub(crate) trait CommitteeThreshold {
10    fn is_threshold(committee: &Committee, amount: Stake) -> bool;
11    fn threshold(committee: &Committee) -> Stake;
12}
13
14pub(crate) struct QuorumThreshold;
15
16#[cfg(test)]
17pub(crate) struct ValidityThreshold;
18
19impl CommitteeThreshold for QuorumThreshold {
20    fn is_threshold(committee: &Committee, amount: Stake) -> bool {
21        committee.reached_quorum(amount)
22    }
23    fn threshold(committee: &Committee) -> Stake {
24        committee.quorum_threshold()
25    }
26}
27
28#[cfg(test)]
29impl CommitteeThreshold for ValidityThreshold {
30    fn is_threshold(committee: &Committee, amount: Stake) -> bool {
31        committee.reached_validity(amount)
32    }
33    fn threshold(committee: &Committee) -> Stake {
34        committee.validity_threshold()
35    }
36}
37
38pub(crate) struct StakeAggregator<T> {
39    votes: BTreeSet<AuthorityIndex>,
40    stake: Stake,
41    _phantom: PhantomData<T>,
42}
43
44impl<T: CommitteeThreshold> StakeAggregator<T> {
45    pub(crate) fn new() -> Self {
46        Self {
47            votes: Default::default(),
48            stake: 0,
49            _phantom: Default::default(),
50        }
51    }
52
53    /// Adds a vote for the specified authority index to the aggregator. It is
54    /// guaranteed to count the vote only once for an authority. The method
55    /// returns true when the required threshold has been reached.
56    pub(crate) fn add(&mut self, vote: AuthorityIndex, committee: &Committee) -> bool {
57        if self.votes.insert(vote) {
58            self.stake += committee.stake(vote);
59        }
60        T::is_threshold(committee, self.stake)
61    }
62
63    pub(crate) fn stake(&self) -> Stake {
64        self.stake
65    }
66
67    pub(crate) fn reached_threshold(&self, committee: &Committee) -> bool {
68        T::is_threshold(committee, self.stake)
69    }
70
71    pub(crate) fn threshold(&self, committee: &Committee) -> Stake {
72        T::threshold(committee)
73    }
74
75    pub(crate) fn clear(&mut self) {
76        self.votes.clear();
77        self.stake = 0;
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use consensus_config::{AuthorityIndex, local_committee_and_keys};
84
85    use super::*;
86
87    #[test]
88    fn test_aggregator_quorum_threshold() {
89        let committee = local_committee_and_keys(0, vec![1, 1, 1, 1]).0;
90        let mut aggregator = StakeAggregator::<QuorumThreshold>::new();
91
92        assert!(!aggregator.add(AuthorityIndex::new_for_test(0), &committee));
93        assert!(!aggregator.add(AuthorityIndex::new_for_test(1), &committee));
94        assert!(aggregator.add(AuthorityIndex::new_for_test(2), &committee));
95        assert!(aggregator.add(AuthorityIndex::new_for_test(3), &committee));
96    }
97
98    #[test]
99    fn test_aggregator_validity_threshold() {
100        let committee = local_committee_and_keys(0, vec![1, 1, 1, 1]).0;
101        let mut aggregator = StakeAggregator::<ValidityThreshold>::new();
102
103        assert!(!aggregator.add(AuthorityIndex::new_for_test(0), &committee));
104        assert!(aggregator.add(AuthorityIndex::new_for_test(1), &committee));
105    }
106
107    #[test]
108    fn test_aggregator_clear() {
109        let committee = local_committee_and_keys(0, vec![1, 1, 1, 1]).0;
110        let mut aggregator = StakeAggregator::<ValidityThreshold>::new();
111
112        assert!(!aggregator.add(AuthorityIndex::new_for_test(0), &committee));
113        assert!(aggregator.add(AuthorityIndex::new_for_test(1), &committee));
114
115        // clear the aggregator
116        aggregator.clear();
117
118        // now add them again
119        assert!(!aggregator.add(AuthorityIndex::new_for_test(0), &committee));
120        assert!(aggregator.add(AuthorityIndex::new_for_test(1), &committee));
121    }
122}