iota_tls/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5mod acceptor;
6mod certgen;
7mod verifier;
8
9use std::sync::Arc;
10
11pub use acceptor::{TlsAcceptor, TlsConnectionInfo};
12pub use certgen::SelfSignedCertificate;
13use fastcrypto::ed25519::{Ed25519PrivateKey, Ed25519PublicKey};
14pub use rustls;
15use rustls::ClientConfig;
16use tokio_rustls::rustls::ServerConfig;
17pub use verifier::{
18    AllowAll, AllowPublicKeys, Allower, ClientCertVerifier, ServerCertVerifier,
19    public_key_from_certificate,
20};
21
22pub const IOTA_VALIDATOR_SERVER_NAME: &str = "iota";
23
24pub fn create_rustls_server_config(
25    private_key: Ed25519PrivateKey,
26    server_name: String,
27) -> ServerConfig {
28    // TODO: refactor to use key bytes
29    let self_signed_cert = SelfSignedCertificate::new(private_key, server_name.as_str());
30    let tls_cert = self_signed_cert.rustls_certificate();
31    let tls_private_key = self_signed_cert.rustls_private_key();
32    let mut tls_config = rustls::ServerConfig::builder_with_provider(Arc::new(
33        rustls::crypto::ring::default_provider(),
34    ))
35    .with_protocol_versions(&[&rustls::version::TLS13])
36    .unwrap_or_else(|e| panic!("Failed to create TLS server config: {e:?}"))
37    .with_no_client_auth()
38    .with_single_cert(vec![tls_cert], tls_private_key)
39    .unwrap_or_else(|e| panic!("Failed to create TLS server config: {e:?}"));
40    tls_config.alpn_protocols = vec![b"h2".to_vec()];
41    tls_config
42}
43
44/// Create a TLS server config which requires mTLS, eg the client to also
45/// provide a cert and be verified by the server based on the provided policy
46pub fn create_rustls_server_config_with_client_verifier<A: Allower + 'static>(
47    private_key: Ed25519PrivateKey,
48    server_name: String,
49    allower: A,
50) -> ServerConfig {
51    let verifier = ClientCertVerifier::new(allower, server_name.clone());
52    // TODO: refactor to use key bytes
53    let self_signed_cert = SelfSignedCertificate::new(private_key, server_name.as_str());
54    let tls_cert = self_signed_cert.rustls_certificate();
55    let tls_private_key = self_signed_cert.rustls_private_key();
56    let mut tls_config = verifier
57        .rustls_server_config(vec![tls_cert], tls_private_key)
58        .unwrap_or_else(|e| panic!("Failed to create TLS server config: {e:?}"));
59    tls_config.alpn_protocols = vec![b"h2".to_vec()];
60    tls_config
61}
62
63pub fn create_rustls_client_config(
64    target_public_key: Ed25519PublicKey,
65    server_name: String,
66    client_key: Option<Ed25519PrivateKey>, // optional self-signed cert for client verification
67) -> ClientConfig {
68    let tls_config = ServerCertVerifier::new(target_public_key, server_name.clone());
69    let tls_config = if let Some(private_key) = client_key {
70        let self_signed_cert = SelfSignedCertificate::new(private_key, server_name.as_str());
71        let tls_cert = self_signed_cert.rustls_certificate();
72        let tls_private_key = self_signed_cert.rustls_private_key();
73        tls_config.rustls_client_config_with_client_auth(vec![tls_cert], tls_private_key)
74    } else {
75        tls_config.rustls_client_config_with_no_client_auth()
76    }
77    .unwrap_or_else(|e| panic!("Failed to create TLS client config: {e:?}"));
78    tls_config
79}
80
81#[cfg(test)]
82mod tests {
83    use std::collections::BTreeSet;
84
85    use fastcrypto::{ed25519::Ed25519KeyPair, traits::KeyPair};
86    use rustls::{
87        client::danger::ServerCertVerifier as _,
88        pki_types::{ServerName, UnixTime},
89        server::danger::ClientCertVerifier as _,
90    };
91
92    use super::*;
93
94    #[test]
95    fn verify_allowall() {
96        let mut rng = rand::thread_rng();
97        let allowed = Ed25519KeyPair::generate(&mut rng);
98        let disallowed = Ed25519KeyPair::generate(&mut rng);
99        let random_cert_bob =
100            SelfSignedCertificate::new(allowed.private(), IOTA_VALIDATOR_SERVER_NAME);
101        let random_cert_alice =
102            SelfSignedCertificate::new(disallowed.private(), IOTA_VALIDATOR_SERVER_NAME);
103
104        let verifier = ClientCertVerifier::new(AllowAll, IOTA_VALIDATOR_SERVER_NAME.to_string());
105
106        // The bob passes validation
107        verifier
108            .verify_client_cert(&random_cert_bob.rustls_certificate(), &[], UnixTime::now())
109            .unwrap();
110
111        // The alice passes validation
112        verifier
113            .verify_client_cert(
114                &random_cert_alice.rustls_certificate(),
115                &[],
116                UnixTime::now(),
117            )
118            .unwrap();
119    }
120
121    #[test]
122    fn verify_server_cert() {
123        let mut rng = rand::thread_rng();
124        let allowed = Ed25519KeyPair::generate(&mut rng);
125        let disallowed = Ed25519KeyPair::generate(&mut rng);
126        let allowed_public_key = allowed.public().to_owned();
127        let random_cert_bob =
128            SelfSignedCertificate::new(allowed.private(), IOTA_VALIDATOR_SERVER_NAME);
129        let random_cert_alice =
130            SelfSignedCertificate::new(disallowed.private(), IOTA_VALIDATOR_SERVER_NAME);
131
132        let verifier =
133            ServerCertVerifier::new(allowed_public_key, IOTA_VALIDATOR_SERVER_NAME.to_string());
134
135        // The bob passes validation
136        verifier
137            .verify_server_cert(
138                &random_cert_bob.rustls_certificate(),
139                &[],
140                &ServerName::try_from("example.com").unwrap(),
141                &[],
142                UnixTime::now(),
143            )
144            .unwrap();
145
146        // The alice does not pass validation
147        let err = verifier
148            .verify_server_cert(
149                &random_cert_alice.rustls_certificate(),
150                &[],
151                &ServerName::try_from("example.com").unwrap(),
152                &[],
153                UnixTime::now(),
154            )
155            .unwrap_err();
156        assert!(
157            matches!(err, rustls::Error::General(_)),
158            "Actual error: {err:?}"
159        );
160    }
161
162    #[test]
163    fn verify_hashset() {
164        let mut rng = rand::thread_rng();
165        let allowed = Ed25519KeyPair::generate(&mut rng);
166        let disallowed = Ed25519KeyPair::generate(&mut rng);
167
168        let allowed_public_keys = BTreeSet::from([allowed.public().to_owned()]);
169        let allowed_cert =
170            SelfSignedCertificate::new(allowed.private(), IOTA_VALIDATOR_SERVER_NAME);
171
172        let disallowed_cert =
173            SelfSignedCertificate::new(disallowed.private(), IOTA_VALIDATOR_SERVER_NAME);
174
175        let allowlist = AllowPublicKeys::new(allowed_public_keys);
176        let verifier =
177            ClientCertVerifier::new(allowlist.clone(), IOTA_VALIDATOR_SERVER_NAME.to_string());
178
179        // The allowed cert passes validation
180        verifier
181            .verify_client_cert(&allowed_cert.rustls_certificate(), &[], UnixTime::now())
182            .unwrap();
183
184        // The disallowed cert fails validation
185        let err = verifier
186            .verify_client_cert(&disallowed_cert.rustls_certificate(), &[], UnixTime::now())
187            .unwrap_err();
188        assert!(
189            matches!(err, rustls::Error::General(_)),
190            "Actual error: {err:?}"
191        );
192
193        // After removing the allowed public key from the set it now fails validation
194        allowlist.update(BTreeSet::new());
195        let err = verifier
196            .verify_client_cert(&allowed_cert.rustls_certificate(), &[], UnixTime::now())
197            .unwrap_err();
198        assert!(
199            matches!(err, rustls::Error::General(_)),
200            "Actual error: {err:?}"
201        );
202    }
203
204    #[test]
205    fn invalid_server_name() {
206        let mut rng = rand::thread_rng();
207        let keypair = Ed25519KeyPair::generate(&mut rng);
208        let public_key = keypair.public().to_owned();
209        let cert = SelfSignedCertificate::new(keypair.private(), "not-iota");
210
211        let allowlist = AllowPublicKeys::new(BTreeSet::from([public_key.clone()]));
212        let client_verifier =
213            ClientCertVerifier::new(allowlist.clone(), IOTA_VALIDATOR_SERVER_NAME.to_string());
214
215        // Allowed public key but the server-name in the cert is not the required "iota"
216        let err = client_verifier
217            .verify_client_cert(&cert.rustls_certificate(), &[], UnixTime::now())
218            .unwrap_err();
219        assert_eq!(
220            err,
221            rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName),
222            "Actual error: {err:?}"
223        );
224
225        let server_verifier =
226            ServerCertVerifier::new(public_key, IOTA_VALIDATOR_SERVER_NAME.to_string());
227
228        // Allowed public key but the server-name in the cert is not the required "iota"
229        let err = server_verifier
230            .verify_server_cert(
231                &cert.rustls_certificate(),
232                &[],
233                &ServerName::try_from("example.com").unwrap(),
234                &[],
235                UnixTime::now(),
236            )
237            .unwrap_err();
238        assert_eq!(
239            err,
240            rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName),
241            "Actual error: {err:?}"
242        );
243    }
244
245    #[tokio::test]
246    async fn axum_acceptor() {
247        use fastcrypto::{ed25519::Ed25519KeyPair, traits::KeyPair};
248
249        let mut rng = rand::thread_rng();
250        let client_keypair = Ed25519KeyPair::generate(&mut rng);
251        let client_public_key = client_keypair.public().to_owned();
252        let client_certificate =
253            SelfSignedCertificate::new(client_keypair.private(), IOTA_VALIDATOR_SERVER_NAME);
254        let server_keypair = Ed25519KeyPair::generate(&mut rng);
255        let server_certificate = SelfSignedCertificate::new(server_keypair.private(), "localhost");
256
257        let client = reqwest::Client::builder()
258            .add_root_certificate(server_certificate.reqwest_certificate())
259            .identity(client_certificate.reqwest_identity())
260            .https_only(true)
261            .build()
262            .unwrap();
263
264        let allowlist = AllowPublicKeys::new(BTreeSet::new());
265        let tls_config =
266            ClientCertVerifier::new(allowlist.clone(), IOTA_VALIDATOR_SERVER_NAME.to_string())
267                .rustls_server_config(
268                    vec![server_certificate.rustls_certificate()],
269                    server_certificate.rustls_private_key(),
270                )
271                .unwrap();
272
273        async fn handler(tls_info: axum::Extension<TlsConnectionInfo>) -> String {
274            tls_info.public_key().unwrap().to_string()
275        }
276
277        let app = axum::Router::new().route("/", axum::routing::get(handler));
278        let listener = std::net::TcpListener::bind("localhost:0").unwrap();
279        let server_address = listener.local_addr().unwrap();
280        let acceptor = TlsAcceptor::new(tls_config);
281        let _server = tokio::spawn(async move {
282            axum_server::Server::from_tcp(listener)
283                .acceptor(acceptor)
284                .serve(app.into_make_service())
285                .await
286                .unwrap()
287        });
288
289        let server_url = format!("https://localhost:{}", server_address.port());
290        // Client request is rejected because it isn't in the allowlist
291        client.get(&server_url).send().await.unwrap_err();
292
293        // Insert the client's public key into the allowlist and verify the request is
294        // successful
295        allowlist.update(BTreeSet::from([client_public_key.clone()]));
296
297        let res = client.get(&server_url).send().await.unwrap();
298        let body = res.text().await.unwrap();
299        assert_eq!(client_public_key.to_string(), body);
300    }
301}