1use std::{sync::Arc, time::Duration};
6
7use tokio::{
8    sync::{
9        oneshot::{Receiver, Sender},
10        watch,
11    },
12    task::JoinHandle,
13    time::{Instant, sleep_until},
14};
15use tracing::{debug, warn};
16
17use crate::{
18    block::Round, context::Context, core::CoreSignalsReceivers, core_thread::CoreThreadDispatcher,
19};
20
21pub(crate) struct LeaderTimeoutTaskHandle {
22    handle: JoinHandle<()>,
23    stop: Sender<()>,
24}
25
26impl LeaderTimeoutTaskHandle {
27    pub async fn stop(self) {
28        self.stop.send(()).ok();
29        self.handle.await.ok();
30    }
31}
32
33pub(crate) struct LeaderTimeoutTask<D: CoreThreadDispatcher> {
34    dispatcher: Arc<D>,
35    new_round_receiver: watch::Receiver<Round>,
36    leader_timeout: Duration,
37    min_round_delay: Duration,
38    stop: Receiver<()>,
39}
40
41impl<D: CoreThreadDispatcher> LeaderTimeoutTask<D> {
42    pub fn start(
45        dispatcher: Arc<D>,
46        signals_receivers: &CoreSignalsReceivers,
47        context: Arc<Context>,
48    ) -> LeaderTimeoutTaskHandle {
49        let (stop_sender, stop) = tokio::sync::oneshot::channel();
50        let mut me = Self {
51            dispatcher,
52            stop,
53            new_round_receiver: signals_receivers.new_round_receiver(),
54            leader_timeout: context.parameters.leader_timeout,
55            min_round_delay: context.parameters.min_round_delay,
56        };
57        let handle = tokio::spawn(async move { me.run().await });
58
59        LeaderTimeoutTaskHandle {
60            handle,
61            stop: stop_sender,
62        }
63    }
64
65    async fn run(&mut self) {
72        let new_round = &mut self.new_round_receiver;
73        let mut leader_round: Round = *new_round.borrow_and_update();
74        let mut min_leader_round_timed_out = false;
75        let mut max_leader_round_timed_out = false;
76        let timer_start = Instant::now();
77        let min_leader_timeout = sleep_until(timer_start + self.min_round_delay);
78        let max_leader_timeout = sleep_until(timer_start + self.leader_timeout);
79
80        tokio::pin!(min_leader_timeout);
81        tokio::pin!(max_leader_timeout);
82
83        loop {
84            tokio::select! {
85                () = &mut min_leader_timeout, if !min_leader_round_timed_out => {
89                    if let Err(err) = self.dispatcher.new_block(leader_round, false).await {
90                        warn!("Error received while calling dispatcher, probably dispatcher is shutting down, will now exit: {err:?}");
91                        return;
92                    }
93                    min_leader_round_timed_out = true;
94                },
95                () = &mut max_leader_timeout, if !max_leader_round_timed_out => {
102                    if let Err(err) = self.dispatcher.new_block(leader_round, true).await {
103                        warn!("Error received while calling dispatcher, probably dispatcher is shutting down, will now exit: {err:?}");
104                        return;
105                    }
106                    max_leader_round_timed_out = true;
107                }
108
109                Ok(_) = new_round.changed() => {
111                    leader_round = *new_round.borrow_and_update();
112                    debug!("New round has been received {leader_round}, resetting timer");
113                    let _span = tracing::trace_span!("new_consensus_round_received", round = ?leader_round).entered();
114
115                    min_leader_round_timed_out = false;
116                    max_leader_round_timed_out = false;
117
118                    let now = Instant::now();
119                    min_leader_timeout
120                    .as_mut()
121                    .reset(now + self.min_round_delay);
122                    max_leader_timeout
123                    .as_mut()
124                    .reset(now + self.leader_timeout);
125                },
126                _ = &mut self.stop => {
127                    debug!("Stop signal has been received, now shutting down");
128                    return;
129                }
130            }
131        }
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use std::{
138        collections::{BTreeMap, BTreeSet},
139        sync::Arc,
140        time::Duration,
141    };
142
143    use async_trait::async_trait;
144    use consensus_config::{AuthorityIndex, Parameters};
145    use parking_lot::Mutex;
146    use tokio::time::{Instant, sleep};
147
148    use crate::{
149        block::{BlockRef, Round, VerifiedBlock},
150        commit::CertifiedCommits,
151        context::Context,
152        core::CoreSignals,
153        core_thread::{CoreError, CoreThreadDispatcher},
154        leader_timeout::LeaderTimeoutTask,
155        round_prober::QuorumRound,
156    };
157
158    #[derive(Clone, Default)]
159    struct MockCoreThreadDispatcher {
160        new_block_calls: Arc<Mutex<Vec<(Round, bool, Instant)>>>,
161    }
162
163    impl MockCoreThreadDispatcher {
164        async fn get_new_block_calls(&self) -> Vec<(Round, bool, Instant)> {
165            let mut binding = self.new_block_calls.lock();
166            let all_calls = binding.drain(0..);
167            all_calls.into_iter().collect()
168        }
169    }
170
171    #[async_trait]
172    impl CoreThreadDispatcher for MockCoreThreadDispatcher {
173        async fn add_blocks(
174            &self,
175            _blocks: Vec<VerifiedBlock>,
176        ) -> Result<BTreeSet<BlockRef>, CoreError> {
177            todo!()
178        }
179
180        async fn add_certified_commits(
181            &self,
182            _commits: CertifiedCommits,
183        ) -> Result<BTreeSet<BlockRef>, CoreError> {
184            todo!()
185        }
186
187        async fn check_block_refs(
188            &self,
189            _block_refs: Vec<BlockRef>,
190        ) -> Result<BTreeSet<BlockRef>, CoreError> {
191            todo!()
192        }
193
194        async fn new_block(&self, round: Round, force: bool) -> Result<(), CoreError> {
195            self.new_block_calls
196                .lock()
197                .push((round, force, Instant::now()));
198            Ok(())
199        }
200
201        async fn get_missing_blocks(
202            &self,
203        ) -> Result<BTreeMap<BlockRef, BTreeSet<AuthorityIndex>>, CoreError> {
204            todo!()
205        }
206
207        fn set_quorum_subscribers_exists(&self, _exists: bool) -> Result<(), CoreError> {
208            todo!()
209        }
210
211        fn set_propagation_delay_and_quorum_rounds(
212            &self,
213            _delay: Round,
214            _received_quorum_rounds: Vec<QuorumRound>,
215            _accepted_quorum_rounds: Vec<QuorumRound>,
216        ) -> Result<(), CoreError> {
217            todo!()
218        }
219
220        fn set_last_known_proposed_round(&self, _round: Round) -> Result<(), CoreError> {
221            todo!()
222        }
223
224        fn highest_received_rounds(&self) -> Vec<Round> {
225            todo!()
226        }
227    }
228
229    #[tokio::test(flavor = "current_thread", start_paused = true)]
230    async fn basic_leader_timeout() {
231        let (context, _signers) = Context::new_for_test(4);
232        let dispatcher = Arc::new(MockCoreThreadDispatcher::default());
233        let leader_timeout = Duration::from_millis(500);
234        let min_round_delay = Duration::from_millis(50);
235        let parameters = Parameters {
236            leader_timeout,
237            min_round_delay,
238            ..Default::default()
239        };
240        let context = Arc::new(context.with_parameters(parameters));
241        let start = Instant::now();
242
243        let (mut signals, signal_receivers) = CoreSignals::new(context.clone());
244
245        let _handle = LeaderTimeoutTask::start(dispatcher.clone(), &signal_receivers, context);
247
248        signals.new_round(10);
250
251        sleep(2 * min_round_delay).await;
254        let all_calls = dispatcher.get_new_block_calls().await;
255        assert_eq!(all_calls.len(), 1);
256
257        let (round, force, timestamp) = all_calls[0];
258        assert_eq!(round, 10);
259        assert!(!force);
260        assert!(
261            min_round_delay <= timestamp - start,
262            "Leader timeout min setting {:?} should be less than actual time difference {:?}",
263            min_round_delay,
264            timestamp - start
265        );
266
267        sleep(2 * leader_timeout).await;
269        let all_calls = dispatcher.get_new_block_calls().await;
270        assert_eq!(all_calls.len(), 1);
271
272        let (round, force, timestamp) = all_calls[0];
273        assert_eq!(round, 10);
274        assert!(force);
275        assert!(
276            leader_timeout <= timestamp - start,
277            "Leader timeout setting {:?} should be less than actual time difference {:?}",
278            leader_timeout,
279            timestamp - start
280        );
281
282        sleep(2 * leader_timeout).await;
284        let all_calls = dispatcher.get_new_block_calls().await;
285
286        assert_eq!(all_calls.len(), 0);
287    }
288
289    #[tokio::test(flavor = "current_thread", start_paused = true)]
290    async fn multiple_leader_timeouts() {
291        let (context, _signers) = Context::new_for_test(4);
292        let dispatcher = Arc::new(MockCoreThreadDispatcher::default());
293        let leader_timeout = Duration::from_millis(500);
294        let min_round_delay = Duration::from_millis(50);
295        let parameters = Parameters {
296            leader_timeout,
297            min_round_delay,
298            ..Default::default()
299        };
300        let context = Arc::new(context.with_parameters(parameters));
301        let now = Instant::now();
302
303        let (mut signals, signal_receivers) = CoreSignals::new(context.clone());
304
305        let _handle = LeaderTimeoutTask::start(dispatcher.clone(), &signal_receivers, context);
307
308        signals.new_round(13);
311        sleep(min_round_delay / 2).await;
312        signals.new_round(14);
313        sleep(min_round_delay / 2).await;
314        signals.new_round(15);
315        sleep(2 * leader_timeout).await;
316
317        let all_calls = dispatcher.get_new_block_calls().await;
319        let (round, force, timestamp) = all_calls[0];
320        assert_eq!(round, 15);
321        assert!(!force);
322        assert!(min_round_delay < timestamp - now);
323
324        let (round, force, timestamp) = all_calls[1];
325        assert_eq!(round, 15);
326        assert!(force);
327        assert!(leader_timeout < timestamp - now);
328    }
329}