1use std::{sync::Arc, time::Duration};
6
7use tokio::{
8 sync::{
9 oneshot::{Receiver, Sender},
10 watch,
11 },
12 task::JoinHandle,
13 time::{Instant, sleep_until},
14};
15use tracing::{debug, warn};
16
17use crate::{
18 block::Round, context::Context, core::CoreSignalsReceivers, core_thread::CoreThreadDispatcher,
19};
20
21pub(crate) struct LeaderTimeoutTaskHandle {
22 handle: JoinHandle<()>,
23 stop: Sender<()>,
24}
25
26impl LeaderTimeoutTaskHandle {
27 pub async fn stop(self) {
28 self.stop.send(()).ok();
29 self.handle.await.ok();
30 }
31}
32
33pub(crate) struct LeaderTimeoutTask<D: CoreThreadDispatcher> {
34 dispatcher: Arc<D>,
35 new_round_receiver: watch::Receiver<Round>,
36 leader_timeout: Duration,
37 min_round_delay: Duration,
38 stop: Receiver<()>,
39}
40
41impl<D: CoreThreadDispatcher> LeaderTimeoutTask<D> {
42 pub fn start(
45 dispatcher: Arc<D>,
46 signals_receivers: &CoreSignalsReceivers,
47 context: Arc<Context>,
48 ) -> LeaderTimeoutTaskHandle {
49 let (stop_sender, stop) = tokio::sync::oneshot::channel();
50 let mut me = Self {
51 dispatcher,
52 stop,
53 new_round_receiver: signals_receivers.new_round_receiver(),
54 leader_timeout: context.parameters.leader_timeout,
55 min_round_delay: context.parameters.min_round_delay,
56 };
57 let handle = tokio::spawn(async move { me.run().await });
58
59 LeaderTimeoutTaskHandle {
60 handle,
61 stop: stop_sender,
62 }
63 }
64
65 async fn run(&mut self) {
72 let new_round = &mut self.new_round_receiver;
73 let mut leader_round: Round = *new_round.borrow_and_update();
74 let mut min_leader_round_timed_out = false;
75 let mut max_leader_round_timed_out = false;
76 let timer_start = Instant::now();
77 let min_leader_timeout = sleep_until(timer_start + self.min_round_delay);
78 let max_leader_timeout = sleep_until(timer_start + self.leader_timeout);
79
80 tokio::pin!(min_leader_timeout);
81 tokio::pin!(max_leader_timeout);
82
83 loop {
84 tokio::select! {
85 () = &mut min_leader_timeout, if !min_leader_round_timed_out => {
89 if let Err(err) = self.dispatcher.new_block(leader_round, false).await {
90 warn!("Error received while calling dispatcher, probably dispatcher is shutting down, will now exit: {err:?}");
91 return;
92 }
93 min_leader_round_timed_out = true;
94 },
95 () = &mut max_leader_timeout, if !max_leader_round_timed_out => {
102 if let Err(err) = self.dispatcher.new_block(leader_round, true).await {
103 warn!("Error received while calling dispatcher, probably dispatcher is shutting down, will now exit: {err:?}");
104 return;
105 }
106 max_leader_round_timed_out = true;
107 }
108
109 Ok(_) = new_round.changed() => {
111 leader_round = *new_round.borrow_and_update();
112 debug!("New round has been received {leader_round}, resetting timer");
113
114 min_leader_round_timed_out = false;
115 max_leader_round_timed_out = false;
116
117 let now = Instant::now();
118 min_leader_timeout
119 .as_mut()
120 .reset(now + self.min_round_delay);
121 max_leader_timeout
122 .as_mut()
123 .reset(now + self.leader_timeout);
124 },
125 _ = &mut self.stop => {
126 debug!("Stop signal has been received, now shutting down");
127 return;
128 }
129 }
130 }
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use std::{
137 collections::{BTreeMap, BTreeSet},
138 sync::Arc,
139 time::Duration,
140 };
141
142 use async_trait::async_trait;
143 use consensus_config::{AuthorityIndex, Parameters};
144 use parking_lot::Mutex;
145 use tokio::time::{Instant, sleep};
146
147 use crate::{
148 block::{BlockRef, Round, VerifiedBlock},
149 commit::CertifiedCommits,
150 context::Context,
151 core::CoreSignals,
152 core_thread::{CoreError, CoreThreadDispatcher},
153 leader_timeout::LeaderTimeoutTask,
154 round_prober::QuorumRound,
155 };
156
157 #[derive(Clone, Default)]
158 struct MockCoreThreadDispatcher {
159 new_block_calls: Arc<Mutex<Vec<(Round, bool, Instant)>>>,
160 }
161
162 impl MockCoreThreadDispatcher {
163 async fn get_new_block_calls(&self) -> Vec<(Round, bool, Instant)> {
164 let mut binding = self.new_block_calls.lock();
165 let all_calls = binding.drain(0..);
166 all_calls.into_iter().collect()
167 }
168 }
169
170 #[async_trait]
171 impl CoreThreadDispatcher for MockCoreThreadDispatcher {
172 async fn add_blocks(
173 &self,
174 _blocks: Vec<VerifiedBlock>,
175 ) -> Result<BTreeSet<BlockRef>, CoreError> {
176 todo!()
177 }
178
179 async fn add_certified_commits(
180 &self,
181 _commits: CertifiedCommits,
182 ) -> Result<BTreeSet<BlockRef>, CoreError> {
183 todo!()
184 }
185
186 async fn check_block_refs(
187 &self,
188 _block_refs: Vec<BlockRef>,
189 ) -> Result<BTreeSet<BlockRef>, CoreError> {
190 todo!()
191 }
192
193 async fn new_block(&self, round: Round, force: bool) -> Result<(), CoreError> {
194 self.new_block_calls
195 .lock()
196 .push((round, force, Instant::now()));
197 Ok(())
198 }
199
200 async fn get_missing_blocks(
201 &self,
202 ) -> Result<BTreeMap<BlockRef, BTreeSet<AuthorityIndex>>, CoreError> {
203 todo!()
204 }
205
206 fn set_quorum_subscribers_exists(&self, _exists: bool) -> Result<(), CoreError> {
207 todo!()
208 }
209
210 fn set_propagation_delay_and_quorum_rounds(
211 &self,
212 _delay: Round,
213 _received_quorum_rounds: Vec<QuorumRound>,
214 _accepted_quorum_rounds: Vec<QuorumRound>,
215 ) -> Result<(), CoreError> {
216 todo!()
217 }
218
219 fn set_last_known_proposed_round(&self, _round: Round) -> Result<(), CoreError> {
220 todo!()
221 }
222
223 fn highest_received_rounds(&self) -> Vec<Round> {
224 todo!()
225 }
226 }
227
228 #[tokio::test(flavor = "current_thread", start_paused = true)]
229 async fn basic_leader_timeout() {
230 let (context, _signers) = Context::new_for_test(4);
231 let dispatcher = Arc::new(MockCoreThreadDispatcher::default());
232 let leader_timeout = Duration::from_millis(500);
233 let min_round_delay = Duration::from_millis(50);
234 let parameters = Parameters {
235 leader_timeout,
236 min_round_delay,
237 ..Default::default()
238 };
239 let context = Arc::new(context.with_parameters(parameters));
240 let start = Instant::now();
241
242 let (mut signals, signal_receivers) = CoreSignals::new(context.clone());
243
244 let _handle = LeaderTimeoutTask::start(dispatcher.clone(), &signal_receivers, context);
246
247 signals.new_round(10);
249
250 sleep(2 * min_round_delay).await;
253 let all_calls = dispatcher.get_new_block_calls().await;
254 assert_eq!(all_calls.len(), 1);
255
256 let (round, force, timestamp) = all_calls[0];
257 assert_eq!(round, 10);
258 assert!(!force);
259 assert!(
260 min_round_delay <= timestamp - start,
261 "Leader timeout min setting {:?} should be less than actual time difference {:?}",
262 min_round_delay,
263 timestamp - start
264 );
265
266 sleep(2 * leader_timeout).await;
268 let all_calls = dispatcher.get_new_block_calls().await;
269 assert_eq!(all_calls.len(), 1);
270
271 let (round, force, timestamp) = all_calls[0];
272 assert_eq!(round, 10);
273 assert!(force);
274 assert!(
275 leader_timeout <= timestamp - start,
276 "Leader timeout setting {:?} should be less than actual time difference {:?}",
277 leader_timeout,
278 timestamp - start
279 );
280
281 sleep(2 * leader_timeout).await;
283 let all_calls = dispatcher.get_new_block_calls().await;
284
285 assert_eq!(all_calls.len(), 0);
286 }
287
288 #[tokio::test(flavor = "current_thread", start_paused = true)]
289 async fn multiple_leader_timeouts() {
290 let (context, _signers) = Context::new_for_test(4);
291 let dispatcher = Arc::new(MockCoreThreadDispatcher::default());
292 let leader_timeout = Duration::from_millis(500);
293 let min_round_delay = Duration::from_millis(50);
294 let parameters = Parameters {
295 leader_timeout,
296 min_round_delay,
297 ..Default::default()
298 };
299 let context = Arc::new(context.with_parameters(parameters));
300 let now = Instant::now();
301
302 let (mut signals, signal_receivers) = CoreSignals::new(context.clone());
303
304 let _handle = LeaderTimeoutTask::start(dispatcher.clone(), &signal_receivers, context);
306
307 signals.new_round(13);
310 sleep(min_round_delay / 2).await;
311 signals.new_round(14);
312 sleep(min_round_delay / 2).await;
313 signals.new_round(15);
314 sleep(2 * leader_timeout).await;
315
316 let all_calls = dispatcher.get_new_block_calls().await;
318 let (round, force, timestamp) = all_calls[0];
319 assert_eq!(round, 15);
320 assert!(!force);
321 assert!(min_round_delay < timestamp - now);
322
323 let (round, force, timestamp) = all_calls[1];
324 assert_eq!(round, 15);
325 assert!(force);
326 assert!(leader_timeout < timestamp - now);
327 }
328}