1use std::{
6 collections::{BTreeMap, HashMap, hash_map::Entry},
7 hash::Hash,
8 sync::Arc,
9};
10
11use iota_sdk_types::crypto::Intent;
12use iota_types::{
13 base_types::{AuthorityName, ConciseableName},
14 committee::{Committee, CommitteeTrait, StakeUnit},
15 crypto::{AuthorityQuorumSignInfo, AuthoritySignInfo, AuthoritySignInfoTrait},
16 error::{IotaError, IotaResult},
17 message_envelope::{Envelope, Message},
18};
19use serde::Serialize;
20use tracing::warn;
21use typed_store::TypedStoreError;
22
23#[derive(Debug)]
27pub struct StakeAggregator<S, const STRENGTH: bool> {
28 data: HashMap<AuthorityName, S>,
29 total_votes: StakeUnit,
30 committee: Arc<Committee>,
31}
32
33impl<S: Clone + Eq, const STRENGTH: bool> StakeAggregator<S, STRENGTH> {
40 pub fn new(committee: Arc<Committee>) -> Self {
41 Self {
42 data: Default::default(),
43 total_votes: Default::default(),
44 committee,
45 }
46 }
47
48 pub fn from_iter<I: Iterator<Item = Result<(AuthorityName, S), TypedStoreError>>>(
49 committee: Arc<Committee>,
50 data: I,
51 ) -> IotaResult<Self> {
52 let mut this = Self::new(committee);
53 for item in data {
54 let (authority, s) = item?;
55 this.insert_generic(authority, s);
56 }
57 Ok(this)
58 }
59
60 pub fn insert_generic(
67 &mut self,
68 authority: AuthorityName,
69 s: S,
70 ) -> InsertResult<&HashMap<AuthorityName, S>> {
71 match self.data.entry(authority) {
72 Entry::Occupied(oc) => {
73 return InsertResult::Failed {
74 error: IotaError::StakeAggregatorRepeatedSigner {
75 signer: authority,
76 conflicting_sig: oc.get() != &s,
77 },
78 };
79 }
80 Entry::Vacant(va) => {
81 va.insert(s);
82 }
83 }
84 let votes = self.committee.weight(&authority);
85 if votes > 0 {
86 self.total_votes += votes;
87 if self.total_votes >= self.committee.threshold::<STRENGTH>() {
88 InsertResult::QuorumReached(&self.data)
89 } else {
90 InsertResult::NotEnoughVotes {
91 bad_votes: 0,
92 bad_authorities: vec![],
93 }
94 }
95 } else {
96 InsertResult::Failed {
97 error: IotaError::InvalidAuthenticator,
98 }
99 }
100 }
101
102 pub fn contains_key(&self, authority: &AuthorityName) -> bool {
103 self.data.contains_key(authority)
104 }
105
106 pub fn keys(&self) -> impl Iterator<Item = &AuthorityName> {
107 self.data.keys()
108 }
109
110 pub fn committee(&self) -> &Committee {
111 &self.committee
112 }
113
114 pub fn total_votes(&self) -> StakeUnit {
115 self.total_votes
116 }
117
118 pub fn has_quorum(&self) -> bool {
119 self.total_votes >= self.committee.threshold::<STRENGTH>()
120 }
121
122 pub fn validator_sig_count(&self) -> usize {
123 self.data.len()
124 }
125}
126
127impl<const STRENGTH: bool> StakeAggregator<AuthoritySignInfo, STRENGTH> {
128 pub fn insert<T: Message + Serialize>(
133 &mut self,
134 envelope: Envelope<T, AuthoritySignInfo>,
135 ) -> InsertResult<AuthorityQuorumSignInfo<STRENGTH>> {
136 let (data, sig) = envelope.into_data_and_sig();
137 if self.committee.epoch != sig.epoch {
138 return InsertResult::Failed {
139 error: IotaError::WrongEpoch {
140 expected_epoch: self.committee.epoch,
141 actual_epoch: sig.epoch,
142 },
143 };
144 }
145 match self.insert_generic(sig.authority, sig) {
146 InsertResult::QuorumReached(_) => {
147 match AuthorityQuorumSignInfo::<STRENGTH>::new_from_auth_sign_infos(
148 self.data.values().cloned().collect(),
149 self.committee(),
150 ) {
151 Ok(aggregated) => {
152 match aggregated.verify_secure(
153 &data,
154 Intent::iota_app(T::SCOPE),
155 self.committee(),
156 ) {
157 Ok(_) => InsertResult::QuorumReached(aggregated),
160 Err(_) => {
161 let mut bad_votes = 0;
173 let mut bad_authorities = vec![];
174 for (name, sig) in &self.data.clone() {
175 if let Err(err) = sig.verify_secure(
176 &data,
177 Intent::iota_app(T::SCOPE),
178 self.committee(),
179 ) {
180 warn!(name=?name.concise(), "Bad stake from validator: {:?}", err);
187 self.data.remove(name);
188 let votes = self.committee.weight(name);
189 self.total_votes -= votes;
190 bad_votes += votes;
191 bad_authorities.push(*name);
192 }
193 }
194 InsertResult::NotEnoughVotes {
195 bad_votes,
196 bad_authorities,
197 }
198 }
199 }
200 }
201 Err(error) => InsertResult::Failed { error },
202 }
203 }
204 InsertResult::Failed { error } => InsertResult::Failed { error },
206 InsertResult::NotEnoughVotes {
207 bad_votes,
208 bad_authorities,
209 } => InsertResult::NotEnoughVotes {
210 bad_votes,
211 bad_authorities,
212 },
213 }
214 }
215}
216
217pub enum InsertResult<CertT> {
218 QuorumReached(CertT),
219 Failed {
220 error: IotaError,
221 },
222 NotEnoughVotes {
223 bad_votes: u64,
224 bad_authorities: Vec<AuthorityName>,
225 },
226}
227
228impl<CertT> InsertResult<CertT> {
229 pub fn is_quorum_reached(&self) -> bool {
230 matches!(self, Self::QuorumReached(..))
231 }
232}
233
234#[derive(Debug)]
240pub struct MultiStakeAggregator<K, V, const STRENGTH: bool> {
241 committee: Arc<Committee>,
242 stake_maps: HashMap<K, (V, StakeAggregator<AuthoritySignInfo, STRENGTH>)>,
243}
244
245impl<K, V, const STRENGTH: bool> MultiStakeAggregator<K, V, STRENGTH> {
246 pub fn new(committee: Arc<Committee>) -> Self {
247 Self {
248 committee,
249 stake_maps: Default::default(),
250 }
251 }
252
253 pub fn unique_key_count(&self) -> usize {
254 self.stake_maps.len()
255 }
256
257 pub fn total_votes(&self) -> StakeUnit {
258 self.stake_maps
259 .values()
260 .map(|(_, stake_aggregator)| stake_aggregator.total_votes())
261 .sum()
262 }
263}
264
265impl<K, V, const STRENGTH: bool> MultiStakeAggregator<K, V, STRENGTH>
266where
267 K: Hash + Eq,
268 V: Message + Serialize + Clone,
269{
270 pub fn insert(
271 &mut self,
272 k: K,
273 envelope: Envelope<V, AuthoritySignInfo>,
274 ) -> InsertResult<AuthorityQuorumSignInfo<STRENGTH>> {
275 if let Some(entry) = self.stake_maps.get_mut(&k) {
276 entry.1.insert(envelope)
277 } else {
278 let mut new_entry = StakeAggregator::new(self.committee.clone());
279 let result = new_entry.insert(envelope.clone());
280 if !matches!(result, InsertResult::Failed { .. }) {
281 self.stake_maps.insert(k, (envelope.into_data(), new_entry));
284 }
285 result
286 }
287 }
288}
289
290impl<K, V, const STRENGTH: bool> MultiStakeAggregator<K, V, STRENGTH>
291where
292 K: Clone + Ord,
293{
294 pub fn get_all_unique_values(&self) -> BTreeMap<K, (Vec<AuthorityName>, StakeUnit)> {
295 self.stake_maps
296 .iter()
297 .map(|(k, (_, s))| (k.clone(), (s.data.keys().copied().collect(), s.total_votes)))
298 .collect()
299 }
300}
301
302impl<K, V, const STRENGTH: bool> MultiStakeAggregator<K, V, STRENGTH>
303where
304 K: Hash + Eq,
305{
306 #[expect(dead_code)]
307 pub fn authorities_for_key(&self, k: &K) -> Option<impl Iterator<Item = &AuthorityName>> {
308 self.stake_maps.get(k).map(|(_, agg)| agg.keys())
309 }
310
311 pub fn uncommitted_stake(&self) -> StakeUnit {
314 self.committee.total_votes() - self.total_votes()
315 }
316
317 pub fn plurality_stake(&self) -> StakeUnit {
319 self.stake_maps
320 .values()
321 .map(|(_, agg)| agg.total_votes())
322 .max()
323 .unwrap_or_default()
324 }
325
326 pub fn quorum_unreachable(&self) -> bool {
329 self.uncommitted_stake() + self.plurality_stake() < self.committee.threshold::<STRENGTH>()
330 }
331}
332
333pub struct GenericMultiStakeAggregator<K, const STRENGTH: bool> {
337 committee: Arc<Committee>,
338 stake_maps: HashMap<K, StakeAggregator<(), STRENGTH>>,
339 votes_per_authority: HashMap<AuthorityName, u64>,
340}
341
342impl<K, const STRENGTH: bool> GenericMultiStakeAggregator<K, STRENGTH>
343where
344 K: Hash + Eq,
345{
346 pub fn new(committee: Arc<Committee>) -> Self {
347 Self {
348 committee,
349 stake_maps: Default::default(),
350 votes_per_authority: Default::default(),
351 }
352 }
353
354 pub fn insert(
355 &mut self,
356 authority: AuthorityName,
357 k: K,
358 ) -> InsertResult<&HashMap<AuthorityName, ()>> {
359 let agg = self
360 .stake_maps
361 .entry(k)
362 .or_insert_with(|| StakeAggregator::new(self.committee.clone()));
363
364 if !agg.contains_key(&authority) {
365 *self.votes_per_authority.entry(authority).or_default() += 1;
366 }
367
368 agg.insert_generic(authority, ())
369 }
370
371 pub fn has_quorum_for_key(&self, k: &K) -> bool {
372 if let Some(entry) = self.stake_maps.get(k) {
373 entry.has_quorum()
374 } else {
375 false
376 }
377 }
378
379 pub fn votes_for_authority(&self, authority: AuthorityName) -> u64 {
380 self.votes_per_authority
381 .get(&authority)
382 .copied()
383 .unwrap_or_default()
384 }
385}
386
387#[test]
388fn test_votes_per_authority() {
389 let (committee, _) = Committee::new_simple_test_committee();
390 let authorities: Vec<_> = committee.names().copied().collect();
391
392 let mut agg: GenericMultiStakeAggregator<&str, true> =
393 GenericMultiStakeAggregator::new(Arc::new(committee));
394
395 let key1: &str = "key1";
398 let authority1 = authorities[0];
399 agg.insert(authority1, key1);
400 assert_eq!(agg.votes_for_authority(authority1), 1);
401
402 agg.insert(authority1, key1);
405 agg.insert(authority1, key1);
406 assert_eq!(agg.votes_for_authority(authority1), 1);
407
408 let authority2 = authorities[1];
410 assert_eq!(agg.votes_for_authority(authority2), 0);
411
412 let key2: &str = "key2";
414 agg.insert(authority2, key2);
415 assert_eq!(agg.votes_for_authority(authority2), 1);
416 assert_eq!(agg.votes_for_authority(authority1), 1);
417
418 let key3: &str = "key3";
421 agg.insert(authority1, key3);
422 assert_eq!(agg.votes_for_authority(authority1), 2);
423}