iota_network_stack/
codec.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{io::Read, marker::PhantomData};
6
7use bytes::{Buf, BufMut};
8use tonic::{
9    Status,
10    codec::{Codec, DecodeBuf, Decoder, EncodeBuf, Encoder},
11};
12
13#[derive(Debug)]
14pub struct BcsEncoder<T>(PhantomData<T>);
15
16impl<T: serde::Serialize> Encoder for BcsEncoder<T> {
17    type Item = T;
18    type Error = Status;
19
20    fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
21        bcs::serialize_into(&mut buf.writer(), &item).map_err(|e| Status::internal(e.to_string()))
22    }
23}
24
25#[derive(Debug)]
26pub struct BcsDecoder<U>(PhantomData<U>);
27
28impl<U: serde::de::DeserializeOwned> Decoder for BcsDecoder<U> {
29    type Item = U;
30    type Error = Status;
31
32    fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
33        if !buf.has_remaining() {
34            return Ok(None);
35        }
36
37        let chunk = buf.chunk();
38
39        let item: Self::Item =
40            bcs::from_bytes(chunk).map_err(|e| Status::internal(e.to_string()))?;
41        buf.advance(chunk.len());
42
43        Ok(Some(item))
44    }
45}
46
47/// A [`Codec`] that implements `application/grpc+bcs` via the serde library.
48#[derive(Debug, Clone)]
49pub struct BcsCodec<T, U>(PhantomData<(T, U)>);
50
51impl<T, U> Default for BcsCodec<T, U> {
52    fn default() -> Self {
53        Self(PhantomData)
54    }
55}
56
57impl<T, U> Codec for BcsCodec<T, U>
58where
59    T: serde::Serialize + Send + 'static,
60    U: serde::de::DeserializeOwned + Send + 'static,
61{
62    type Encode = T;
63    type Decode = U;
64    type Encoder = BcsEncoder<T>;
65    type Decoder = BcsDecoder<U>;
66
67    fn encoder(&mut self) -> Self::Encoder {
68        BcsEncoder(PhantomData)
69    }
70
71    fn decoder(&mut self) -> Self::Decoder {
72        BcsDecoder(PhantomData)
73    }
74}
75
76#[derive(Debug)]
77pub struct BcsSnappyEncoder<T>(PhantomData<T>);
78
79impl<T: serde::Serialize> Encoder for BcsSnappyEncoder<T> {
80    type Item = T;
81    type Error = Status;
82
83    fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
84        let mut snappy_encoder = snap::write::FrameEncoder::new(buf.writer());
85        bcs::serialize_into(&mut snappy_encoder, &item).map_err(|e| Status::internal(e.to_string()))
86    }
87}
88
89#[derive(Debug)]
90pub struct BcsSnappyDecoder<U>(PhantomData<U>);
91
92impl<U: serde::de::DeserializeOwned> Decoder for BcsSnappyDecoder<U> {
93    type Item = U;
94    type Error = Status;
95
96    fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
97        let compressed_size = buf.remaining();
98        if compressed_size == 0 {
99            return Ok(None);
100        }
101        let mut snappy_decoder = snap::read::FrameDecoder::new(buf.reader());
102        let mut bytes = Vec::with_capacity(compressed_size);
103        snappy_decoder.read_to_end(&mut bytes)?;
104        let item =
105            bcs::from_bytes(bytes.as_slice()).map_err(|e| Status::internal(e.to_string()))?;
106        Ok(Some(item))
107    }
108}
109
110/// A [`Codec`] that implements `bcs` encoding/decoding and snappy
111/// compression/decompression via the serde library.
112#[derive(Debug, Clone)]
113pub struct BcsSnappyCodec<T, U>(PhantomData<(T, U)>);
114
115impl<T, U> Default for BcsSnappyCodec<T, U> {
116    fn default() -> Self {
117        Self(PhantomData)
118    }
119}
120
121impl<T, U> Codec for BcsSnappyCodec<T, U>
122where
123    T: serde::Serialize + Send + 'static,
124    U: serde::de::DeserializeOwned + Send + 'static,
125{
126    type Encode = T;
127    type Decode = U;
128    type Encoder = BcsSnappyEncoder<T>;
129    type Decoder = BcsSnappyDecoder<U>;
130
131    fn encoder(&mut self) -> Self::Encoder {
132        BcsSnappyEncoder(PhantomData)
133    }
134
135    fn decoder(&mut self) -> Self::Decoder {
136        BcsSnappyDecoder(PhantomData)
137    }
138}
139
140// Anemo variant of BCS codec using Snappy for compression.
141pub mod anemo {
142    use std::{io::Read, marker::PhantomData};
143
144    use ::anemo::rpc::codec::{Codec, Decoder, Encoder};
145    use bytes::Buf;
146
147    #[derive(Debug)]
148    pub struct BcsSnappyEncoder<T>(PhantomData<T>);
149
150    impl<T: serde::Serialize> Encoder for BcsSnappyEncoder<T> {
151        type Item = T;
152        type Error = bcs::Error;
153
154        fn encode(&mut self, item: Self::Item) -> Result<bytes::Bytes, Self::Error> {
155            let mut buf = Vec::<u8>::new();
156            let mut snappy_encoder = snap::write::FrameEncoder::new(&mut buf);
157            bcs::serialize_into(&mut snappy_encoder, &item)?;
158            drop(snappy_encoder);
159            Ok(buf.into())
160        }
161    }
162
163    #[derive(Debug)]
164    pub struct BcsSnappyDecoder<U>(PhantomData<U>);
165
166    impl<U: serde::de::DeserializeOwned> Decoder for BcsSnappyDecoder<U> {
167        type Item = U;
168        type Error = bcs::Error;
169
170        fn decode(&mut self, buf: bytes::Bytes) -> Result<Self::Item, Self::Error> {
171            let compressed_size = buf.len();
172            let mut snappy_decoder = snap::read::FrameDecoder::new(buf.reader()).take(1 << 30);
173            let mut bytes = Vec::with_capacity(compressed_size);
174            snappy_decoder.read_to_end(&mut bytes)?;
175            bcs::from_bytes(bytes.as_slice())
176        }
177    }
178
179    /// A [`Codec`] that implements `bcs` encoding/decoding via the serde
180    /// library.
181    #[derive(Debug, Clone)]
182    pub struct BcsSnappyCodec<T, U>(PhantomData<(T, U)>);
183
184    impl<T, U> Default for BcsSnappyCodec<T, U> {
185        fn default() -> Self {
186            Self(PhantomData)
187        }
188    }
189
190    impl<T, U> Codec for BcsSnappyCodec<T, U>
191    where
192        T: serde::Serialize + Send + 'static,
193        U: serde::de::DeserializeOwned + Send + 'static,
194    {
195        type Encode = T;
196        type Decode = U;
197        type Encoder = BcsSnappyEncoder<T>;
198        type Decoder = BcsSnappyDecoder<U>;
199
200        fn encoder(&mut self) -> Self::Encoder {
201            BcsSnappyEncoder(PhantomData)
202        }
203
204        fn decoder(&mut self) -> Self::Decoder {
205            BcsSnappyDecoder(PhantomData)
206        }
207
208        fn format_name(&self) -> &'static str {
209            "bcs"
210        }
211    }
212}