iota_network/state_sync/
server.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    sync::{Arc, RwLock},
7    task::{Context, Poll},
8};
9
10use anemo::{Request, Response, Result, rpc::Status, types::response::StatusCode};
11use dashmap::DashMap;
12use futures::future::BoxFuture;
13use iota_types::{
14    digests::{CheckpointContentsDigest, CheckpointDigest},
15    messages_checkpoint::{
16        CertifiedCheckpointSummary as Checkpoint, CheckpointSequenceNumber, FullCheckpointContents,
17        VerifiedCheckpoint,
18    },
19    storage::WriteStore,
20};
21use serde::{Deserialize, Serialize};
22use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
23
24use super::{PeerHeights, StateSync, StateSyncMessage};
25
26#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
27pub enum GetCheckpointSummaryRequest {
28    Latest,
29    ByDigest(CheckpointDigest),
30    BySequenceNumber(CheckpointSequenceNumber),
31}
32
33#[derive(Clone, Debug, Serialize, Deserialize)]
34pub struct GetCheckpointAvailabilityResponse {
35    pub(crate) highest_synced_checkpoint: Checkpoint,
36    pub(crate) lowest_available_checkpoint: CheckpointSequenceNumber,
37}
38
39pub(super) struct Server<S> {
40    pub(super) store: S,
41    pub(super) peer_heights: Arc<RwLock<PeerHeights>>,
42    pub(super) sender: mpsc::WeakSender<StateSyncMessage>,
43}
44
45#[anemo::async_trait]
46impl<S> StateSync for Server<S>
47where
48    S: WriteStore + Send + Sync + 'static,
49{
50    /// Pushes a checkpoint summary to the server.
51    /// If the checkpoint is higher than the highest verified checkpoint, it
52    /// will notify the event loop to potentially sync it.
53    async fn push_checkpoint_summary(
54        &self,
55        request: Request<Checkpoint>,
56    ) -> Result<Response<()>, Status> {
57        let peer_id = request
58            .peer_id()
59            .copied()
60            .ok_or_else(|| Status::internal("unable to query sender's PeerId"))?;
61
62        let checkpoint = request.into_inner();
63        if !self
64            .peer_heights
65            .write()
66            .unwrap()
67            .update_peer_info(peer_id, checkpoint.clone(), None)
68        {
69            return Ok(Response::new(()));
70        }
71
72        let highest_verified_checkpoint = *self
73            .store
74            .get_highest_verified_checkpoint()
75            .map_err(|e| Status::internal(e.to_string()))?
76            .sequence_number();
77
78        // If this checkpoint is higher than our highest verified checkpoint notify the
79        // event loop to potentially sync it
80        if *checkpoint.sequence_number() > highest_verified_checkpoint {
81            if let Some(sender) = self.sender.upgrade() {
82                sender.send(StateSyncMessage::StartSyncJob).await.unwrap();
83            }
84        }
85
86        Ok(Response::new(()))
87    }
88
89    /// Gets a checkpoint summary by digest or sequence number, or get the
90    /// latest one.
91    async fn get_checkpoint_summary(
92        &self,
93        request: Request<GetCheckpointSummaryRequest>,
94    ) -> Result<Response<Option<Checkpoint>>, Status> {
95        let checkpoint = match request.inner() {
96            GetCheckpointSummaryRequest::Latest => {
97                self.store.get_highest_synced_checkpoint().map(Some)
98            }
99            GetCheckpointSummaryRequest::ByDigest(digest) => {
100                self.store.get_checkpoint_by_digest(digest)
101            }
102            GetCheckpointSummaryRequest::BySequenceNumber(sequence_number) => self
103                .store
104                .get_checkpoint_by_sequence_number(*sequence_number),
105        }
106        .map_err(|e| Status::internal(e.to_string()))?
107        .map(VerifiedCheckpoint::into_inner);
108
109        Ok(Response::new(checkpoint))
110    }
111
112    /// Gets the highest synced checkpoint and the lowest available checkpoint
113    /// of the node.
114    async fn get_checkpoint_availability(
115        &self,
116        _request: Request<()>,
117    ) -> Result<Response<GetCheckpointAvailabilityResponse>, Status> {
118        let highest_synced_checkpoint = self
119            .store
120            .get_highest_synced_checkpoint()
121            .map_err(|e| Status::internal(e.to_string()))
122            .map(VerifiedCheckpoint::into_inner)?;
123        let lowest_available_checkpoint = self
124            .store
125            .get_lowest_available_checkpoint()
126            .map_err(|e| Status::internal(e.to_string()))?;
127
128        Ok(Response::new(GetCheckpointAvailabilityResponse {
129            highest_synced_checkpoint,
130            lowest_available_checkpoint,
131        }))
132    }
133
134    /// Gets the contents of a checkpoint.
135    async fn get_checkpoint_contents(
136        &self,
137        request: Request<CheckpointContentsDigest>,
138    ) -> Result<Response<Option<FullCheckpointContents>>, Status> {
139        let contents = self
140            .store
141            .get_full_checkpoint_contents(request.inner())
142            .map_err(|e| Status::internal(e.to_string()))?;
143        Ok(Response::new(contents))
144    }
145}
146
147/// [`Layer`] for adding a per-checkpoint limit to the number of inflight
148/// GetCheckpointContent requests.
149#[derive(Clone)]
150pub(super) struct CheckpointContentsDownloadLimitLayer {
151    inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
152    max_inflight_per_checkpoint: usize,
153}
154
155impl CheckpointContentsDownloadLimitLayer {
156    pub(super) fn new(max_inflight_per_checkpoint: usize) -> Self {
157        Self {
158            inflight_per_checkpoint: Arc::new(DashMap::new()),
159            max_inflight_per_checkpoint,
160        }
161    }
162
163    pub(super) fn maybe_prune_map(&self) {
164        const PRUNE_THRESHOLD: usize = 5000;
165        if self.inflight_per_checkpoint.len() >= PRUNE_THRESHOLD {
166            self.inflight_per_checkpoint.retain(|_, semaphore| {
167                semaphore.available_permits() < self.max_inflight_per_checkpoint
168            });
169        }
170    }
171}
172
173impl<S> tower::layer::Layer<S> for CheckpointContentsDownloadLimitLayer {
174    type Service = CheckpointContentsDownloadLimit<S>;
175
176    fn layer(&self, inner: S) -> Self::Service {
177        CheckpointContentsDownloadLimit {
178            inner,
179            inflight_per_checkpoint: self.inflight_per_checkpoint.clone(),
180            max_inflight_per_checkpoint: self.max_inflight_per_checkpoint,
181        }
182    }
183}
184
185/// Middleware for adding a per-checkpoint limit to the number of inflight
186/// GetCheckpointContent requests.
187#[derive(Clone)]
188pub(super) struct CheckpointContentsDownloadLimit<S> {
189    inner: S,
190    inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
191    max_inflight_per_checkpoint: usize,
192}
193
194impl<S> tower::Service<Request<CheckpointContentsDigest>> for CheckpointContentsDownloadLimit<S>
195where
196    S: tower::Service<
197            Request<CheckpointContentsDigest>,
198            Response = Response<Option<FullCheckpointContents>>,
199            Error = Status,
200        >
201        + 'static
202        + Clone
203        + Send,
204    <S as tower::Service<Request<CheckpointContentsDigest>>>::Future: Send,
205    Request<CheckpointContentsDigest>: 'static + Send + Sync,
206{
207    type Response = Response<Option<FullCheckpointContents>>;
208    type Error = S::Error;
209    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
210
211    #[inline]
212    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
213        self.inner.poll_ready(cx)
214    }
215
216    fn call(&mut self, req: Request<CheckpointContentsDigest>) -> Self::Future {
217        let inflight_per_checkpoint = self.inflight_per_checkpoint.clone();
218        let max_inflight_per_checkpoint = self.max_inflight_per_checkpoint;
219        let mut inner = self.inner.clone();
220
221        let fut = async move {
222            let semaphore = {
223                let semaphore_entry = inflight_per_checkpoint
224                    .entry(*req.body())
225                    .or_insert_with(|| Arc::new(Semaphore::new(max_inflight_per_checkpoint)));
226                semaphore_entry.value().clone()
227            };
228            let permit = semaphore.try_acquire_owned().map_err(|e| match e {
229                tokio::sync::TryAcquireError::Closed => {
230                    anemo::rpc::Status::new(StatusCode::InternalServerError)
231                }
232                tokio::sync::TryAcquireError::NoPermits => {
233                    anemo::rpc::Status::new(StatusCode::TooManyRequests)
234                }
235            })?;
236
237            struct SemaphoreExtension(#[expect(unused)] OwnedSemaphorePermit);
238            inner.call(req).await.map(move |mut response| {
239                // Insert permit as extension so it's not dropped until the response is sent.
240                response
241                    .extensions_mut()
242                    .insert(Arc::new(SemaphoreExtension(permit)));
243                response
244            })
245        };
246        Box::pin(fut)
247    }
248}