consensus_core/network/
tonic_network.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::BTreeMap,
7    net::{SocketAddr, SocketAddrV4, SocketAddrV6},
8    pin::Pin,
9    sync::Arc,
10    time::{Duration, Instant},
11};
12
13use async_trait::async_trait;
14use bytes::Bytes;
15use consensus_config::{AuthorityIndex, NetworkKeyPair, NetworkPublicKey};
16use futures::{Stream, StreamExt as _, stream};
17use iota_http::ServerHandle;
18use iota_network_stack::{
19    Multiaddr,
20    callback::{CallbackLayer, MakeCallbackHandler, ResponseHandler},
21    multiaddr::Protocol,
22};
23use parking_lot::RwLock;
24use tokio_stream::{Iter, iter};
25use tonic::{Request, Response, Streaming, codec::CompressionEncoding};
26use tower_http::trace::{DefaultMakeSpan, DefaultOnFailure, TraceLayer};
27use tracing::{debug, error, info, trace, warn};
28
29use super::{
30    BlockStream, ExtendedSerializedBlock, NetworkClient, NetworkManager, NetworkService,
31    metrics_layer::{MetricsCallbackMaker, MetricsResponseCallback, SizedRequest, SizedResponse},
32    tonic_gen::{
33        consensus_service_client::ConsensusServiceClient,
34        consensus_service_server::ConsensusService,
35    },
36    tonic_tls::create_rustls_client_config,
37};
38use crate::{
39    CommitIndex, Round,
40    block::{BlockRef, VerifiedBlock},
41    commit::CommitRange,
42    context::Context,
43    error::{ConsensusError, ConsensusResult},
44    network::{
45        tonic_gen::consensus_service_server::ConsensusServiceServer,
46        tonic_tls::create_rustls_server_config,
47    },
48};
49
50// Maximum bytes size in a single fetch_blocks()response.
51// TODO: put max RPC response size in protocol config.
52const MAX_FETCH_RESPONSE_BYTES: usize = 4 * 1024 * 1024;
53
54// Maximum total bytes fetched in a single fetch_blocks() call, after combining
55// the responses.
56const MAX_TOTAL_FETCHED_BYTES: usize = 128 * 1024 * 1024;
57
58// Implements Tonic RPC client for Consensus.
59pub(crate) struct TonicClient {
60    context: Arc<Context>,
61    network_keypair: NetworkKeyPair,
62    channel_pool: Arc<ChannelPool>,
63}
64
65impl TonicClient {
66    pub(crate) fn new(context: Arc<Context>, network_keypair: NetworkKeyPair) -> Self {
67        Self {
68            context: context.clone(),
69            network_keypair,
70            channel_pool: Arc::new(ChannelPool::new(context)),
71        }
72    }
73
74    async fn get_client(
75        &self,
76        peer: AuthorityIndex,
77        timeout: Duration,
78    ) -> ConsensusResult<ConsensusServiceClient<Channel>> {
79        let config = &self.context.parameters.tonic;
80        let channel = self
81            .channel_pool
82            .get_channel(self.network_keypair.clone(), peer, timeout)
83            .await?;
84        let mut client = ConsensusServiceClient::new(channel)
85            .max_encoding_message_size(config.message_size_limit)
86            .max_decoding_message_size(config.message_size_limit);
87
88        if self.context.protocol_config.consensus_zstd_compression() {
89            client = client
90                .send_compressed(CompressionEncoding::Zstd)
91                .accept_compressed(CompressionEncoding::Zstd);
92        }
93        Ok(client)
94    }
95}
96
97// TODO: make sure callsites do not send request to own index, and return error
98// otherwise.
99#[async_trait]
100impl NetworkClient for TonicClient {
101    const SUPPORT_STREAMING: bool = true;
102
103    async fn send_block(
104        &self,
105        peer: AuthorityIndex,
106        block: &VerifiedBlock,
107        timeout: Duration,
108    ) -> ConsensusResult<()> {
109        let mut client = self.get_client(peer, timeout).await?;
110        let mut request = Request::new(SendBlockRequest {
111            block: block.serialized().clone(),
112        });
113        request.set_timeout(timeout);
114        client
115            .send_block(request)
116            .await
117            .map_err(|e| ConsensusError::NetworkRequest(format!("send_block failed: {e:?}")))?;
118        Ok(())
119    }
120
121    async fn subscribe_blocks(
122        &self,
123        peer: AuthorityIndex,
124        last_received: Round,
125        timeout: Duration,
126    ) -> ConsensusResult<BlockStream> {
127        let mut client = self.get_client(peer, timeout).await?;
128        // TODO: add sampled block acknowledgments for latency measurements.
129        let request = Request::new(stream::once(async move {
130            SubscribeBlocksRequest {
131                last_received_round: last_received,
132            }
133        }));
134        let response = client.subscribe_blocks(request).await.map_err(|e| {
135            ConsensusError::NetworkRequest(format!("subscribe_blocks failed: {e:?}"))
136        })?;
137        let stream = response
138            .into_inner()
139            .take_while(|b| futures::future::ready(b.is_ok()))
140            .filter_map(move |b| async move {
141                match b {
142                    Ok(response) => Some(ExtendedSerializedBlock {
143                        block: response.block,
144                        excluded_ancestors: response.excluded_ancestors,
145                    }),
146                    Err(e) => {
147                        debug!("Network error received from {}: {e:?}", peer);
148                        None
149                    }
150                }
151            });
152        let rate_limited_stream =
153            tokio_stream::StreamExt::throttle(stream, self.context.parameters.min_round_delay / 2)
154                .boxed();
155        Ok(rate_limited_stream)
156    }
157
158    async fn fetch_blocks(
159        &self,
160        peer: AuthorityIndex,
161        block_refs: Vec<BlockRef>,
162        highest_accepted_rounds: Vec<Round>,
163        timeout: Duration,
164    ) -> ConsensusResult<Vec<Bytes>> {
165        let mut client = self.get_client(peer, timeout).await?;
166        let mut request = Request::new(FetchBlocksRequest {
167            block_refs: block_refs
168                .iter()
169                .filter_map(|r| match bcs::to_bytes(r) {
170                    Ok(serialized) => Some(serialized),
171                    Err(e) => {
172                        debug!("Failed to serialize block ref {:?}: {e:?}", r);
173                        None
174                    }
175                })
176                .collect(),
177            highest_accepted_rounds,
178        });
179        request.set_timeout(timeout);
180        let mut stream = client
181            .fetch_blocks(request)
182            .await
183            .map_err(|e| {
184                if e.code() == tonic::Code::DeadlineExceeded {
185                    ConsensusError::NetworkRequestTimeout(format!("fetch_blocks failed: {e:?}"))
186                } else {
187                    ConsensusError::NetworkRequest(format!("fetch_blocks failed: {e:?}"))
188                }
189            })?
190            .into_inner();
191        let mut blocks = vec![];
192        let mut total_fetched_bytes = 0;
193        loop {
194            match stream.message().await {
195                Ok(Some(response)) => {
196                    for b in &response.blocks {
197                        total_fetched_bytes += b.len();
198                    }
199                    blocks.extend(response.blocks);
200                    if total_fetched_bytes > MAX_TOTAL_FETCHED_BYTES {
201                        info!(
202                            "fetch_blocks() fetched bytes exceeded limit: {} > {}, terminating stream.",
203                            total_fetched_bytes, MAX_TOTAL_FETCHED_BYTES,
204                        );
205                        break;
206                    }
207                }
208                Ok(None) => {
209                    break;
210                }
211                Err(e) => {
212                    if blocks.is_empty() {
213                        if e.code() == tonic::Code::DeadlineExceeded {
214                            return Err(ConsensusError::NetworkRequestTimeout(format!(
215                                "fetch_blocks failed mid-stream: {e:?}"
216                            )));
217                        }
218                        return Err(ConsensusError::NetworkRequest(format!(
219                            "fetch_blocks failed mid-stream: {e:?}"
220                        )));
221                    } else {
222                        warn!("fetch_blocks failed mid-stream: {e:?}");
223                        break;
224                    }
225                }
226            }
227        }
228        Ok(blocks)
229    }
230
231    async fn fetch_commits(
232        &self,
233        peer: AuthorityIndex,
234        commit_range: CommitRange,
235        timeout: Duration,
236    ) -> ConsensusResult<(Vec<Bytes>, Vec<Bytes>)> {
237        let mut client = self.get_client(peer, timeout).await?;
238        let mut request = Request::new(FetchCommitsRequest {
239            start: commit_range.start(),
240            end: commit_range.end(),
241        });
242        request.set_timeout(timeout);
243        let response = client
244            .fetch_commits(request)
245            .await
246            .map_err(|e| ConsensusError::NetworkRequest(format!("fetch_commits failed: {e:?}")))?;
247        let response = response.into_inner();
248        Ok((response.commits, response.certifier_blocks))
249    }
250
251    async fn fetch_latest_blocks(
252        &self,
253        peer: AuthorityIndex,
254        authorities: Vec<AuthorityIndex>,
255        timeout: Duration,
256    ) -> ConsensusResult<Vec<Bytes>> {
257        let mut client = self.get_client(peer, timeout).await?;
258        let mut request = Request::new(FetchLatestBlocksRequest {
259            authorities: authorities
260                .iter()
261                .map(|authority| authority.value() as u32)
262                .collect(),
263        });
264        request.set_timeout(timeout);
265        let mut stream = client
266            .fetch_latest_blocks(request)
267            .await
268            .map_err(|e| {
269                if e.code() == tonic::Code::DeadlineExceeded {
270                    ConsensusError::NetworkRequestTimeout(format!("fetch_blocks failed: {e:?}"))
271                } else {
272                    ConsensusError::NetworkRequest(format!("fetch_blocks failed: {e:?}"))
273                }
274            })?
275            .into_inner();
276        let mut blocks = vec![];
277        let mut total_fetched_bytes = 0;
278        loop {
279            match stream.message().await {
280                Ok(Some(response)) => {
281                    for b in &response.blocks {
282                        total_fetched_bytes += b.len();
283                    }
284                    blocks.extend(response.blocks);
285                    if total_fetched_bytes > MAX_TOTAL_FETCHED_BYTES {
286                        info!(
287                            "fetch_blocks() fetched bytes exceeded limit: {} > {}, terminating stream.",
288                            total_fetched_bytes, MAX_TOTAL_FETCHED_BYTES,
289                        );
290                        break;
291                    }
292                }
293                Ok(None) => {
294                    break;
295                }
296                Err(e) => {
297                    if blocks.is_empty() {
298                        if e.code() == tonic::Code::DeadlineExceeded {
299                            return Err(ConsensusError::NetworkRequestTimeout(format!(
300                                "fetch_blocks failed mid-stream: {e:?}"
301                            )));
302                        }
303                        return Err(ConsensusError::NetworkRequest(format!(
304                            "fetch_blocks failed mid-stream: {e:?}"
305                        )));
306                    } else {
307                        warn!("fetch_latest_blocks failed mid-stream: {e:?}");
308                        break;
309                    }
310                }
311            }
312        }
313        Ok(blocks)
314    }
315
316    async fn get_latest_rounds(
317        &self,
318        peer: AuthorityIndex,
319        timeout: Duration,
320    ) -> ConsensusResult<(Vec<Round>, Vec<Round>)> {
321        let mut client = self.get_client(peer, timeout).await?;
322        let mut request = Request::new(GetLatestRoundsRequest {});
323        request.set_timeout(timeout);
324        let response = client.get_latest_rounds(request).await.map_err(|e| {
325            ConsensusError::NetworkRequest(format!("get_latest_rounds failed: {e:?}"))
326        })?;
327        let response = response.into_inner();
328        Ok((response.highest_received, response.highest_accepted))
329    }
330}
331
332// Tonic channel wrapped with layers.
333type Channel = iota_network_stack::callback::Callback<
334    tower_http::trace::Trace<
335        tonic_rustls::Channel,
336        tower_http::classify::SharedClassifier<tower_http::classify::GrpcErrorsAsFailures>,
337    >,
338    MetricsCallbackMaker,
339>;
340
341/// Manages a pool of connections to peers to avoid constantly reconnecting,
342/// which can be expensive.
343struct ChannelPool {
344    context: Arc<Context>,
345    // Size is limited by known authorities in the committee.
346    channels: RwLock<BTreeMap<AuthorityIndex, Channel>>,
347}
348
349impl ChannelPool {
350    fn new(context: Arc<Context>) -> Self {
351        Self {
352            context,
353            channels: RwLock::new(BTreeMap::new()),
354        }
355    }
356
357    async fn get_channel(
358        &self,
359        network_keypair: NetworkKeyPair,
360        peer: AuthorityIndex,
361        timeout: Duration,
362    ) -> ConsensusResult<Channel> {
363        {
364            let channels = self.channels.read();
365            if let Some(channel) = channels.get(&peer) {
366                return Ok(channel.clone());
367            }
368        }
369
370        let authority = self.context.committee.authority(peer);
371        let address = to_host_port_str(&authority.address).map_err(|e| {
372            ConsensusError::NetworkConfig(format!("Cannot convert address to host:port: {e:?}"))
373        })?;
374        let address = format!("https://{address}");
375        let config = &self.context.parameters.tonic;
376        let buffer_size = config.connection_buffer_size;
377        let client_tls_config = create_rustls_client_config(&self.context, network_keypair, peer);
378        let endpoint = tonic_rustls::Channel::from_shared(address.clone())
379            .unwrap()
380            .connect_timeout(timeout)
381            .initial_connection_window_size(Some(buffer_size as u32))
382            .initial_stream_window_size(Some(buffer_size as u32 / 2))
383            .keep_alive_while_idle(true)
384            .keep_alive_timeout(config.keepalive_interval)
385            .http2_keep_alive_interval(config.keepalive_interval)
386            // tcp keepalive is probably unnecessary and is unsupported by msim.
387            .user_agent("mysticeti")
388            .unwrap()
389            .tls_config(client_tls_config)
390            .unwrap();
391
392        let deadline = tokio::time::Instant::now() + timeout;
393        let channel = loop {
394            trace!("Connecting to endpoint at {address}");
395            match endpoint.connect().await {
396                Ok(channel) => break channel,
397                Err(e) => {
398                    debug!("Failed to connect to endpoint at {address}: {e:?}");
399                    if tokio::time::Instant::now() >= deadline {
400                        return Err(ConsensusError::NetworkClientConnection(format!(
401                            "Timed out connecting to endpoint at {address}: {e:?}"
402                        )));
403                    }
404                    tokio::time::sleep(Duration::from_secs(1)).await;
405                }
406            }
407        };
408        trace!("Connected to {address}");
409
410        let channel = tower::ServiceBuilder::new()
411            .layer(CallbackLayer::new(MetricsCallbackMaker::new(
412                self.context.metrics.network_metrics.outbound.clone(),
413                self.context.parameters.tonic.excessive_message_size,
414            )))
415            .layer(
416                TraceLayer::new_for_grpc()
417                    .make_span_with(DefaultMakeSpan::new().level(tracing::Level::TRACE))
418                    .on_failure(DefaultOnFailure::new().level(tracing::Level::DEBUG)),
419            )
420            .service(channel);
421
422        let mut channels = self.channels.write();
423        // There should not be many concurrent attempts at connecting to the same peer.
424        let channel = channels.entry(peer).or_insert(channel);
425        Ok(channel.clone())
426    }
427}
428
429/// Proxies Tonic requests to NetworkService with actual handler implementation.
430struct TonicServiceProxy<S: NetworkService> {
431    context: Arc<Context>,
432    service: Arc<S>,
433}
434
435impl<S: NetworkService> TonicServiceProxy<S> {
436    fn new(context: Arc<Context>, service: Arc<S>) -> Self {
437        Self { context, service }
438    }
439}
440
441#[async_trait]
442impl<S: NetworkService> ConsensusService for TonicServiceProxy<S> {
443    async fn send_block(
444        &self,
445        request: Request<SendBlockRequest>,
446    ) -> Result<Response<SendBlockResponse>, tonic::Status> {
447        let Some(peer_index) = request
448            .extensions()
449            .get::<PeerInfo>()
450            .map(|p| p.authority_index)
451        else {
452            return Err(tonic::Status::internal("PeerInfo not found"));
453        };
454        let block = request.into_inner().block;
455        let block = ExtendedSerializedBlock {
456            block,
457            excluded_ancestors: vec![],
458        };
459        self.service
460            .handle_send_block(peer_index, block)
461            .await
462            .map_err(|e| tonic::Status::invalid_argument(format!("{e:?}")))?;
463        Ok(Response::new(SendBlockResponse {}))
464    }
465
466    type SubscribeBlocksStream =
467        Pin<Box<dyn Stream<Item = Result<SubscribeBlocksResponse, tonic::Status>> + Send>>;
468
469    async fn subscribe_blocks(
470        &self,
471        request: Request<Streaming<SubscribeBlocksRequest>>,
472    ) -> Result<Response<Self::SubscribeBlocksStream>, tonic::Status> {
473        let Some(peer_index) = request
474            .extensions()
475            .get::<PeerInfo>()
476            .map(|p| p.authority_index)
477        else {
478            return Err(tonic::Status::internal("PeerInfo not found"));
479        };
480        let mut request_stream = request.into_inner();
481        let first_request = match request_stream.next().await {
482            Some(Ok(r)) => r,
483            Some(Err(e)) => {
484                debug!(
485                    "subscribe_blocks() request from {} failed: {e:?}",
486                    peer_index
487                );
488                return Err(tonic::Status::invalid_argument("Request error"));
489            }
490            None => {
491                return Err(tonic::Status::invalid_argument("Missing request"));
492            }
493        };
494        let stream = self
495            .service
496            .handle_subscribe_blocks(peer_index, first_request.last_received_round)
497            .await
498            .map_err(|e| tonic::Status::internal(format!("{e:?}")))?
499            .map(|block| {
500                Ok(SubscribeBlocksResponse {
501                    block: block.block,
502                    excluded_ancestors: block.excluded_ancestors,
503                })
504            });
505        let rate_limited_stream =
506            tokio_stream::StreamExt::throttle(stream, self.context.parameters.min_round_delay / 2)
507                .boxed();
508        Ok(Response::new(rate_limited_stream))
509    }
510
511    type FetchBlocksStream = Iter<std::vec::IntoIter<Result<FetchBlocksResponse, tonic::Status>>>;
512
513    async fn fetch_blocks(
514        &self,
515        request: Request<FetchBlocksRequest>,
516    ) -> Result<Response<Self::FetchBlocksStream>, tonic::Status> {
517        let Some(peer_index) = request
518            .extensions()
519            .get::<PeerInfo>()
520            .map(|p| p.authority_index)
521        else {
522            return Err(tonic::Status::internal("PeerInfo not found"));
523        };
524        let inner = request.into_inner();
525        let block_refs = inner
526            .block_refs
527            .into_iter()
528            .filter_map(|serialized| match bcs::from_bytes(&serialized) {
529                Ok(r) => Some(r),
530                Err(e) => {
531                    debug!("Failed to deserialize block ref {:?}: {e:?}", serialized);
532                    None
533                }
534            })
535            .collect();
536        let highest_accepted_rounds = inner.highest_accepted_rounds;
537        let blocks = self
538            .service
539            .handle_fetch_blocks(peer_index, block_refs, highest_accepted_rounds)
540            .await
541            .map_err(|e| tonic::Status::internal(format!("{e:?}")))?;
542        let responses: std::vec::IntoIter<Result<FetchBlocksResponse, tonic::Status>> =
543            chunk_blocks(blocks, MAX_FETCH_RESPONSE_BYTES)
544                .into_iter()
545                .map(|blocks| Ok(FetchBlocksResponse { blocks }))
546                .collect::<Vec<_>>()
547                .into_iter();
548        let stream = iter(responses);
549        Ok(Response::new(stream))
550    }
551
552    async fn fetch_commits(
553        &self,
554        request: Request<FetchCommitsRequest>,
555    ) -> Result<Response<FetchCommitsResponse>, tonic::Status> {
556        let Some(peer_index) = request
557            .extensions()
558            .get::<PeerInfo>()
559            .map(|p| p.authority_index)
560        else {
561            return Err(tonic::Status::internal("PeerInfo not found"));
562        };
563        let request = request.into_inner();
564        let (commits, certifier_blocks) = self
565            .service
566            .handle_fetch_commits(peer_index, (request.start..=request.end).into())
567            .await
568            .map_err(|e| tonic::Status::internal(format!("{e:?}")))?;
569        let commits = commits
570            .into_iter()
571            .map(|c| c.serialized().clone())
572            .collect();
573        let certifier_blocks = certifier_blocks
574            .into_iter()
575            .map(|b| b.serialized().clone())
576            .collect();
577        Ok(Response::new(FetchCommitsResponse {
578            commits,
579            certifier_blocks,
580        }))
581    }
582
583    type FetchLatestBlocksStream =
584        Iter<std::vec::IntoIter<Result<FetchLatestBlocksResponse, tonic::Status>>>;
585
586    async fn fetch_latest_blocks(
587        &self,
588        request: Request<FetchLatestBlocksRequest>,
589    ) -> Result<Response<Self::FetchLatestBlocksStream>, tonic::Status> {
590        let Some(peer_index) = request
591            .extensions()
592            .get::<PeerInfo>()
593            .map(|p| p.authority_index)
594        else {
595            return Err(tonic::Status::internal("PeerInfo not found"));
596        };
597        let inner = request.into_inner();
598
599        // Convert the authority indexes and validate them
600        let mut authorities = vec![];
601        for authority in inner.authorities.into_iter() {
602            let Some(authority) = self
603                .context
604                .committee
605                .to_authority_index(authority as usize)
606            else {
607                return Err(tonic::Status::internal(format!(
608                    "Invalid authority index provided {authority}"
609                )));
610            };
611            authorities.push(authority);
612        }
613
614        let blocks = self
615            .service
616            .handle_fetch_latest_blocks(peer_index, authorities)
617            .await
618            .map_err(|e| tonic::Status::internal(format!("{e:?}")))?;
619        let responses: std::vec::IntoIter<Result<FetchLatestBlocksResponse, tonic::Status>> =
620            chunk_blocks(blocks, MAX_FETCH_RESPONSE_BYTES)
621                .into_iter()
622                .map(|blocks| Ok(FetchLatestBlocksResponse { blocks }))
623                .collect::<Vec<_>>()
624                .into_iter();
625        let stream = iter(responses);
626        Ok(Response::new(stream))
627    }
628
629    async fn get_latest_rounds(
630        &self,
631        request: Request<GetLatestRoundsRequest>,
632    ) -> Result<Response<GetLatestRoundsResponse>, tonic::Status> {
633        let Some(peer_index) = request
634            .extensions()
635            .get::<PeerInfo>()
636            .map(|p| p.authority_index)
637        else {
638            return Err(tonic::Status::internal("PeerInfo not found"));
639        };
640        let (highest_received, highest_accepted) = self
641            .service
642            .handle_get_latest_rounds(peer_index)
643            .await
644            .map_err(|e| tonic::Status::internal(format!("{e:?}")))?;
645        Ok(Response::new(GetLatestRoundsResponse {
646            highest_received,
647            highest_accepted,
648        }))
649    }
650}
651
652/// Manages the lifecycle of Tonic network client and service. Typical usage
653/// during initialization:
654/// 1. Create a new `TonicManager`.
655/// 2. Take `TonicClient` from `TonicManager::client()`.
656/// 3. Create consensus components.
657/// 4. Create `TonicService` for consensus service handler.
658/// 5. Install `TonicService` to `TonicManager` with
659///    `TonicManager::install_service()`.
660pub(crate) struct TonicManager {
661    context: Arc<Context>,
662    network_keypair: NetworkKeyPair,
663    client: Arc<TonicClient>,
664    server: Option<ServerHandle>,
665}
666
667impl TonicManager {
668    pub(crate) fn new(context: Arc<Context>, network_keypair: NetworkKeyPair) -> Self {
669        Self {
670            context: context.clone(),
671            network_keypair: network_keypair.clone(),
672            client: Arc::new(TonicClient::new(context, network_keypair)),
673            server: None,
674        }
675    }
676}
677
678impl<S: NetworkService> NetworkManager<S> for TonicManager {
679    type Client = TonicClient;
680
681    fn new(context: Arc<Context>, network_keypair: NetworkKeyPair) -> Self {
682        TonicManager::new(context, network_keypair)
683    }
684
685    fn client(&self) -> Arc<Self::Client> {
686        self.client.clone()
687    }
688
689    async fn install_service(&mut self, service: Arc<S>) {
690        self.context
691            .metrics
692            .network_metrics
693            .network_type
694            .with_label_values(&["tonic"])
695            .set(1);
696
697        info!("Starting tonic service");
698
699        let authority = self.context.committee.authority(self.context.own_index);
700        // By default, bind to the unspecified address to allow the actual address to be
701        // assigned. But bind to localhost if it is requested.
702        let own_address = if authority.address.is_localhost_ip() {
703            authority.address.clone()
704        } else {
705            authority.address.with_zero_ip()
706        };
707        let own_address = to_socket_addr(&own_address).unwrap();
708        let service = TonicServiceProxy::new(self.context.clone(), service);
709        let config = &self.context.parameters.tonic;
710
711        let connections_info = Arc::new(ConnectionsInfo::new(self.context.clone()));
712        let layers = tower::ServiceBuilder::new()
713            // Add a layer to extract a peer's PeerInfo from their TLS certs
714            .map_request(move |mut request: http::Request<_>| {
715                if let Some(peer_certificates) =
716                    request.extensions().get::<iota_http::PeerCertificates>()
717                {
718                    if let Some(peer_info) =
719                        peer_info_from_certs(&connections_info, peer_certificates)
720                    {
721                        request.extensions_mut().insert(peer_info);
722                    }
723                }
724                request
725            })
726            .layer(CallbackLayer::new(MetricsCallbackMaker::new(
727                self.context.metrics.network_metrics.inbound.clone(),
728                self.context.parameters.tonic.excessive_message_size,
729            )))
730            .layer(
731                TraceLayer::new_for_grpc()
732                    .make_span_with(DefaultMakeSpan::new().level(tracing::Level::TRACE))
733                    .on_failure(DefaultOnFailure::new().level(tracing::Level::DEBUG)),
734            )
735            .layer_fn(|service| iota_network_stack::grpc_timeout::GrpcTimeout::new(service, None));
736
737        let mut consensus_service_server = ConsensusServiceServer::new(service)
738            .max_encoding_message_size(config.message_size_limit)
739            .max_decoding_message_size(config.message_size_limit);
740
741        if self.context.protocol_config.consensus_zstd_compression() {
742            consensus_service_server = consensus_service_server
743                .send_compressed(CompressionEncoding::Zstd)
744                .accept_compressed(CompressionEncoding::Zstd);
745        }
746
747        let consensus_service = tonic::service::Routes::new(consensus_service_server)
748            .into_axum_router()
749            .route_layer(layers);
750
751        let tls_server_config =
752            create_rustls_server_config(&self.context, self.network_keypair.clone());
753
754        // Calculate some metrics around send/recv buffer sizes for the current
755        // machine/OS
756        #[cfg(not(msim))]
757        {
758            let tcp_connection_metrics =
759                &self.context.metrics.network_metrics.tcp_connection_metrics;
760
761            // Try creating an ephemeral port to test the highest allowed send and recv
762            // buffer sizes. Buffer sizes are not set explicitly on the socket
763            // used for real traffic, to allow the OS to set appropriate values.
764            {
765                let ephemeral_addr = SocketAddr::new(own_address.ip(), 0);
766                let ephemeral_socket = create_socket(&ephemeral_addr);
767                tcp_connection_metrics
768                    .socket_send_buffer_size
769                    .set(ephemeral_socket.send_buffer_size().unwrap_or(0) as i64);
770                tcp_connection_metrics
771                    .socket_recv_buffer_size
772                    .set(ephemeral_socket.recv_buffer_size().unwrap_or(0) as i64);
773
774                if let Err(e) = ephemeral_socket.set_send_buffer_size(32 << 20) {
775                    info!("Failed to set send buffer size: {e:?}");
776                }
777                if let Err(e) = ephemeral_socket.set_recv_buffer_size(32 << 20) {
778                    info!("Failed to set recv buffer size: {e:?}");
779                }
780                if ephemeral_socket.bind(ephemeral_addr).is_ok() {
781                    tcp_connection_metrics
782                        .socket_send_buffer_max_size
783                        .set(ephemeral_socket.send_buffer_size().unwrap_or(0) as i64);
784                    tcp_connection_metrics
785                        .socket_recv_buffer_max_size
786                        .set(ephemeral_socket.recv_buffer_size().unwrap_or(0) as i64);
787                };
788            }
789        }
790
791        let http_config = iota_http::Config::default()
792            .tcp_nodelay(true)
793            .initial_connection_window_size(64 << 20)
794            .initial_stream_window_size(32 << 20)
795            .http2_keepalive_interval(Some(config.keepalive_interval))
796            .http2_keepalive_timeout(Some(config.keepalive_interval))
797            .accept_http1(false);
798
799        // Create server
800        //
801        // During simtest crash/restart tests there may be an older instance of
802        // consensus running that is bound to the TCP port of `own_address` that
803        // hasn't finished relinquishing control of the port yet. So instead of
804        // crashing when the address is inuse, we will retry for a short/
805        // reasonable period of time before giving up.
806        let deadline = Instant::now() + Duration::from_secs(20);
807        let server = loop {
808            match iota_http::Builder::new()
809                .config(http_config.clone())
810                .tls_config(tls_server_config.clone())
811                .serve(own_address, consensus_service.clone())
812            {
813                Ok(server) => break server,
814                Err(err) => {
815                    warn!("Error starting consensus server: {err:?}");
816                    if Instant::now() > deadline {
817                        panic!("Failed to start consensus server within required deadline");
818                    }
819                    tokio::time::sleep(Duration::from_secs(1)).await;
820                }
821            }
822        };
823
824        info!("Server started at: {own_address}");
825        self.server = Some(server);
826    }
827
828    async fn stop(&mut self) {
829        if let Some(server) = self.server.take() {
830            server.shutdown().await;
831        }
832
833        self.context
834            .metrics
835            .network_metrics
836            .network_type
837            .with_label_values(&["tonic"])
838            .set(0);
839    }
840}
841
842// Ensure that if there is an active network running that it is shutdown when
843// the TonicManager is dropped.
844impl Drop for TonicManager {
845    fn drop(&mut self) {
846        if let Some(server) = self.server.as_ref() {
847            server.trigger_shutdown();
848        }
849    }
850}
851
852// TODO: improve iota-http to allow for providing a MakeService so that this can
853// be done once per connection
854fn peer_info_from_certs(
855    connections_info: &ConnectionsInfo,
856    peer_certificates: &iota_http::PeerCertificates,
857) -> Option<PeerInfo> {
858    let certs = peer_certificates.peer_certs();
859
860    if certs.len() != 1 {
861        trace!(
862            "Unexpected number of certificates from TLS stream: {}",
863            certs.len()
864        );
865        return None;
866    }
867    trace!("Received {} certificates", certs.len());
868    let public_key = iota_tls::public_key_from_certificate(&certs[0])
869        .map_err(|e| {
870            trace!("Failed to extract public key from certificate: {e:?}");
871            e
872        })
873        .ok()?;
874    let client_public_key = NetworkPublicKey::new(public_key);
875    let Some(authority_index) = connections_info.authority_index(&client_public_key) else {
876        error!("Failed to find the authority with public key {client_public_key:?}");
877        return None;
878    };
879    Some(PeerInfo { authority_index })
880}
881
882/// Attempts to convert a multiaddr of the form `/[ip4,ip6,dns]/{}/udp/{port}`
883/// into a host:port string.
884fn to_host_port_str(addr: &Multiaddr) -> Result<String, String> {
885    let mut iter = addr.iter();
886
887    match (iter.next(), iter.next()) {
888        (Some(Protocol::Ip4(ipaddr)), Some(Protocol::Udp(port))) => {
889            Ok(format!("{}:{}", ipaddr, port))
890        }
891        (Some(Protocol::Ip6(ipaddr)), Some(Protocol::Udp(port))) => {
892            Ok(format!("{}:{}", ipaddr, port))
893        }
894        (Some(Protocol::Dns(hostname)), Some(Protocol::Udp(port))) => {
895            Ok(format!("{}:{}", hostname, port))
896        }
897
898        _ => Err(format!("unsupported multiaddr: {addr}")),
899    }
900}
901
902/// Attempts to convert a multiaddr of the form `/[ip4,ip6]/{}/[udp,tcp]/{port}`
903/// into a SocketAddr value.
904pub fn to_socket_addr(addr: &Multiaddr) -> Result<SocketAddr, String> {
905    let mut iter = addr.iter();
906
907    match (iter.next(), iter.next()) {
908        (Some(Protocol::Ip4(ipaddr)), Some(Protocol::Udp(port)))
909        | (Some(Protocol::Ip4(ipaddr)), Some(Protocol::Tcp(port))) => {
910            Ok(SocketAddr::V4(SocketAddrV4::new(ipaddr, port)))
911        }
912
913        (Some(Protocol::Ip6(ipaddr)), Some(Protocol::Udp(port)))
914        | (Some(Protocol::Ip6(ipaddr)), Some(Protocol::Tcp(port))) => {
915            Ok(SocketAddr::V6(SocketAddrV6::new(ipaddr, port, 0, 0)))
916        }
917
918        _ => Err(format!("unsupported multiaddr: {addr}")),
919    }
920}
921
922#[cfg(not(msim))]
923fn create_socket(address: &SocketAddr) -> tokio::net::TcpSocket {
924    let socket = if address.is_ipv4() {
925        tokio::net::TcpSocket::new_v4()
926    } else if address.is_ipv6() {
927        tokio::net::TcpSocket::new_v6()
928    } else {
929        panic!("Invalid own address: {address:?}");
930    }
931    .unwrap_or_else(|e| panic!("Cannot create TCP socket: {e:?}"));
932    if let Err(e) = socket.set_nodelay(true) {
933        info!("Failed to set TCP_NODELAY: {e:?}");
934    }
935    if let Err(e) = socket.set_reuseaddr(true) {
936        info!("Failed to set SO_REUSEADDR: {e:?}");
937    }
938    socket
939}
940
941/// Looks up authority index by authority public key.
942///
943/// TODO: Add connection monitoring, and keep track of connected peers.
944/// TODO: Maybe merge with connection_monitor.rs
945struct ConnectionsInfo {
946    authority_key_to_index: BTreeMap<NetworkPublicKey, AuthorityIndex>,
947}
948
949impl ConnectionsInfo {
950    fn new(context: Arc<Context>) -> Self {
951        let authority_key_to_index = context
952            .committee
953            .authorities()
954            .map(|(index, authority)| (authority.network_key.clone(), index))
955            .collect();
956        Self {
957            authority_key_to_index,
958        }
959    }
960
961    fn authority_index(&self, key: &NetworkPublicKey) -> Option<AuthorityIndex> {
962        self.authority_key_to_index.get(key).copied()
963    }
964}
965
966/// Information about the client peer, set per connection.
967#[derive(Clone, Debug)]
968struct PeerInfo {
969    authority_index: AuthorityIndex,
970}
971
972// Adapt MetricsCallbackMaker and MetricsResponseCallback to http.
973
974impl SizedRequest for http::request::Parts {
975    fn size(&self) -> usize {
976        // TODO: implement this.
977        0
978    }
979
980    fn route(&self) -> String {
981        let path = self.uri.path();
982        path.rsplit_once('/')
983            .map(|(_, route)| route)
984            .unwrap_or("unknown")
985            .to_string()
986    }
987}
988
989impl SizedResponse for http::response::Parts {
990    fn size(&self) -> usize {
991        // TODO: implement this.
992        0
993    }
994
995    fn error_type(&self) -> Option<String> {
996        if self.status.is_success() {
997            None
998        } else {
999            Some(self.status.to_string())
1000        }
1001    }
1002}
1003
1004impl MakeCallbackHandler for MetricsCallbackMaker {
1005    type Handler = MetricsResponseCallback;
1006
1007    fn make_handler(&self, request: &http::request::Parts) -> Self::Handler {
1008        self.handle_request(request)
1009    }
1010}
1011
1012impl ResponseHandler for MetricsResponseCallback {
1013    fn on_response(self, response: &http::response::Parts) {
1014        self.on_response(response)
1015    }
1016
1017    fn on_error<E>(self, err: &E) {
1018        self.on_error(err)
1019    }
1020}
1021
1022/// Network message types.
1023#[derive(Clone, prost::Message)]
1024pub(crate) struct SendBlockRequest {
1025    // Serialized SignedBlock.
1026    #[prost(bytes = "bytes", tag = "1")]
1027    block: Bytes,
1028}
1029
1030#[derive(Clone, prost::Message)]
1031pub(crate) struct SendBlockResponse {}
1032
1033#[derive(Clone, prost::Message)]
1034pub(crate) struct SubscribeBlocksRequest {
1035    #[prost(uint32, tag = "1")]
1036    last_received_round: Round,
1037}
1038
1039#[derive(Clone, prost::Message)]
1040pub(crate) struct SubscribeBlocksResponse {
1041    #[prost(bytes = "bytes", tag = "1")]
1042    block: Bytes,
1043    // Serialized BlockRefs that are excluded from the blocks ancestors.
1044    #[prost(bytes = "vec", repeated, tag = "2")]
1045    excluded_ancestors: Vec<Vec<u8>>,
1046}
1047
1048#[derive(Clone, prost::Message)]
1049pub(crate) struct FetchBlocksRequest {
1050    #[prost(bytes = "vec", repeated, tag = "1")]
1051    block_refs: Vec<Vec<u8>>,
1052    // The highest accepted round per authority. The vector represents the round for each authority
1053    // and its length should be the same as the committee size.
1054    #[prost(uint32, repeated, tag = "2")]
1055    highest_accepted_rounds: Vec<Round>,
1056}
1057
1058#[derive(Clone, prost::Message)]
1059pub(crate) struct FetchBlocksResponse {
1060    // The response of the requested blocks as Serialized SignedBlock.
1061    #[prost(bytes = "bytes", repeated, tag = "1")]
1062    blocks: Vec<Bytes>,
1063}
1064
1065#[derive(Clone, prost::Message)]
1066pub(crate) struct FetchCommitsRequest {
1067    #[prost(uint32, tag = "1")]
1068    start: CommitIndex,
1069    #[prost(uint32, tag = "2")]
1070    end: CommitIndex,
1071}
1072
1073#[derive(Clone, prost::Message)]
1074pub(crate) struct FetchCommitsResponse {
1075    // Serialized consecutive Commit.
1076    #[prost(bytes = "bytes", repeated, tag = "1")]
1077    commits: Vec<Bytes>,
1078    // Serialized SignedBlock that certify the last commit from above.
1079    #[prost(bytes = "bytes", repeated, tag = "2")]
1080    certifier_blocks: Vec<Bytes>,
1081}
1082
1083#[derive(Clone, prost::Message)]
1084pub(crate) struct FetchLatestBlocksRequest {
1085    #[prost(uint32, repeated, tag = "1")]
1086    authorities: Vec<u32>,
1087}
1088
1089#[derive(Clone, prost::Message)]
1090pub(crate) struct FetchLatestBlocksResponse {
1091    // The response of the requested blocks as Serialized SignedBlock.
1092    #[prost(bytes = "bytes", repeated, tag = "1")]
1093    blocks: Vec<Bytes>,
1094}
1095
1096#[derive(Clone, prost::Message)]
1097pub(crate) struct GetLatestRoundsRequest {}
1098
1099#[derive(Clone, prost::Message)]
1100pub(crate) struct GetLatestRoundsResponse {
1101    // Highest received round per authority.
1102    #[prost(uint32, repeated, tag = "1")]
1103    highest_received: Vec<u32>,
1104    // Highest accepted round per authority.
1105    #[prost(uint32, repeated, tag = "2")]
1106    highest_accepted: Vec<u32>,
1107}
1108
1109fn chunk_blocks(blocks: Vec<Bytes>, chunk_limit: usize) -> Vec<Vec<Bytes>> {
1110    let mut chunks = vec![];
1111    let mut chunk = vec![];
1112    let mut chunk_size = 0;
1113    for block in blocks {
1114        let block_size = block.len();
1115        if !chunk.is_empty() && chunk_size + block_size > chunk_limit {
1116            chunks.push(chunk);
1117            chunk = vec![];
1118            chunk_size = 0;
1119        }
1120        chunk.push(block);
1121        chunk_size += block_size;
1122    }
1123    if !chunk.is_empty() {
1124        chunks.push(chunk);
1125    }
1126    chunks
1127}