consensus_core/network/
tonic_network.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::BTreeMap,
7    net::{SocketAddr, SocketAddrV4, SocketAddrV6},
8    pin::Pin,
9    sync::Arc,
10    time::{Duration, Instant},
11};
12
13use async_trait::async_trait;
14use bytes::Bytes;
15use consensus_config::{AuthorityIndex, NetworkKeyPair, NetworkPublicKey};
16use futures::{Stream, StreamExt as _, stream};
17use iota_http::ServerHandle;
18use iota_network_stack::{
19    Multiaddr,
20    callback::{CallbackLayer, MakeCallbackHandler, ResponseHandler},
21    multiaddr::Protocol,
22};
23use iota_tls::AllowPublicKeys;
24use parking_lot::RwLock;
25use tokio_stream::{Iter, iter};
26use tonic::{Request, Response, Streaming, codec::CompressionEncoding};
27use tower_http::trace::{DefaultMakeSpan, DefaultOnFailure, TraceLayer};
28use tracing::{debug, error, info, trace, warn};
29
30use super::{
31    BlockStream, ExtendedSerializedBlock, NetworkClient, NetworkManager, NetworkService,
32    metrics_layer::{MetricsCallbackMaker, MetricsResponseCallback, SizedRequest, SizedResponse},
33    tonic_gen::{
34        consensus_service_client::ConsensusServiceClient,
35        consensus_service_server::ConsensusService,
36    },
37};
38use crate::{
39    CommitIndex, Round,
40    block::{BlockRef, VerifiedBlock},
41    commit::CommitRange,
42    context::Context,
43    error::{ConsensusError, ConsensusResult},
44    network::{
45        tonic_gen::consensus_service_server::ConsensusServiceServer,
46        tonic_tls::certificate_server_name,
47    },
48};
49
50// Maximum bytes size in a single fetch_blocks()response.
51// TODO: put max RPC response size in protocol config.
52const MAX_FETCH_RESPONSE_BYTES: usize = 4 * 1024 * 1024;
53
54// Maximum total bytes fetched in a single fetch_blocks() call, after combining
55// the responses.
56const MAX_TOTAL_FETCHED_BYTES: usize = 128 * 1024 * 1024;
57
58// Implements Tonic RPC client for Consensus.
59pub(crate) struct TonicClient {
60    context: Arc<Context>,
61    network_keypair: NetworkKeyPair,
62    channel_pool: Arc<ChannelPool>,
63}
64
65impl TonicClient {
66    pub(crate) fn new(context: Arc<Context>, network_keypair: NetworkKeyPair) -> Self {
67        Self {
68            context: context.clone(),
69            network_keypair,
70            channel_pool: Arc::new(ChannelPool::new(context)),
71        }
72    }
73
74    async fn get_client(
75        &self,
76        peer: AuthorityIndex,
77        timeout: Duration,
78    ) -> ConsensusResult<ConsensusServiceClient<Channel>> {
79        let config = &self.context.parameters.tonic;
80        let channel = self
81            .channel_pool
82            .get_channel(self.network_keypair.clone(), peer, timeout)
83            .await?;
84        let mut client = ConsensusServiceClient::new(channel)
85            .max_encoding_message_size(config.message_size_limit)
86            .max_decoding_message_size(config.message_size_limit);
87
88        if self.context.protocol_config.consensus_zstd_compression() {
89            client = client
90                .send_compressed(CompressionEncoding::Zstd)
91                .accept_compressed(CompressionEncoding::Zstd);
92        }
93        Ok(client)
94    }
95}
96
97// TODO: make sure callsites do not send request to own index, and return error
98// otherwise.
99#[async_trait]
100impl NetworkClient for TonicClient {
101    const SUPPORT_STREAMING: bool = true;
102
103    async fn send_block(
104        &self,
105        peer: AuthorityIndex,
106        block: &VerifiedBlock,
107        timeout: Duration,
108    ) -> ConsensusResult<()> {
109        let mut client = self.get_client(peer, timeout).await?;
110        let mut request = Request::new(SendBlockRequest {
111            block: block.serialized().clone(),
112        });
113        request.set_timeout(timeout);
114        client
115            .send_block(request)
116            .await
117            .map_err(|e| ConsensusError::NetworkRequest(format!("send_block failed: {e:?}")))?;
118        Ok(())
119    }
120
121    async fn subscribe_blocks(
122        &self,
123        peer: AuthorityIndex,
124        last_received: Round,
125        timeout: Duration,
126    ) -> ConsensusResult<BlockStream> {
127        let mut client = self.get_client(peer, timeout).await?;
128        // TODO: add sampled block acknowledgments for latency measurements.
129        let request = Request::new(stream::once(async move {
130            SubscribeBlocksRequest {
131                last_received_round: last_received,
132            }
133        }));
134        let response = client.subscribe_blocks(request).await.map_err(|e| {
135            ConsensusError::NetworkRequest(format!("subscribe_blocks failed: {e:?}"))
136        })?;
137        let stream = response
138            .into_inner()
139            .take_while(|b| futures::future::ready(b.is_ok()))
140            .filter_map(move |b| async move {
141                match b {
142                    Ok(response) => Some(ExtendedSerializedBlock {
143                        block: response.block,
144                        excluded_ancestors: response.excluded_ancestors,
145                    }),
146                    Err(e) => {
147                        debug!("Network error received from {}: {e:?}", peer);
148                        None
149                    }
150                }
151            });
152        let rate_limited_stream =
153            tokio_stream::StreamExt::throttle(stream, self.context.parameters.min_round_delay / 2)
154                .boxed();
155        Ok(rate_limited_stream)
156    }
157
158    async fn fetch_blocks(
159        &self,
160        peer: AuthorityIndex,
161        block_refs: Vec<BlockRef>,
162        highest_accepted_rounds: Vec<Round>,
163        timeout: Duration,
164    ) -> ConsensusResult<Vec<Bytes>> {
165        let mut client = self.get_client(peer, timeout).await?;
166        let mut request = Request::new(FetchBlocksRequest {
167            block_refs: block_refs
168                .iter()
169                .filter_map(|r| match bcs::to_bytes(r) {
170                    Ok(serialized) => Some(serialized),
171                    Err(e) => {
172                        debug!("Failed to serialize block ref {:?}: {e:?}", r);
173                        None
174                    }
175                })
176                .collect(),
177            highest_accepted_rounds,
178        });
179        request.set_timeout(timeout);
180        let mut stream = client
181            .fetch_blocks(request)
182            .await
183            .map_err(|e| {
184                if e.code() == tonic::Code::DeadlineExceeded {
185                    ConsensusError::NetworkRequestTimeout(format!("fetch_blocks failed: {e:?}"))
186                } else {
187                    ConsensusError::NetworkRequest(format!("fetch_blocks failed: {e:?}"))
188                }
189            })?
190            .into_inner();
191        let mut blocks = vec![];
192        let mut total_fetched_bytes = 0;
193        loop {
194            match stream.message().await {
195                Ok(Some(response)) => {
196                    for b in &response.blocks {
197                        total_fetched_bytes += b.len();
198                    }
199                    blocks.extend(response.blocks);
200                    if total_fetched_bytes > MAX_TOTAL_FETCHED_BYTES {
201                        info!(
202                            "fetch_blocks() fetched bytes exceeded limit: {} > {}, terminating stream.",
203                            total_fetched_bytes, MAX_TOTAL_FETCHED_BYTES,
204                        );
205                        break;
206                    }
207                }
208                Ok(None) => {
209                    break;
210                }
211                Err(e) => {
212                    if blocks.is_empty() {
213                        if e.code() == tonic::Code::DeadlineExceeded {
214                            return Err(ConsensusError::NetworkRequestTimeout(format!(
215                                "fetch_blocks failed mid-stream: {e:?}"
216                            )));
217                        }
218                        return Err(ConsensusError::NetworkRequest(format!(
219                            "fetch_blocks failed mid-stream: {e:?}"
220                        )));
221                    } else {
222                        warn!("fetch_blocks failed mid-stream: {e:?}");
223                        break;
224                    }
225                }
226            }
227        }
228        Ok(blocks)
229    }
230
231    async fn fetch_commits(
232        &self,
233        peer: AuthorityIndex,
234        commit_range: CommitRange,
235        timeout: Duration,
236    ) -> ConsensusResult<(Vec<Bytes>, Vec<Bytes>)> {
237        let mut client = self.get_client(peer, timeout).await?;
238        let mut request = Request::new(FetchCommitsRequest {
239            start: commit_range.start(),
240            end: commit_range.end(),
241        });
242        request.set_timeout(timeout);
243        let response = client
244            .fetch_commits(request)
245            .await
246            .map_err(|e| ConsensusError::NetworkRequest(format!("fetch_commits failed: {e:?}")))?;
247        let response = response.into_inner();
248        Ok((response.commits, response.certifier_blocks))
249    }
250
251    async fn fetch_latest_blocks(
252        &self,
253        peer: AuthorityIndex,
254        authorities: Vec<AuthorityIndex>,
255        timeout: Duration,
256    ) -> ConsensusResult<Vec<Bytes>> {
257        let mut client = self.get_client(peer, timeout).await?;
258        let mut request = Request::new(FetchLatestBlocksRequest {
259            authorities: authorities
260                .iter()
261                .map(|authority| authority.value() as u32)
262                .collect(),
263        });
264        request.set_timeout(timeout);
265        let mut stream = client
266            .fetch_latest_blocks(request)
267            .await
268            .map_err(|e| {
269                if e.code() == tonic::Code::DeadlineExceeded {
270                    ConsensusError::NetworkRequestTimeout(format!("fetch_blocks failed: {e:?}"))
271                } else {
272                    ConsensusError::NetworkRequest(format!("fetch_blocks failed: {e:?}"))
273                }
274            })?
275            .into_inner();
276        let mut blocks = vec![];
277        let mut total_fetched_bytes = 0;
278        loop {
279            match stream.message().await {
280                Ok(Some(response)) => {
281                    for b in &response.blocks {
282                        total_fetched_bytes += b.len();
283                    }
284                    blocks.extend(response.blocks);
285                    if total_fetched_bytes > MAX_TOTAL_FETCHED_BYTES {
286                        info!(
287                            "fetch_blocks() fetched bytes exceeded limit: {} > {}, terminating stream.",
288                            total_fetched_bytes, MAX_TOTAL_FETCHED_BYTES,
289                        );
290                        break;
291                    }
292                }
293                Ok(None) => {
294                    break;
295                }
296                Err(e) => {
297                    if blocks.is_empty() {
298                        if e.code() == tonic::Code::DeadlineExceeded {
299                            return Err(ConsensusError::NetworkRequestTimeout(format!(
300                                "fetch_blocks failed mid-stream: {e:?}"
301                            )));
302                        }
303                        return Err(ConsensusError::NetworkRequest(format!(
304                            "fetch_blocks failed mid-stream: {e:?}"
305                        )));
306                    } else {
307                        warn!("fetch_latest_blocks failed mid-stream: {e:?}");
308                        break;
309                    }
310                }
311            }
312        }
313        Ok(blocks)
314    }
315
316    async fn get_latest_rounds(
317        &self,
318        peer: AuthorityIndex,
319        timeout: Duration,
320    ) -> ConsensusResult<(Vec<Round>, Vec<Round>)> {
321        let mut client = self.get_client(peer, timeout).await?;
322        let mut request = Request::new(GetLatestRoundsRequest {});
323        request.set_timeout(timeout);
324        let response = client.get_latest_rounds(request).await.map_err(|e| {
325            ConsensusError::NetworkRequest(format!("get_latest_rounds failed: {e:?}"))
326        })?;
327        let response = response.into_inner();
328        Ok((response.highest_received, response.highest_accepted))
329    }
330}
331
332// Tonic channel wrapped with layers.
333type Channel = iota_network_stack::callback::Callback<
334    tower_http::trace::Trace<
335        tonic_rustls::Channel,
336        tower_http::classify::SharedClassifier<tower_http::classify::GrpcErrorsAsFailures>,
337    >,
338    MetricsCallbackMaker,
339>;
340
341/// Manages a pool of connections to peers to avoid constantly reconnecting,
342/// which can be expensive.
343struct ChannelPool {
344    context: Arc<Context>,
345    // Size is limited by known authorities in the committee.
346    channels: RwLock<BTreeMap<AuthorityIndex, Channel>>,
347}
348
349impl ChannelPool {
350    fn new(context: Arc<Context>) -> Self {
351        Self {
352            context,
353            channels: RwLock::new(BTreeMap::new()),
354        }
355    }
356
357    async fn get_channel(
358        &self,
359        network_keypair: NetworkKeyPair,
360        peer: AuthorityIndex,
361        timeout: Duration,
362    ) -> ConsensusResult<Channel> {
363        {
364            let channels = self.channels.read();
365            if let Some(channel) = channels.get(&peer) {
366                return Ok(channel.clone());
367            }
368        }
369
370        let authority = self.context.committee.authority(peer);
371        let address = to_host_port_str(&authority.address).map_err(|e| {
372            ConsensusError::NetworkConfig(format!("Cannot convert address to host:port: {e:?}"))
373        })?;
374        let address = format!("https://{address}");
375        let config = &self.context.parameters.tonic;
376        let buffer_size = config.connection_buffer_size;
377        let client_tls_config = iota_tls::create_rustls_client_config(
378            self.context
379                .committee
380                .authority(peer)
381                .network_key
382                .clone()
383                .into_inner(),
384            certificate_server_name(&self.context),
385            Some(network_keypair.private_key().into_inner()),
386        );
387        let endpoint = tonic_rustls::Channel::from_shared(address.clone())
388            .unwrap()
389            .connect_timeout(timeout)
390            .initial_connection_window_size(Some(buffer_size as u32))
391            .initial_stream_window_size(Some(buffer_size as u32 / 2))
392            .keep_alive_while_idle(true)
393            .keep_alive_timeout(config.keepalive_interval)
394            .http2_keep_alive_interval(config.keepalive_interval)
395            // tcp keepalive is probably unnecessary and is unsupported by msim.
396            .user_agent("mysticeti")
397            .unwrap()
398            .tls_config(client_tls_config)
399            .unwrap();
400
401        let deadline = tokio::time::Instant::now() + timeout;
402        let channel = loop {
403            trace!("Connecting to endpoint at {address}");
404            match endpoint.connect().await {
405                Ok(channel) => break channel,
406                Err(e) => {
407                    debug!("Failed to connect to endpoint at {address}: {e:?}");
408                    if tokio::time::Instant::now() >= deadline {
409                        return Err(ConsensusError::NetworkClientConnection(format!(
410                            "Timed out connecting to endpoint at {address}: {e:?}"
411                        )));
412                    }
413                    tokio::time::sleep(Duration::from_secs(1)).await;
414                }
415            }
416        };
417        trace!("Connected to {address}");
418
419        let channel = tower::ServiceBuilder::new()
420            .layer(CallbackLayer::new(MetricsCallbackMaker::new(
421                self.context.metrics.network_metrics.outbound.clone(),
422                self.context.parameters.tonic.excessive_message_size,
423            )))
424            .layer(
425                TraceLayer::new_for_grpc()
426                    .make_span_with(DefaultMakeSpan::new().level(tracing::Level::TRACE))
427                    .on_failure(DefaultOnFailure::new().level(tracing::Level::DEBUG)),
428            )
429            .service(channel);
430
431        let mut channels = self.channels.write();
432        // There should not be many concurrent attempts at connecting to the same peer.
433        let channel = channels.entry(peer).or_insert(channel);
434        Ok(channel.clone())
435    }
436}
437
438/// Proxies Tonic requests to NetworkService with actual handler implementation.
439struct TonicServiceProxy<S: NetworkService> {
440    context: Arc<Context>,
441    service: Arc<S>,
442}
443
444impl<S: NetworkService> TonicServiceProxy<S> {
445    fn new(context: Arc<Context>, service: Arc<S>) -> Self {
446        Self { context, service }
447    }
448}
449
450#[async_trait]
451impl<S: NetworkService> ConsensusService for TonicServiceProxy<S> {
452    async fn send_block(
453        &self,
454        request: Request<SendBlockRequest>,
455    ) -> Result<Response<SendBlockResponse>, tonic::Status> {
456        let Some(peer_index) = request
457            .extensions()
458            .get::<PeerInfo>()
459            .map(|p| p.authority_index)
460        else {
461            return Err(tonic::Status::internal("PeerInfo not found"));
462        };
463        let block = request.into_inner().block;
464        let block = ExtendedSerializedBlock {
465            block,
466            excluded_ancestors: vec![],
467        };
468        self.service
469            .handle_send_block(peer_index, block)
470            .await
471            .map_err(|e| tonic::Status::invalid_argument(format!("{e:?}")))?;
472        Ok(Response::new(SendBlockResponse {}))
473    }
474
475    type SubscribeBlocksStream =
476        Pin<Box<dyn Stream<Item = Result<SubscribeBlocksResponse, tonic::Status>> + Send>>;
477
478    async fn subscribe_blocks(
479        &self,
480        request: Request<Streaming<SubscribeBlocksRequest>>,
481    ) -> Result<Response<Self::SubscribeBlocksStream>, tonic::Status> {
482        let Some(peer_index) = request
483            .extensions()
484            .get::<PeerInfo>()
485            .map(|p| p.authority_index)
486        else {
487            return Err(tonic::Status::internal("PeerInfo not found"));
488        };
489        let mut request_stream = request.into_inner();
490        let first_request = match request_stream.next().await {
491            Some(Ok(r)) => r,
492            Some(Err(e)) => {
493                debug!(
494                    "subscribe_blocks() request from {} failed: {e:?}",
495                    peer_index
496                );
497                return Err(tonic::Status::invalid_argument("Request error"));
498            }
499            None => {
500                return Err(tonic::Status::invalid_argument("Missing request"));
501            }
502        };
503        let stream = self
504            .service
505            .handle_subscribe_blocks(peer_index, first_request.last_received_round)
506            .await
507            .map_err(|e| tonic::Status::internal(format!("{e:?}")))?
508            .map(|block| {
509                Ok(SubscribeBlocksResponse {
510                    block: block.block,
511                    excluded_ancestors: block.excluded_ancestors,
512                })
513            });
514        let rate_limited_stream =
515            tokio_stream::StreamExt::throttle(stream, self.context.parameters.min_round_delay / 2)
516                .boxed();
517        Ok(Response::new(rate_limited_stream))
518    }
519
520    type FetchBlocksStream = Iter<std::vec::IntoIter<Result<FetchBlocksResponse, tonic::Status>>>;
521
522    async fn fetch_blocks(
523        &self,
524        request: Request<FetchBlocksRequest>,
525    ) -> Result<Response<Self::FetchBlocksStream>, tonic::Status> {
526        let Some(peer_index) = request
527            .extensions()
528            .get::<PeerInfo>()
529            .map(|p| p.authority_index)
530        else {
531            return Err(tonic::Status::internal("PeerInfo not found"));
532        };
533        let inner = request.into_inner();
534        let block_refs = inner
535            .block_refs
536            .into_iter()
537            .filter_map(|serialized| match bcs::from_bytes(&serialized) {
538                Ok(r) => Some(r),
539                Err(e) => {
540                    debug!("Failed to deserialize block ref {:?}: {e:?}", serialized);
541                    None
542                }
543            })
544            .collect();
545        let highest_accepted_rounds = inner.highest_accepted_rounds;
546        let blocks = self
547            .service
548            .handle_fetch_blocks(peer_index, block_refs, highest_accepted_rounds)
549            .await
550            .map_err(|e| tonic::Status::internal(format!("{e:?}")))?;
551        let responses: std::vec::IntoIter<Result<FetchBlocksResponse, tonic::Status>> =
552            chunk_blocks(blocks, MAX_FETCH_RESPONSE_BYTES)
553                .into_iter()
554                .map(|blocks| Ok(FetchBlocksResponse { blocks }))
555                .collect::<Vec<_>>()
556                .into_iter();
557        let stream = iter(responses);
558        Ok(Response::new(stream))
559    }
560
561    async fn fetch_commits(
562        &self,
563        request: Request<FetchCommitsRequest>,
564    ) -> Result<Response<FetchCommitsResponse>, tonic::Status> {
565        let Some(peer_index) = request
566            .extensions()
567            .get::<PeerInfo>()
568            .map(|p| p.authority_index)
569        else {
570            return Err(tonic::Status::internal("PeerInfo not found"));
571        };
572        let request = request.into_inner();
573        let (commits, certifier_blocks) = self
574            .service
575            .handle_fetch_commits(peer_index, (request.start..=request.end).into())
576            .await
577            .map_err(|e| tonic::Status::internal(format!("{e:?}")))?;
578        let commits = commits
579            .into_iter()
580            .map(|c| c.serialized().clone())
581            .collect();
582        let certifier_blocks = certifier_blocks
583            .into_iter()
584            .map(|b| b.serialized().clone())
585            .collect();
586        Ok(Response::new(FetchCommitsResponse {
587            commits,
588            certifier_blocks,
589        }))
590    }
591
592    type FetchLatestBlocksStream =
593        Iter<std::vec::IntoIter<Result<FetchLatestBlocksResponse, tonic::Status>>>;
594
595    async fn fetch_latest_blocks(
596        &self,
597        request: Request<FetchLatestBlocksRequest>,
598    ) -> Result<Response<Self::FetchLatestBlocksStream>, tonic::Status> {
599        let Some(peer_index) = request
600            .extensions()
601            .get::<PeerInfo>()
602            .map(|p| p.authority_index)
603        else {
604            return Err(tonic::Status::internal("PeerInfo not found"));
605        };
606        let inner = request.into_inner();
607
608        // Convert the authority indexes and validate them
609        let mut authorities = vec![];
610        for authority in inner.authorities.into_iter() {
611            let Some(authority) = self
612                .context
613                .committee
614                .to_authority_index(authority as usize)
615            else {
616                return Err(tonic::Status::internal(format!(
617                    "Invalid authority index provided {authority}"
618                )));
619            };
620            authorities.push(authority);
621        }
622
623        let blocks = self
624            .service
625            .handle_fetch_latest_blocks(peer_index, authorities)
626            .await
627            .map_err(|e| tonic::Status::internal(format!("{e:?}")))?;
628        let responses: std::vec::IntoIter<Result<FetchLatestBlocksResponse, tonic::Status>> =
629            chunk_blocks(blocks, MAX_FETCH_RESPONSE_BYTES)
630                .into_iter()
631                .map(|blocks| Ok(FetchLatestBlocksResponse { blocks }))
632                .collect::<Vec<_>>()
633                .into_iter();
634        let stream = iter(responses);
635        Ok(Response::new(stream))
636    }
637
638    async fn get_latest_rounds(
639        &self,
640        request: Request<GetLatestRoundsRequest>,
641    ) -> Result<Response<GetLatestRoundsResponse>, tonic::Status> {
642        let Some(peer_index) = request
643            .extensions()
644            .get::<PeerInfo>()
645            .map(|p| p.authority_index)
646        else {
647            return Err(tonic::Status::internal("PeerInfo not found"));
648        };
649        let (highest_received, highest_accepted) = self
650            .service
651            .handle_get_latest_rounds(peer_index)
652            .await
653            .map_err(|e| tonic::Status::internal(format!("{e:?}")))?;
654        Ok(Response::new(GetLatestRoundsResponse {
655            highest_received,
656            highest_accepted,
657        }))
658    }
659}
660
661/// Manages the lifecycle of Tonic network client and service. Typical usage
662/// during initialization:
663/// 1. Create a new `TonicManager`.
664/// 2. Take `TonicClient` from `TonicManager::client()`.
665/// 3. Create consensus components.
666/// 4. Create `TonicService` for consensus service handler.
667/// 5. Install `TonicService` to `TonicManager` with
668///    `TonicManager::install_service()`.
669pub(crate) struct TonicManager {
670    context: Arc<Context>,
671    network_keypair: NetworkKeyPair,
672    client: Arc<TonicClient>,
673    server: Option<ServerHandle>,
674}
675
676impl TonicManager {
677    pub(crate) fn new(context: Arc<Context>, network_keypair: NetworkKeyPair) -> Self {
678        Self {
679            context: context.clone(),
680            network_keypair: network_keypair.clone(),
681            client: Arc::new(TonicClient::new(context, network_keypair)),
682            server: None,
683        }
684    }
685}
686
687impl<S: NetworkService> NetworkManager<S> for TonicManager {
688    type Client = TonicClient;
689
690    fn new(context: Arc<Context>, network_keypair: NetworkKeyPair) -> Self {
691        TonicManager::new(context, network_keypair)
692    }
693
694    fn client(&self) -> Arc<Self::Client> {
695        self.client.clone()
696    }
697
698    async fn install_service(&mut self, service: Arc<S>) {
699        self.context
700            .metrics
701            .network_metrics
702            .network_type
703            .with_label_values(&["tonic"])
704            .set(1);
705
706        info!("Starting tonic service");
707
708        let authority = self.context.committee.authority(self.context.own_index);
709        // By default, bind to the unspecified address to allow the actual address to be
710        // assigned. But bind to localhost if it is requested.
711        let own_address = if authority.address.is_localhost_ip() {
712            authority.address.clone()
713        } else {
714            authority.address.with_zero_ip()
715        };
716        let own_address = to_socket_addr(&own_address).unwrap();
717        let service = TonicServiceProxy::new(self.context.clone(), service);
718        let config = &self.context.parameters.tonic;
719
720        let connections_info = Arc::new(ConnectionsInfo::new(self.context.clone()));
721        let layers = tower::ServiceBuilder::new()
722            // Add a layer to extract a peer's PeerInfo from their TLS certs
723            .map_request(move |mut request: http::Request<_>| {
724                if let Some(peer_certificates) =
725                    request.extensions().get::<iota_http::PeerCertificates>()
726                {
727                    if let Some(peer_info) =
728                        peer_info_from_certs(&connections_info, peer_certificates)
729                    {
730                        request.extensions_mut().insert(peer_info);
731                    }
732                }
733                request
734            })
735            .layer(CallbackLayer::new(MetricsCallbackMaker::new(
736                self.context.metrics.network_metrics.inbound.clone(),
737                self.context.parameters.tonic.excessive_message_size,
738            )))
739            .layer(
740                TraceLayer::new_for_grpc()
741                    .make_span_with(DefaultMakeSpan::new().level(tracing::Level::TRACE))
742                    .on_failure(DefaultOnFailure::new().level(tracing::Level::DEBUG)),
743            )
744            .layer_fn(|service| iota_network_stack::grpc_timeout::GrpcTimeout::new(service, None));
745
746        let mut consensus_service_server = ConsensusServiceServer::new(service)
747            .max_encoding_message_size(config.message_size_limit)
748            .max_decoding_message_size(config.message_size_limit);
749
750        if self.context.protocol_config.consensus_zstd_compression() {
751            consensus_service_server = consensus_service_server
752                .send_compressed(CompressionEncoding::Zstd)
753                .accept_compressed(CompressionEncoding::Zstd);
754        }
755
756        let consensus_service = tonic::service::Routes::new(consensus_service_server)
757            .into_axum_router()
758            .route_layer(layers);
759
760        let tls_server_config = iota_tls::create_rustls_server_config(
761            self.network_keypair.clone().private_key().into_inner(),
762            certificate_server_name(&self.context),
763            AllowPublicKeys::new(
764                self.context
765                    .committee
766                    .authorities()
767                    .map(|(_i, a)| a.network_key.clone().into_inner())
768                    .collect(),
769            ),
770        );
771
772        // Calculate some metrics around send/recv buffer sizes for the current
773        // machine/OS
774        #[cfg(not(msim))]
775        {
776            let tcp_connection_metrics =
777                &self.context.metrics.network_metrics.tcp_connection_metrics;
778
779            // Try creating an ephemeral port to test the highest allowed send and recv
780            // buffer sizes. Buffer sizes are not set explicitly on the socket
781            // used for real traffic, to allow the OS to set appropriate values.
782            {
783                let ephemeral_addr = SocketAddr::new(own_address.ip(), 0);
784                let ephemeral_socket = create_socket(&ephemeral_addr);
785                tcp_connection_metrics
786                    .socket_send_buffer_size
787                    .set(ephemeral_socket.send_buffer_size().unwrap_or(0) as i64);
788                tcp_connection_metrics
789                    .socket_recv_buffer_size
790                    .set(ephemeral_socket.recv_buffer_size().unwrap_or(0) as i64);
791
792                if let Err(e) = ephemeral_socket.set_send_buffer_size(32 << 20) {
793                    info!("Failed to set send buffer size: {e:?}");
794                }
795                if let Err(e) = ephemeral_socket.set_recv_buffer_size(32 << 20) {
796                    info!("Failed to set recv buffer size: {e:?}");
797                }
798                if ephemeral_socket.bind(ephemeral_addr).is_ok() {
799                    tcp_connection_metrics
800                        .socket_send_buffer_max_size
801                        .set(ephemeral_socket.send_buffer_size().unwrap_or(0) as i64);
802                    tcp_connection_metrics
803                        .socket_recv_buffer_max_size
804                        .set(ephemeral_socket.recv_buffer_size().unwrap_or(0) as i64);
805                };
806            }
807        }
808
809        let http_config = iota_http::Config::default()
810            .tcp_nodelay(true)
811            .initial_connection_window_size(64 << 20)
812            .initial_stream_window_size(32 << 20)
813            .http2_keepalive_interval(Some(config.keepalive_interval))
814            .http2_keepalive_timeout(Some(config.keepalive_interval))
815            .accept_http1(false);
816
817        // Create server
818        //
819        // During simtest crash/restart tests there may be an older instance of
820        // consensus running that is bound to the TCP port of `own_address` that
821        // hasn't finished relinquishing control of the port yet. So instead of
822        // crashing when the address is inuse, we will retry for a short/
823        // reasonable period of time before giving up.
824        let deadline = Instant::now() + Duration::from_secs(20);
825        let server = loop {
826            match iota_http::Builder::new()
827                .config(http_config.clone())
828                .tls_config(tls_server_config.clone())
829                .serve(own_address, consensus_service.clone())
830            {
831                Ok(server) => break server,
832                Err(err) => {
833                    warn!("Error starting consensus server: {err:?}");
834                    if Instant::now() > deadline {
835                        panic!("Failed to start consensus server within required deadline");
836                    }
837                    tokio::time::sleep(Duration::from_secs(1)).await;
838                }
839            }
840        };
841
842        info!("Server started at: {own_address}");
843        self.server = Some(server);
844    }
845
846    async fn stop(&mut self) {
847        if let Some(server) = self.server.take() {
848            server.shutdown().await;
849        }
850
851        self.context
852            .metrics
853            .network_metrics
854            .network_type
855            .with_label_values(&["tonic"])
856            .set(0);
857    }
858}
859
860// Ensure that if there is an active network running that it is shutdown when
861// the TonicManager is dropped.
862impl Drop for TonicManager {
863    fn drop(&mut self) {
864        if let Some(server) = self.server.as_ref() {
865            server.trigger_shutdown();
866        }
867    }
868}
869
870// TODO: improve iota-http to allow for providing a MakeService so that this can
871// be done once per connection
872fn peer_info_from_certs(
873    connections_info: &ConnectionsInfo,
874    peer_certificates: &iota_http::PeerCertificates,
875) -> Option<PeerInfo> {
876    let certs = peer_certificates.peer_certs();
877
878    if certs.len() != 1 {
879        trace!(
880            "Unexpected number of certificates from TLS stream: {}",
881            certs.len()
882        );
883        return None;
884    }
885    trace!("Received {} certificates", certs.len());
886    let public_key = iota_tls::public_key_from_certificate(&certs[0])
887        .map_err(|e| {
888            trace!("Failed to extract public key from certificate: {e:?}");
889            e
890        })
891        .ok()?;
892    let client_public_key = NetworkPublicKey::new(public_key);
893    let Some(authority_index) = connections_info.authority_index(&client_public_key) else {
894        error!("Failed to find the authority with public key {client_public_key:?}");
895        return None;
896    };
897    Some(PeerInfo { authority_index })
898}
899
900/// Attempts to convert a multiaddr of the form `/[ip4,ip6,dns]/{}/udp/{port}`
901/// into a host:port string.
902fn to_host_port_str(addr: &Multiaddr) -> Result<String, String> {
903    let mut iter = addr.iter();
904
905    match (iter.next(), iter.next()) {
906        (Some(Protocol::Ip4(ipaddr)), Some(Protocol::Udp(port))) => Ok(format!("{ipaddr}:{port}")),
907        (Some(Protocol::Ip6(ipaddr)), Some(Protocol::Udp(port))) => Ok(format!("{ipaddr}:{port}")),
908        (Some(Protocol::Dns(hostname)), Some(Protocol::Udp(port))) => {
909            Ok(format!("{hostname}:{port}"))
910        }
911
912        _ => Err(format!("unsupported multiaddr: {addr}")),
913    }
914}
915
916/// Attempts to convert a multiaddr of the form `/[ip4,ip6]/{}/[udp,tcp]/{port}`
917/// into a SocketAddr value.
918pub fn to_socket_addr(addr: &Multiaddr) -> Result<SocketAddr, String> {
919    let mut iter = addr.iter();
920
921    match (iter.next(), iter.next()) {
922        (Some(Protocol::Ip4(ipaddr)), Some(Protocol::Udp(port)))
923        | (Some(Protocol::Ip4(ipaddr)), Some(Protocol::Tcp(port))) => {
924            Ok(SocketAddr::V4(SocketAddrV4::new(ipaddr, port)))
925        }
926
927        (Some(Protocol::Ip6(ipaddr)), Some(Protocol::Udp(port)))
928        | (Some(Protocol::Ip6(ipaddr)), Some(Protocol::Tcp(port))) => {
929            Ok(SocketAddr::V6(SocketAddrV6::new(ipaddr, port, 0, 0)))
930        }
931
932        _ => Err(format!("unsupported multiaddr: {addr}")),
933    }
934}
935
936#[cfg(not(msim))]
937fn create_socket(address: &SocketAddr) -> tokio::net::TcpSocket {
938    let socket = if address.is_ipv4() {
939        tokio::net::TcpSocket::new_v4()
940    } else if address.is_ipv6() {
941        tokio::net::TcpSocket::new_v6()
942    } else {
943        panic!("Invalid own address: {address:?}");
944    }
945    .unwrap_or_else(|e| panic!("Cannot create TCP socket: {e:?}"));
946    if let Err(e) = socket.set_nodelay(true) {
947        info!("Failed to set TCP_NODELAY: {e:?}");
948    }
949    if let Err(e) = socket.set_reuseaddr(true) {
950        info!("Failed to set SO_REUSEADDR: {e:?}");
951    }
952    socket
953}
954
955/// Looks up authority index by authority public key.
956///
957/// TODO: Add connection monitoring, and keep track of connected peers.
958/// TODO: Maybe merge with connection_monitor.rs
959struct ConnectionsInfo {
960    authority_key_to_index: BTreeMap<NetworkPublicKey, AuthorityIndex>,
961}
962
963impl ConnectionsInfo {
964    fn new(context: Arc<Context>) -> Self {
965        let authority_key_to_index = context
966            .committee
967            .authorities()
968            .map(|(index, authority)| (authority.network_key.clone(), index))
969            .collect();
970        Self {
971            authority_key_to_index,
972        }
973    }
974
975    fn authority_index(&self, key: &NetworkPublicKey) -> Option<AuthorityIndex> {
976        self.authority_key_to_index.get(key).copied()
977    }
978}
979
980/// Information about the client peer, set per connection.
981#[derive(Clone, Debug)]
982struct PeerInfo {
983    authority_index: AuthorityIndex,
984}
985
986// Adapt MetricsCallbackMaker and MetricsResponseCallback to http.
987
988impl SizedRequest for http::request::Parts {
989    fn size(&self) -> usize {
990        // TODO: implement this.
991        0
992    }
993
994    fn route(&self) -> String {
995        let path = self.uri.path();
996        path.rsplit_once('/')
997            .map(|(_, route)| route)
998            .unwrap_or("unknown")
999            .to_string()
1000    }
1001}
1002
1003impl SizedResponse for http::response::Parts {
1004    fn size(&self) -> usize {
1005        // TODO: implement this.
1006        0
1007    }
1008
1009    fn error_type(&self) -> Option<String> {
1010        if self.status.is_success() {
1011            None
1012        } else {
1013            Some(self.status.to_string())
1014        }
1015    }
1016}
1017
1018impl MakeCallbackHandler for MetricsCallbackMaker {
1019    type Handler = MetricsResponseCallback;
1020
1021    fn make_handler(&self, request: &http::request::Parts) -> Self::Handler {
1022        self.handle_request(request)
1023    }
1024}
1025
1026impl ResponseHandler for MetricsResponseCallback {
1027    fn on_response(self, response: &http::response::Parts) {
1028        self.on_response(response)
1029    }
1030
1031    fn on_error<E>(self, err: &E) {
1032        self.on_error(err)
1033    }
1034}
1035
1036/// Network message types.
1037#[derive(Clone, prost::Message)]
1038pub(crate) struct SendBlockRequest {
1039    // Serialized SignedBlock.
1040    #[prost(bytes = "bytes", tag = "1")]
1041    block: Bytes,
1042}
1043
1044#[derive(Clone, prost::Message)]
1045pub(crate) struct SendBlockResponse {}
1046
1047#[derive(Clone, prost::Message)]
1048pub(crate) struct SubscribeBlocksRequest {
1049    #[prost(uint32, tag = "1")]
1050    last_received_round: Round,
1051}
1052
1053#[derive(Clone, prost::Message)]
1054pub(crate) struct SubscribeBlocksResponse {
1055    #[prost(bytes = "bytes", tag = "1")]
1056    block: Bytes,
1057    // Serialized BlockRefs that are excluded from the blocks ancestors.
1058    #[prost(bytes = "vec", repeated, tag = "2")]
1059    excluded_ancestors: Vec<Vec<u8>>,
1060}
1061
1062#[derive(Clone, prost::Message)]
1063pub(crate) struct FetchBlocksRequest {
1064    #[prost(bytes = "vec", repeated, tag = "1")]
1065    block_refs: Vec<Vec<u8>>,
1066    // The highest accepted round per authority. The vector represents the round for each authority
1067    // and its length should be the same as the committee size.
1068    #[prost(uint32, repeated, tag = "2")]
1069    highest_accepted_rounds: Vec<Round>,
1070}
1071
1072#[derive(Clone, prost::Message)]
1073pub(crate) struct FetchBlocksResponse {
1074    // The response of the requested blocks as Serialized SignedBlock.
1075    #[prost(bytes = "bytes", repeated, tag = "1")]
1076    blocks: Vec<Bytes>,
1077}
1078
1079#[derive(Clone, prost::Message)]
1080pub(crate) struct FetchCommitsRequest {
1081    #[prost(uint32, tag = "1")]
1082    start: CommitIndex,
1083    #[prost(uint32, tag = "2")]
1084    end: CommitIndex,
1085}
1086
1087#[derive(Clone, prost::Message)]
1088pub(crate) struct FetchCommitsResponse {
1089    // Serialized consecutive Commit.
1090    #[prost(bytes = "bytes", repeated, tag = "1")]
1091    commits: Vec<Bytes>,
1092    // Serialized SignedBlock that certify the last commit from above.
1093    #[prost(bytes = "bytes", repeated, tag = "2")]
1094    certifier_blocks: Vec<Bytes>,
1095}
1096
1097#[derive(Clone, prost::Message)]
1098pub(crate) struct FetchLatestBlocksRequest {
1099    #[prost(uint32, repeated, tag = "1")]
1100    authorities: Vec<u32>,
1101}
1102
1103#[derive(Clone, prost::Message)]
1104pub(crate) struct FetchLatestBlocksResponse {
1105    // The response of the requested blocks as Serialized SignedBlock.
1106    #[prost(bytes = "bytes", repeated, tag = "1")]
1107    blocks: Vec<Bytes>,
1108}
1109
1110#[derive(Clone, prost::Message)]
1111pub(crate) struct GetLatestRoundsRequest {}
1112
1113#[derive(Clone, prost::Message)]
1114pub(crate) struct GetLatestRoundsResponse {
1115    // Highest received round per authority.
1116    #[prost(uint32, repeated, tag = "1")]
1117    highest_received: Vec<u32>,
1118    // Highest accepted round per authority.
1119    #[prost(uint32, repeated, tag = "2")]
1120    highest_accepted: Vec<u32>,
1121}
1122
1123fn chunk_blocks(blocks: Vec<Bytes>, chunk_limit: usize) -> Vec<Vec<Bytes>> {
1124    let mut chunks = vec![];
1125    let mut chunk = vec![];
1126    let mut chunk_size = 0;
1127    for block in blocks {
1128        let block_size = block.len();
1129        if !chunk.is_empty() && chunk_size + block_size > chunk_limit {
1130            chunks.push(chunk);
1131            chunk = vec![];
1132            chunk_size = 0;
1133        }
1134        chunk.push(block);
1135        chunk_size += block_size;
1136    }
1137    if !chunk.is_empty() {
1138        chunks.push(chunk);
1139    }
1140    chunks
1141}