1use std::{
6 collections::{BTreeMap, HashMap, hash_map::Entry},
7 hash::Hash,
8 sync::Arc,
9};
10
11use iota_types::{
12 base_types::{AuthorityName, ConciseableName},
13 committee::{Committee, CommitteeTrait, StakeUnit},
14 crypto::{AuthorityQuorumSignInfo, AuthoritySignInfo, AuthoritySignInfoTrait},
15 error::IotaError,
16 message_envelope::{Envelope, Message},
17};
18use serde::Serialize;
19use shared_crypto::intent::Intent;
20use tracing::warn;
21
22#[derive(Debug)]
26pub struct StakeAggregator<S, const STRENGTH: bool> {
27 data: HashMap<AuthorityName, S>,
28 total_votes: StakeUnit,
29 committee: Arc<Committee>,
30}
31
32impl<S: Clone + Eq, const STRENGTH: bool> StakeAggregator<S, STRENGTH> {
39 pub fn new(committee: Arc<Committee>) -> Self {
40 Self {
41 data: Default::default(),
42 total_votes: Default::default(),
43 committee,
44 }
45 }
46
47 pub fn from_iter<I: Iterator<Item = (AuthorityName, S)>>(
48 committee: Arc<Committee>,
49 data: I,
50 ) -> Self {
51 let mut this = Self::new(committee);
52 for (authority, s) in data {
53 this.insert_generic(authority, s);
54 }
55 this
56 }
57
58 pub fn insert_generic(
65 &mut self,
66 authority: AuthorityName,
67 s: S,
68 ) -> InsertResult<&HashMap<AuthorityName, S>> {
69 match self.data.entry(authority) {
70 Entry::Occupied(oc) => {
71 return InsertResult::Failed {
72 error: IotaError::StakeAggregatorRepeatedSigner {
73 signer: authority,
74 conflicting_sig: oc.get() != &s,
75 },
76 };
77 }
78 Entry::Vacant(va) => {
79 va.insert(s);
80 }
81 }
82 let votes = self.committee.weight(&authority);
83 if votes > 0 {
84 self.total_votes += votes;
85 if self.total_votes >= self.committee.threshold::<STRENGTH>() {
86 InsertResult::QuorumReached(&self.data)
87 } else {
88 InsertResult::NotEnoughVotes {
89 bad_votes: 0,
90 bad_authorities: vec![],
91 }
92 }
93 } else {
94 InsertResult::Failed {
95 error: IotaError::InvalidAuthenticator,
96 }
97 }
98 }
99
100 pub fn contains_key(&self, authority: &AuthorityName) -> bool {
101 self.data.contains_key(authority)
102 }
103
104 pub fn keys(&self) -> impl Iterator<Item = &AuthorityName> {
105 self.data.keys()
106 }
107
108 pub fn committee(&self) -> &Committee {
109 &self.committee
110 }
111
112 pub fn total_votes(&self) -> StakeUnit {
113 self.total_votes
114 }
115
116 pub fn has_quorum(&self) -> bool {
117 self.total_votes >= self.committee.threshold::<STRENGTH>()
118 }
119
120 pub fn validator_sig_count(&self) -> usize {
121 self.data.len()
122 }
123}
124
125impl<const STRENGTH: bool> StakeAggregator<AuthoritySignInfo, STRENGTH> {
126 pub fn insert<T: Message + Serialize>(
131 &mut self,
132 envelope: Envelope<T, AuthoritySignInfo>,
133 ) -> InsertResult<AuthorityQuorumSignInfo<STRENGTH>> {
134 let (data, sig) = envelope.into_data_and_sig();
135 if self.committee.epoch != sig.epoch {
136 return InsertResult::Failed {
137 error: IotaError::WrongEpoch {
138 expected_epoch: self.committee.epoch,
139 actual_epoch: sig.epoch,
140 },
141 };
142 }
143 match self.insert_generic(sig.authority, sig) {
144 InsertResult::QuorumReached(_) => {
145 match AuthorityQuorumSignInfo::<STRENGTH>::new_from_auth_sign_infos(
146 self.data.values().cloned().collect(),
147 self.committee(),
148 ) {
149 Ok(aggregated) => {
150 match aggregated.verify_secure(
151 &data,
152 Intent::iota_app(T::SCOPE),
153 self.committee(),
154 ) {
155 Ok(_) => InsertResult::QuorumReached(aggregated),
158 Err(_) => {
159 let mut bad_votes = 0;
171 let mut bad_authorities = vec![];
172 for (name, sig) in &self.data.clone() {
173 if let Err(err) = sig.verify_secure(
174 &data,
175 Intent::iota_app(T::SCOPE),
176 self.committee(),
177 ) {
178 warn!(name=?name.concise(), "Bad stake from validator: {:?}", err);
185 self.data.remove(name);
186 let votes = self.committee.weight(name);
187 self.total_votes -= votes;
188 bad_votes += votes;
189 bad_authorities.push(*name);
190 }
191 }
192 InsertResult::NotEnoughVotes {
193 bad_votes,
194 bad_authorities,
195 }
196 }
197 }
198 }
199 Err(error) => InsertResult::Failed { error },
200 }
201 }
202 InsertResult::Failed { error } => InsertResult::Failed { error },
204 InsertResult::NotEnoughVotes {
205 bad_votes,
206 bad_authorities,
207 } => InsertResult::NotEnoughVotes {
208 bad_votes,
209 bad_authorities,
210 },
211 }
212 }
213}
214
215pub enum InsertResult<CertT> {
216 QuorumReached(CertT),
217 Failed {
218 error: IotaError,
219 },
220 NotEnoughVotes {
221 bad_votes: u64,
222 bad_authorities: Vec<AuthorityName>,
223 },
224}
225
226impl<CertT> InsertResult<CertT> {
227 pub fn is_quorum_reached(&self) -> bool {
228 matches!(self, Self::QuorumReached(..))
229 }
230}
231
232#[derive(Debug)]
238pub struct MultiStakeAggregator<K, V, const STRENGTH: bool> {
239 committee: Arc<Committee>,
240 stake_maps: HashMap<K, (V, StakeAggregator<AuthoritySignInfo, STRENGTH>)>,
241}
242
243impl<K, V, const STRENGTH: bool> MultiStakeAggregator<K, V, STRENGTH> {
244 pub fn new(committee: Arc<Committee>) -> Self {
245 Self {
246 committee,
247 stake_maps: Default::default(),
248 }
249 }
250
251 pub fn unique_key_count(&self) -> usize {
252 self.stake_maps.len()
253 }
254
255 pub fn total_votes(&self) -> StakeUnit {
256 self.stake_maps
257 .values()
258 .map(|(_, stake_aggregator)| stake_aggregator.total_votes())
259 .sum()
260 }
261}
262
263impl<K, V, const STRENGTH: bool> MultiStakeAggregator<K, V, STRENGTH>
264where
265 K: Hash + Eq,
266 V: Message + Serialize + Clone,
267{
268 pub fn insert(
269 &mut self,
270 k: K,
271 envelope: Envelope<V, AuthoritySignInfo>,
272 ) -> InsertResult<AuthorityQuorumSignInfo<STRENGTH>> {
273 if let Some(entry) = self.stake_maps.get_mut(&k) {
274 entry.1.insert(envelope)
275 } else {
276 let mut new_entry = StakeAggregator::new(self.committee.clone());
277 let result = new_entry.insert(envelope.clone());
278 if !matches!(result, InsertResult::Failed { .. }) {
279 self.stake_maps.insert(k, (envelope.into_data(), new_entry));
282 }
283 result
284 }
285 }
286}
287
288impl<K, V, const STRENGTH: bool> MultiStakeAggregator<K, V, STRENGTH>
289where
290 K: Clone + Ord,
291{
292 pub fn get_all_unique_values(&self) -> BTreeMap<K, (Vec<AuthorityName>, StakeUnit)> {
293 self.stake_maps
294 .iter()
295 .map(|(k, (_, s))| (k.clone(), (s.data.keys().copied().collect(), s.total_votes)))
296 .collect()
297 }
298}
299
300impl<K, V, const STRENGTH: bool> MultiStakeAggregator<K, V, STRENGTH>
301where
302 K: Hash + Eq,
303{
304 #[expect(dead_code)]
305 pub fn authorities_for_key(&self, k: &K) -> Option<impl Iterator<Item = &AuthorityName>> {
306 self.stake_maps.get(k).map(|(_, agg)| agg.keys())
307 }
308
309 pub fn uncommitted_stake(&self) -> StakeUnit {
312 self.committee.total_votes() - self.total_votes()
313 }
314
315 pub fn plurality_stake(&self) -> StakeUnit {
317 self.stake_maps
318 .values()
319 .map(|(_, agg)| agg.total_votes())
320 .max()
321 .unwrap_or_default()
322 }
323
324 pub fn quorum_unreachable(&self) -> bool {
327 self.uncommitted_stake() + self.plurality_stake() < self.committee.threshold::<STRENGTH>()
328 }
329}
330
331pub struct GenericMultiStakeAggregator<K, const STRENGTH: bool> {
335 committee: Arc<Committee>,
336 stake_maps: HashMap<K, StakeAggregator<(), STRENGTH>>,
337 votes_per_authority: HashMap<AuthorityName, u64>,
338}
339
340impl<K, const STRENGTH: bool> GenericMultiStakeAggregator<K, STRENGTH>
341where
342 K: Hash + Eq,
343{
344 pub fn new(committee: Arc<Committee>) -> Self {
345 Self {
346 committee,
347 stake_maps: Default::default(),
348 votes_per_authority: Default::default(),
349 }
350 }
351
352 pub fn insert(
353 &mut self,
354 authority: AuthorityName,
355 k: K,
356 ) -> InsertResult<&HashMap<AuthorityName, ()>> {
357 let agg = self
358 .stake_maps
359 .entry(k)
360 .or_insert_with(|| StakeAggregator::new(self.committee.clone()));
361
362 if !agg.contains_key(&authority) {
363 *self.votes_per_authority.entry(authority).or_default() += 1;
364 }
365
366 agg.insert_generic(authority, ())
367 }
368
369 pub fn has_quorum_for_key(&self, k: &K) -> bool {
370 if let Some(entry) = self.stake_maps.get(k) {
371 entry.has_quorum()
372 } else {
373 false
374 }
375 }
376
377 pub fn votes_for_authority(&self, authority: AuthorityName) -> u64 {
378 self.votes_per_authority
379 .get(&authority)
380 .copied()
381 .unwrap_or_default()
382 }
383}
384
385#[test]
386fn test_votes_per_authority() {
387 let (committee, _) = Committee::new_simple_test_committee();
388 let authorities: Vec<_> = committee.names().copied().collect();
389
390 let mut agg: GenericMultiStakeAggregator<&str, true> =
391 GenericMultiStakeAggregator::new(Arc::new(committee));
392
393 let key1: &str = "key1";
396 let authority1 = authorities[0];
397 agg.insert(authority1, key1);
398 assert_eq!(agg.votes_for_authority(authority1), 1);
399
400 agg.insert(authority1, key1);
403 agg.insert(authority1, key1);
404 assert_eq!(agg.votes_for_authority(authority1), 1);
405
406 let authority2 = authorities[1];
408 assert_eq!(agg.votes_for_authority(authority2), 0);
409
410 let key2: &str = "key2";
412 agg.insert(authority2, key2);
413 assert_eq!(agg.votes_for_authority(authority2), 1);
414 assert_eq!(agg.votes_for_authority(authority1), 1);
415
416 let key3: &str = "key3";
419 agg.insert(authority1, key3);
420 assert_eq!(agg.votes_for_authority(authority1), 2);
421}