1use std::{sync::Arc, time::Duration};
6
7use tokio::{
8 sync::{
9 oneshot::{Receiver, Sender},
10 watch,
11 },
12 task::JoinHandle,
13 time::{Instant, sleep_until},
14};
15use tracing::{debug, warn};
16
17use crate::{
18 block::Round, context::Context, core::CoreSignalsReceivers, core_thread::CoreThreadDispatcher,
19};
20
21pub(crate) struct LeaderTimeoutTaskHandle {
22 handle: JoinHandle<()>,
23 stop: Sender<()>,
24}
25
26impl LeaderTimeoutTaskHandle {
27 pub async fn stop(self) {
28 self.stop.send(()).ok();
29 self.handle.await.ok();
30 }
31}
32
33pub(crate) struct LeaderTimeoutTask<D: CoreThreadDispatcher> {
34 dispatcher: Arc<D>,
35 new_round_receiver: watch::Receiver<Round>,
36 leader_timeout: Duration,
37 min_round_delay: Duration,
38 stop: Receiver<()>,
39}
40
41impl<D: CoreThreadDispatcher> LeaderTimeoutTask<D> {
42 pub fn start(
45 dispatcher: Arc<D>,
46 signals_receivers: &CoreSignalsReceivers,
47 context: Arc<Context>,
48 ) -> LeaderTimeoutTaskHandle {
49 let (stop_sender, stop) = tokio::sync::oneshot::channel();
50 let mut me = Self {
51 dispatcher,
52 stop,
53 new_round_receiver: signals_receivers.new_round_receiver(),
54 leader_timeout: context.parameters.leader_timeout,
55 min_round_delay: context.parameters.min_round_delay,
56 };
57 let handle = tokio::spawn(async move { me.run().await });
58
59 LeaderTimeoutTaskHandle {
60 handle,
61 stop: stop_sender,
62 }
63 }
64
65 async fn run(&mut self) {
72 let new_round = &mut self.new_round_receiver;
73 let mut leader_round: Round = *new_round.borrow_and_update();
74 let mut min_leader_round_timed_out = false;
75 let mut max_leader_round_timed_out = false;
76 let timer_start = Instant::now();
77 let min_leader_timeout = sleep_until(timer_start + self.min_round_delay);
78 let max_leader_timeout = sleep_until(timer_start + self.leader_timeout);
79
80 tokio::pin!(min_leader_timeout);
81 tokio::pin!(max_leader_timeout);
82
83 loop {
84 tokio::select! {
85 () = &mut min_leader_timeout, if !min_leader_round_timed_out => {
89 if let Err(err) = self.dispatcher.new_block(leader_round, false).await {
90 warn!("Error received while calling dispatcher, probably dispatcher is shutting down, will now exit: {err:?}");
91 return;
92 }
93 min_leader_round_timed_out = true;
94 },
95 () = &mut max_leader_timeout, if !max_leader_round_timed_out => {
102 if let Err(err) = self.dispatcher.new_block(leader_round, true).await {
103 warn!("Error received while calling dispatcher, probably dispatcher is shutting down, will now exit: {err:?}");
104 return;
105 }
106 max_leader_round_timed_out = true;
107 }
108
109 Ok(_) = new_round.changed() => {
111 leader_round = *new_round.borrow_and_update();
112 debug!("New round has been received {leader_round}, resetting timer");
113
114 min_leader_round_timed_out = false;
115 max_leader_round_timed_out = false;
116
117 let now = Instant::now();
118 min_leader_timeout
119 .as_mut()
120 .reset(now + self.min_round_delay);
121 max_leader_timeout
122 .as_mut()
123 .reset(now + self.leader_timeout);
124 },
125 _ = &mut self.stop => {
126 debug!("Stop signal has been received, now shutting down");
127 return;
128 }
129 }
130 }
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use std::{collections::BTreeSet, sync::Arc, time::Duration};
137
138 use async_trait::async_trait;
139 use consensus_config::Parameters;
140 use parking_lot::Mutex;
141 use tokio::time::{Instant, sleep};
142
143 use crate::{
144 block::{BlockRef, Round, VerifiedBlock},
145 commit::CertifiedCommits,
146 context::Context,
147 core::CoreSignals,
148 core_thread::{CoreError, CoreThreadDispatcher},
149 leader_timeout::LeaderTimeoutTask,
150 round_prober::QuorumRound,
151 };
152
153 #[derive(Clone, Default)]
154 struct MockCoreThreadDispatcher {
155 new_block_calls: Arc<Mutex<Vec<(Round, bool, Instant)>>>,
156 }
157
158 impl MockCoreThreadDispatcher {
159 async fn get_new_block_calls(&self) -> Vec<(Round, bool, Instant)> {
160 let mut binding = self.new_block_calls.lock();
161 let all_calls = binding.drain(0..);
162 all_calls.into_iter().collect()
163 }
164 }
165
166 #[async_trait]
167 impl CoreThreadDispatcher for MockCoreThreadDispatcher {
168 async fn add_blocks(
169 &self,
170 _blocks: Vec<VerifiedBlock>,
171 ) -> Result<BTreeSet<BlockRef>, CoreError> {
172 todo!()
173 }
174
175 async fn add_certified_commits(
176 &self,
177 _commits: CertifiedCommits,
178 ) -> Result<BTreeSet<BlockRef>, CoreError> {
179 todo!()
180 }
181
182 async fn check_block_refs(
183 &self,
184 _block_refs: Vec<BlockRef>,
185 ) -> Result<BTreeSet<BlockRef>, CoreError> {
186 todo!()
187 }
188
189 async fn new_block(&self, round: Round, force: bool) -> Result<(), CoreError> {
190 self.new_block_calls
191 .lock()
192 .push((round, force, Instant::now()));
193 Ok(())
194 }
195
196 async fn get_missing_blocks(&self) -> Result<BTreeSet<BlockRef>, CoreError> {
197 todo!()
198 }
199
200 fn set_subscriber_exists(&self, _exists: bool) -> Result<(), CoreError> {
201 todo!()
202 }
203
204 fn set_propagation_delay_and_quorum_rounds(
205 &self,
206 _delay: Round,
207 _received_quorum_rounds: Vec<QuorumRound>,
208 _accepted_quorum_rounds: Vec<QuorumRound>,
209 ) -> Result<(), CoreError> {
210 todo!()
211 }
212
213 fn set_last_known_proposed_round(&self, _round: Round) -> Result<(), CoreError> {
214 todo!()
215 }
216
217 fn highest_received_rounds(&self) -> Vec<Round> {
218 todo!()
219 }
220 }
221
222 #[tokio::test(flavor = "current_thread", start_paused = true)]
223 async fn basic_leader_timeout() {
224 let (context, _signers) = Context::new_for_test(4);
225 let dispatcher = Arc::new(MockCoreThreadDispatcher::default());
226 let leader_timeout = Duration::from_millis(500);
227 let min_round_delay = Duration::from_millis(50);
228 let parameters = Parameters {
229 leader_timeout,
230 min_round_delay,
231 ..Default::default()
232 };
233 let context = Arc::new(context.with_parameters(parameters));
234 let start = Instant::now();
235
236 let (mut signals, signal_receivers) = CoreSignals::new(context.clone());
237
238 let _handle = LeaderTimeoutTask::start(dispatcher.clone(), &signal_receivers, context);
240
241 signals.new_round(10);
243
244 sleep(2 * min_round_delay).await;
247 let all_calls = dispatcher.get_new_block_calls().await;
248 assert_eq!(all_calls.len(), 1);
249
250 let (round, force, timestamp) = all_calls[0];
251 assert_eq!(round, 10);
252 assert!(!force);
253 assert!(
254 min_round_delay <= timestamp - start,
255 "Leader timeout min setting {:?} should be less than actual time difference {:?}",
256 min_round_delay,
257 timestamp - start
258 );
259
260 sleep(2 * leader_timeout).await;
262 let all_calls = dispatcher.get_new_block_calls().await;
263 assert_eq!(all_calls.len(), 1);
264
265 let (round, force, timestamp) = all_calls[0];
266 assert_eq!(round, 10);
267 assert!(force);
268 assert!(
269 leader_timeout <= timestamp - start,
270 "Leader timeout setting {:?} should be less than actual time difference {:?}",
271 leader_timeout,
272 timestamp - start
273 );
274
275 sleep(2 * leader_timeout).await;
277 let all_calls = dispatcher.get_new_block_calls().await;
278
279 assert_eq!(all_calls.len(), 0);
280 }
281
282 #[tokio::test(flavor = "current_thread", start_paused = true)]
283 async fn multiple_leader_timeouts() {
284 let (context, _signers) = Context::new_for_test(4);
285 let dispatcher = Arc::new(MockCoreThreadDispatcher::default());
286 let leader_timeout = Duration::from_millis(500);
287 let min_round_delay = Duration::from_millis(50);
288 let parameters = Parameters {
289 leader_timeout,
290 min_round_delay,
291 ..Default::default()
292 };
293 let context = Arc::new(context.with_parameters(parameters));
294 let now = Instant::now();
295
296 let (mut signals, signal_receivers) = CoreSignals::new(context.clone());
297
298 let _handle = LeaderTimeoutTask::start(dispatcher.clone(), &signal_receivers, context);
300
301 signals.new_round(13);
304 sleep(min_round_delay / 2).await;
305 signals.new_round(14);
306 sleep(min_round_delay / 2).await;
307 signals.new_round(15);
308 sleep(2 * leader_timeout).await;
309
310 let all_calls = dispatcher.get_new_block_calls().await;
312 let (round, force, timestamp) = all_calls[0];
313 assert_eq!(round, 15);
314 assert!(!force);
315 assert!(min_round_delay < timestamp - now);
316
317 let (round, force, timestamp) = all_calls[1];
318 assert_eq!(round, 15);
319 assert!(force);
320 assert!(leader_timeout < timestamp - now);
321 }
322}