iota_network/state_sync/
server.rs1use std::{
6 sync::{Arc, RwLock},
7 task::{Context, Poll},
8};
9
10use anemo::{Request, Response, Result, rpc::Status, types::response::StatusCode};
11use dashmap::DashMap;
12use futures::future::BoxFuture;
13use iota_types::{
14 digests::{CheckpointContentsDigest, CheckpointDigest},
15 messages_checkpoint::{
16 CertifiedCheckpointSummary as Checkpoint, CheckpointSequenceNumber, FullCheckpointContents,
17 VerifiedCheckpoint,
18 },
19 storage::WriteStore,
20};
21use serde::{Deserialize, Serialize};
22use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
23
24use super::{PeerHeights, StateSync, StateSyncMessage};
25
26#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
27pub enum GetCheckpointSummaryRequest {
28 Latest,
29 ByDigest(CheckpointDigest),
30 BySequenceNumber(CheckpointSequenceNumber),
31}
32
33#[derive(Clone, Debug, Serialize, Deserialize)]
34pub struct GetCheckpointAvailabilityResponse {
35 pub(crate) highest_synced_checkpoint: Checkpoint,
36 pub(crate) lowest_available_checkpoint: CheckpointSequenceNumber,
37}
38
39pub(super) struct Server<S> {
40 pub(super) store: S,
41 pub(super) peer_heights: Arc<RwLock<PeerHeights>>,
42 pub(super) sender: mpsc::WeakSender<StateSyncMessage>,
43}
44
45#[anemo::async_trait]
46impl<S> StateSync for Server<S>
47where
48 S: WriteStore + Send + Sync + 'static,
49{
50 async fn push_checkpoint_summary(
54 &self,
55 request: Request<Checkpoint>,
56 ) -> Result<Response<()>, Status> {
57 let peer_id = request
58 .peer_id()
59 .copied()
60 .ok_or_else(|| Status::internal("unable to query sender's PeerId"))?;
61
62 let checkpoint = request.into_inner();
63 if !self
64 .peer_heights
65 .write()
66 .unwrap()
67 .update_peer_info(peer_id, checkpoint.clone(), None)
68 {
69 return Ok(Response::new(()));
70 }
71
72 let highest_verified_checkpoint = *self
73 .store
74 .get_highest_verified_checkpoint()
75 .map_err(|e| Status::internal(e.to_string()))?
76 .sequence_number();
77
78 if *checkpoint.sequence_number() > highest_verified_checkpoint {
81 if let Some(sender) = self.sender.upgrade() {
82 sender.send(StateSyncMessage::StartSyncJob).await.unwrap();
83 }
84 }
85
86 Ok(Response::new(()))
87 }
88
89 async fn get_checkpoint_summary(
92 &self,
93 request: Request<GetCheckpointSummaryRequest>,
94 ) -> Result<Response<Option<Checkpoint>>, Status> {
95 let checkpoint = match request.inner() {
96 GetCheckpointSummaryRequest::Latest => {
97 self.store.get_highest_synced_checkpoint().map(Some)
98 }
99 GetCheckpointSummaryRequest::ByDigest(digest) => {
100 self.store.get_checkpoint_by_digest(digest)
101 }
102 GetCheckpointSummaryRequest::BySequenceNumber(sequence_number) => self
103 .store
104 .get_checkpoint_by_sequence_number(*sequence_number),
105 }
106 .map_err(|e| Status::internal(e.to_string()))?
107 .map(VerifiedCheckpoint::into_inner);
108
109 Ok(Response::new(checkpoint))
110 }
111
112 async fn get_checkpoint_availability(
115 &self,
116 _request: Request<()>,
117 ) -> Result<Response<GetCheckpointAvailabilityResponse>, Status> {
118 let highest_synced_checkpoint = self
119 .store
120 .get_highest_synced_checkpoint()
121 .map_err(|e| Status::internal(e.to_string()))
122 .map(VerifiedCheckpoint::into_inner)?;
123 let lowest_available_checkpoint = self
124 .store
125 .get_lowest_available_checkpoint()
126 .map_err(|e| Status::internal(e.to_string()))?;
127
128 Ok(Response::new(GetCheckpointAvailabilityResponse {
129 highest_synced_checkpoint,
130 lowest_available_checkpoint,
131 }))
132 }
133
134 async fn get_checkpoint_contents(
136 &self,
137 request: Request<CheckpointContentsDigest>,
138 ) -> Result<Response<Option<FullCheckpointContents>>, Status> {
139 let contents = self
140 .store
141 .get_full_checkpoint_contents(request.inner())
142 .map_err(|e| Status::internal(e.to_string()))?;
143 Ok(Response::new(contents))
144 }
145}
146
147#[derive(Clone)]
150pub(super) struct CheckpointContentsDownloadLimitLayer {
151 inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
152 max_inflight_per_checkpoint: usize,
153}
154
155impl CheckpointContentsDownloadLimitLayer {
156 pub(super) fn new(max_inflight_per_checkpoint: usize) -> Self {
157 Self {
158 inflight_per_checkpoint: Arc::new(DashMap::new()),
159 max_inflight_per_checkpoint,
160 }
161 }
162
163 pub(super) fn maybe_prune_map(&self) {
164 const PRUNE_THRESHOLD: usize = 5000;
165 if self.inflight_per_checkpoint.len() >= PRUNE_THRESHOLD {
166 self.inflight_per_checkpoint.retain(|_, semaphore| {
167 semaphore.available_permits() < self.max_inflight_per_checkpoint
168 });
169 }
170 }
171}
172
173impl<S> tower::layer::Layer<S> for CheckpointContentsDownloadLimitLayer {
174 type Service = CheckpointContentsDownloadLimit<S>;
175
176 fn layer(&self, inner: S) -> Self::Service {
177 CheckpointContentsDownloadLimit {
178 inner,
179 inflight_per_checkpoint: self.inflight_per_checkpoint.clone(),
180 max_inflight_per_checkpoint: self.max_inflight_per_checkpoint,
181 }
182 }
183}
184
185#[derive(Clone)]
188pub(super) struct CheckpointContentsDownloadLimit<S> {
189 inner: S,
190 inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
191 max_inflight_per_checkpoint: usize,
192}
193
194impl<S> tower::Service<Request<CheckpointContentsDigest>> for CheckpointContentsDownloadLimit<S>
195where
196 S: tower::Service<
197 Request<CheckpointContentsDigest>,
198 Response = Response<Option<FullCheckpointContents>>,
199 Error = Status,
200 >
201 + 'static
202 + Clone
203 + Send,
204 <S as tower::Service<Request<CheckpointContentsDigest>>>::Future: Send,
205 Request<CheckpointContentsDigest>: 'static + Send + Sync,
206{
207 type Response = Response<Option<FullCheckpointContents>>;
208 type Error = S::Error;
209 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
210
211 #[inline]
212 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
213 self.inner.poll_ready(cx)
214 }
215
216 fn call(&mut self, req: Request<CheckpointContentsDigest>) -> Self::Future {
217 let inflight_per_checkpoint = self.inflight_per_checkpoint.clone();
218 let max_inflight_per_checkpoint = self.max_inflight_per_checkpoint;
219 let mut inner = self.inner.clone();
220
221 let fut = async move {
222 let semaphore = {
223 let semaphore_entry = inflight_per_checkpoint
224 .entry(*req.body())
225 .or_insert_with(|| Arc::new(Semaphore::new(max_inflight_per_checkpoint)));
226 semaphore_entry.value().clone()
227 };
228 let permit = semaphore.try_acquire_owned().map_err(|e| match e {
229 tokio::sync::TryAcquireError::Closed => {
230 anemo::rpc::Status::new(StatusCode::InternalServerError)
231 }
232 tokio::sync::TryAcquireError::NoPermits => {
233 anemo::rpc::Status::new(StatusCode::TooManyRequests)
234 }
235 })?;
236
237 struct SemaphoreExtension(#[expect(unused)] OwnedSemaphorePermit);
238 inner.call(req).await.map(move |mut response| {
239 response
241 .extensions_mut()
242 .insert(Arc::new(SemaphoreExtension(permit)));
243 response
244 })
245 };
246 Box::pin(fut)
247 }
248}