1use std::{sync::Arc, time::Duration};
6
7use tokio::{
8 sync::{
9 oneshot::{Receiver, Sender},
10 watch,
11 },
12 task::JoinHandle,
13 time::{Instant, sleep_until},
14};
15use tracing::{debug, warn};
16
17use crate::{
18 block::Round, context::Context, core::CoreSignalsReceivers, core_thread::CoreThreadDispatcher,
19};
20
21pub(crate) struct LeaderTimeoutTaskHandle {
22 handle: JoinHandle<()>,
23 stop: Sender<()>,
24}
25
26impl LeaderTimeoutTaskHandle {
27 pub async fn stop(self) {
28 self.stop.send(()).ok();
29 self.handle.await.ok();
30 }
31}
32
33pub(crate) struct LeaderTimeoutTask<D: CoreThreadDispatcher> {
34 dispatcher: Arc<D>,
35 new_round_receiver: watch::Receiver<Round>,
36 leader_timeout: Duration,
37 min_round_delay: Duration,
38 stop: Receiver<()>,
39}
40
41impl<D: CoreThreadDispatcher> LeaderTimeoutTask<D> {
42 pub fn start(
45 dispatcher: Arc<D>,
46 signals_receivers: &CoreSignalsReceivers,
47 context: Arc<Context>,
48 ) -> LeaderTimeoutTaskHandle {
49 let (stop_sender, stop) = tokio::sync::oneshot::channel();
50 let mut me = Self {
51 dispatcher,
52 stop,
53 new_round_receiver: signals_receivers.new_round_receiver(),
54 leader_timeout: context.parameters.leader_timeout,
55 min_round_delay: context.parameters.min_round_delay,
56 };
57 let handle = tokio::spawn(async move { me.run().await });
58
59 LeaderTimeoutTaskHandle {
60 handle,
61 stop: stop_sender,
62 }
63 }
64
65 async fn run(&mut self) {
72 let new_round = &mut self.new_round_receiver;
73 let mut leader_round: Round = *new_round.borrow_and_update();
74 let mut min_leader_round_timed_out = false;
75 let mut max_leader_round_timed_out = false;
76 let timer_start = Instant::now();
77 let min_leader_timeout = sleep_until(timer_start + self.min_round_delay);
78 let max_leader_timeout = sleep_until(timer_start + self.leader_timeout);
79
80 tokio::pin!(min_leader_timeout);
81 tokio::pin!(max_leader_timeout);
82
83 loop {
84 tokio::select! {
85 () = &mut min_leader_timeout, if !min_leader_round_timed_out => {
89 if let Err(err) = self.dispatcher.new_block(leader_round, false).await {
90 warn!("Error received while calling dispatcher, probably dispatcher is shutting down, will now exit: {err:?}");
91 return;
92 }
93 min_leader_round_timed_out = true;
94 },
95 () = &mut max_leader_timeout, if !max_leader_round_timed_out => {
102 if let Err(err) = self.dispatcher.new_block(leader_round, true).await {
103 warn!("Error received while calling dispatcher, probably dispatcher is shutting down, will now exit: {err:?}");
104 return;
105 }
106 max_leader_round_timed_out = true;
107 }
108
109 Ok(_) = new_round.changed() => {
111 leader_round = *new_round.borrow_and_update();
112 debug!("New round has been received {leader_round}, resetting timer");
113 let _span = tracing::trace_span!("new_consensus_round_received", round = ?leader_round).entered();
114
115 min_leader_round_timed_out = false;
116 max_leader_round_timed_out = false;
117
118 let now = Instant::now();
119 min_leader_timeout
120 .as_mut()
121 .reset(now + self.min_round_delay);
122 max_leader_timeout
123 .as_mut()
124 .reset(now + self.leader_timeout);
125 },
126 _ = &mut self.stop => {
127 debug!("Stop signal has been received, now shutting down");
128 return;
129 }
130 }
131 }
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use std::{
138 collections::{BTreeMap, BTreeSet},
139 sync::Arc,
140 time::Duration,
141 };
142
143 use async_trait::async_trait;
144 use consensus_config::{AuthorityIndex, Parameters};
145 use parking_lot::Mutex;
146 use tokio::time::{Instant, sleep};
147
148 use crate::{
149 block::{BlockRef, Round, VerifiedBlock},
150 commit::CertifiedCommits,
151 context::Context,
152 core::CoreSignals,
153 core_thread::{CoreError, CoreThreadDispatcher},
154 leader_timeout::LeaderTimeoutTask,
155 round_prober::QuorumRound,
156 };
157
158 #[derive(Clone, Default)]
159 struct MockCoreThreadDispatcher {
160 new_block_calls: Arc<Mutex<Vec<(Round, bool, Instant)>>>,
161 }
162
163 impl MockCoreThreadDispatcher {
164 async fn get_new_block_calls(&self) -> Vec<(Round, bool, Instant)> {
165 let mut binding = self.new_block_calls.lock();
166 let all_calls = binding.drain(0..);
167 all_calls.into_iter().collect()
168 }
169 }
170
171 #[async_trait]
172 impl CoreThreadDispatcher for MockCoreThreadDispatcher {
173 async fn add_blocks(
174 &self,
175 _blocks: Vec<VerifiedBlock>,
176 ) -> Result<BTreeSet<BlockRef>, CoreError> {
177 todo!()
178 }
179
180 async fn add_certified_commits(
181 &self,
182 _commits: CertifiedCommits,
183 ) -> Result<BTreeSet<BlockRef>, CoreError> {
184 todo!()
185 }
186
187 async fn check_block_refs(
188 &self,
189 _block_refs: Vec<BlockRef>,
190 ) -> Result<BTreeSet<BlockRef>, CoreError> {
191 todo!()
192 }
193
194 async fn new_block(&self, round: Round, force: bool) -> Result<(), CoreError> {
195 self.new_block_calls
196 .lock()
197 .push((round, force, Instant::now()));
198 Ok(())
199 }
200
201 async fn get_missing_blocks(
202 &self,
203 ) -> Result<BTreeMap<BlockRef, BTreeSet<AuthorityIndex>>, CoreError> {
204 todo!()
205 }
206
207 fn set_quorum_subscribers_exists(&self, _exists: bool) -> Result<(), CoreError> {
208 todo!()
209 }
210
211 fn set_propagation_delay_and_quorum_rounds(
212 &self,
213 _delay: Round,
214 _received_quorum_rounds: Vec<QuorumRound>,
215 _accepted_quorum_rounds: Vec<QuorumRound>,
216 ) -> Result<(), CoreError> {
217 todo!()
218 }
219
220 fn set_last_known_proposed_round(&self, _round: Round) -> Result<(), CoreError> {
221 todo!()
222 }
223
224 fn highest_received_rounds(&self) -> Vec<Round> {
225 todo!()
226 }
227 }
228
229 #[tokio::test(flavor = "current_thread", start_paused = true)]
230 async fn basic_leader_timeout() {
231 let (context, _signers) = Context::new_for_test(4);
232 let dispatcher = Arc::new(MockCoreThreadDispatcher::default());
233 let leader_timeout = Duration::from_millis(500);
234 let min_round_delay = Duration::from_millis(50);
235 let parameters = Parameters {
236 leader_timeout,
237 min_round_delay,
238 ..Default::default()
239 };
240 let context = Arc::new(context.with_parameters(parameters));
241 let start = Instant::now();
242
243 let (mut signals, signal_receivers) = CoreSignals::new(context.clone());
244
245 let _handle = LeaderTimeoutTask::start(dispatcher.clone(), &signal_receivers, context);
247
248 signals.new_round(10);
250
251 sleep(2 * min_round_delay).await;
254 let all_calls = dispatcher.get_new_block_calls().await;
255 assert_eq!(all_calls.len(), 1);
256
257 let (round, force, timestamp) = all_calls[0];
258 assert_eq!(round, 10);
259 assert!(!force);
260 assert!(
261 min_round_delay <= timestamp - start,
262 "Leader timeout min setting {:?} should be less than actual time difference {:?}",
263 min_round_delay,
264 timestamp - start
265 );
266
267 sleep(2 * leader_timeout).await;
269 let all_calls = dispatcher.get_new_block_calls().await;
270 assert_eq!(all_calls.len(), 1);
271
272 let (round, force, timestamp) = all_calls[0];
273 assert_eq!(round, 10);
274 assert!(force);
275 assert!(
276 leader_timeout <= timestamp - start,
277 "Leader timeout setting {:?} should be less than actual time difference {:?}",
278 leader_timeout,
279 timestamp - start
280 );
281
282 sleep(2 * leader_timeout).await;
284 let all_calls = dispatcher.get_new_block_calls().await;
285
286 assert_eq!(all_calls.len(), 0);
287 }
288
289 #[tokio::test(flavor = "current_thread", start_paused = true)]
290 async fn multiple_leader_timeouts() {
291 let (context, _signers) = Context::new_for_test(4);
292 let dispatcher = Arc::new(MockCoreThreadDispatcher::default());
293 let leader_timeout = Duration::from_millis(500);
294 let min_round_delay = Duration::from_millis(50);
295 let parameters = Parameters {
296 leader_timeout,
297 min_round_delay,
298 ..Default::default()
299 };
300 let context = Arc::new(context.with_parameters(parameters));
301 let now = Instant::now();
302
303 let (mut signals, signal_receivers) = CoreSignals::new(context.clone());
304
305 let _handle = LeaderTimeoutTask::start(dispatcher.clone(), &signal_receivers, context);
307
308 signals.new_round(13);
311 sleep(min_round_delay / 2).await;
312 signals.new_round(14);
313 sleep(min_round_delay / 2).await;
314 signals.new_round(15);
315 sleep(2 * leader_timeout).await;
316
317 let all_calls = dispatcher.get_new_block_calls().await;
319 let (round, force, timestamp) = all_calls[0];
320 assert_eq!(round, 15);
321 assert!(!force);
322 assert!(min_round_delay < timestamp - now);
323
324 let (round, force, timestamp) = all_calls[1];
325 assert_eq!(round, 15);
326 assert!(force);
327 assert!(leader_timeout < timestamp - now);
328 }
329}