iota_types/
zk_login_authenticator.rs

1// Copyright (c) 2021, Facebook, Inc. and its affiliates
2// Copyright (c) Mysten Labs, Inc.
3// Modifications Copyright (c) 2024 IOTA Stiftung
4// SPDX-License-Identifier: Apache-2.0
5
6use std::{
7    hash::{Hash, Hasher},
8    sync::Arc,
9};
10
11use fastcrypto::{error::FastCryptoError, traits::ToFromBytes};
12use fastcrypto_zkp::bn254::{
13    zk_login::{JWK, JwkId, ZkLoginInputs},
14    zk_login_api::{ZkLoginEnv, verify_zk_login},
15};
16use once_cell::sync::OnceCell;
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use shared_crypto::intent::IntentMessage;
20
21use crate::{
22    base_types::{EpochId, IotaAddress},
23    crypto::{DefaultHash, IotaSignature, PublicKey, Signature, SignatureScheme},
24    digests::ZKLoginInputsDigest,
25    error::{IotaError, IotaResult},
26    signature::{AuthenticatorTrait, VerifyParams},
27    signature_verification::VerifiedDigestCache,
28};
29#[cfg(test)]
30#[path = "unit_tests/zk_login_authenticator_test.rs"]
31mod zk_login_authenticator_test;
32
33/// An zk login authenticator with all the necessary fields.
34#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
35#[serde(rename_all = "camelCase")]
36pub struct ZkLoginAuthenticator {
37    pub inputs: ZkLoginInputs,
38    max_epoch: EpochId,
39    user_signature: Signature,
40    #[serde(skip)]
41    pub bytes: OnceCell<Vec<u8>>,
42}
43
44/// A helper struct that contains the necessary fields to calculate caching key.
45/// If the verify_zk_login() api changes, additional fields must be added here
46/// so the cache is not skipped.
47#[derive(Serialize, Deserialize)]
48struct ZkLoginCachingParams {
49    inputs: ZkLoginInputs,
50    max_epoch: EpochId,
51    extended_pk_bytes: Vec<u8>,
52}
53
54impl ZkLoginAuthenticator {
55    /// The caching key for zklogin signature, it is the hash of bcs bytes of
56    /// ZkLoginInputs || max_epoch || flagged_pk_bytes. If any of these fields
57    /// change, zklogin signature is re-verified without using the caching
58    /// result.
59    fn get_caching_params(&self) -> ZkLoginCachingParams {
60        let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
61        extended_pk_bytes.extend(self.user_signature.public_key_bytes());
62        ZkLoginCachingParams {
63            inputs: self.inputs.clone(),
64            max_epoch: self.max_epoch,
65            extended_pk_bytes,
66        }
67    }
68
69    pub fn hash_inputs(&self) -> ZKLoginInputsDigest {
70        use fastcrypto::hash::HashFunction;
71        let mut hasher = DefaultHash::default();
72        hasher.update(bcs::to_bytes(&self.get_caching_params()).expect("serde should not fail"));
73        ZKLoginInputsDigest::new(hasher.finalize().into())
74    }
75
76    /// Create a new [struct ZkLoginAuthenticator] with necessary fields.
77    pub fn new(inputs: ZkLoginInputs, max_epoch: EpochId, user_signature: Signature) -> Self {
78        Self {
79            inputs,
80            max_epoch,
81            user_signature,
82            bytes: OnceCell::new(),
83        }
84    }
85
86    pub fn get_pk(&self) -> IotaResult<PublicKey> {
87        PublicKey::from_zklogin_inputs(&self.inputs)
88    }
89
90    pub fn get_iss(&self) -> &str {
91        self.inputs.get_iss()
92    }
93
94    pub fn get_max_epoch(&self) -> EpochId {
95        self.max_epoch
96    }
97
98    pub fn user_signature_mut_for_testing(&mut self) -> &mut Signature {
99        &mut self.user_signature
100    }
101
102    pub fn max_epoch_mut_for_testing(&mut self) -> &mut EpochId {
103        &mut self.max_epoch
104    }
105
106    pub fn zk_login_inputs_mut_for_testing(&mut self) -> &mut ZkLoginInputs {
107        &mut self.inputs
108    }
109}
110
111/// Necessary trait for [struct SenderSignedData].
112impl PartialEq for ZkLoginAuthenticator {
113    fn eq(&self, other: &Self) -> bool {
114        self.as_ref() == other.as_ref()
115    }
116}
117
118/// Necessary trait for [struct SenderSignedData].
119impl Eq for ZkLoginAuthenticator {}
120
121/// Necessary trait for [struct SenderSignedData].
122impl Hash for ZkLoginAuthenticator {
123    fn hash<H: Hasher>(&self, state: &mut H) {
124        self.as_ref().hash(state);
125    }
126}
127
128impl AuthenticatorTrait for ZkLoginAuthenticator {
129    fn verify_user_authenticator_epoch(
130        &self,
131        epoch: EpochId,
132        max_epoch_upper_bound_delta: Option<u64>,
133    ) -> IotaResult {
134        // the checks here ensure that `current_epoch + max_epoch_upper_bound_delta >=
135        // self.max_epoch >= current_epoch`.
136        // 1. if the config for upper bound is set, ensure that the max epoch in
137        //    signature is not larger than epoch + upper_bound.
138        if let Some(delta) = max_epoch_upper_bound_delta {
139            let max_epoch_upper_bound = epoch + delta;
140            if self.get_max_epoch() > max_epoch_upper_bound {
141                return Err(IotaError::InvalidSignature {
142                    error: format!(
143                        "ZKLogin max epoch too large {}, current epoch {}, max accepted: {}",
144                        self.get_max_epoch(),
145                        epoch,
146                        max_epoch_upper_bound
147                    ),
148                });
149            }
150        }
151        // 2. ensure that max epoch in signature is greater than the current epoch.
152        if epoch > self.get_max_epoch() {
153            return Err(IotaError::InvalidSignature {
154                error: format!(
155                    "ZKLogin expired at epoch {}, current epoch {}",
156                    self.get_max_epoch(),
157                    epoch
158                ),
159            });
160        }
161        Ok(())
162    }
163
164    /// Verify an intent message of a transaction with an zk login
165    /// authenticator.
166    fn verify_claims<T>(
167        &self,
168        intent_msg: &IntentMessage<T>,
169        author: IotaAddress,
170        aux_verify_data: &VerifyParams,
171        zklogin_inputs_cache: Arc<VerifiedDigestCache<ZKLoginInputsDigest>>,
172    ) -> IotaResult
173    where
174        T: Serialize,
175    {
176        // Always evaluate the unpadded address derivation.
177        if author != IotaAddress::try_from_unpadded(&self.inputs)? {
178            return Err(IotaError::InvalidAddress);
179        }
180
181        // Verify the ephemeral signature over the intent message of the transaction
182        // data.
183        self.user_signature.verify_secure(
184            intent_msg,
185            author,
186            SignatureScheme::ZkLoginAuthenticator,
187        )?;
188
189        if zklogin_inputs_cache.is_cached(&self.hash_inputs()) {
190            // If the zklogin inputs hits the cache, we don't need to verify the zklogin
191            // again that contains the heavy computation.
192            Ok(())
193        } else {
194            // if it is not cached, we verify the full zklogin inputs.
195            // build extended_pk_bytes as flag || pk_bytes.
196            let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
197            extended_pk_bytes.extend(self.user_signature.public_key_bytes());
198            let res = verify_zklogin_inputs_wrapper(
199                self.get_caching_params(),
200                &aux_verify_data.oidc_provider_jwks,
201                &aux_verify_data.zk_login_env,
202            )
203            .map_err(|e| IotaError::InvalidSignature {
204                error: e.to_string(),
205            });
206            match res {
207                Ok(_) => {
208                    // If it's verified ok, we cache the digest.
209                    zklogin_inputs_cache.cache_digest(self.hash_inputs());
210                    Ok(())
211                }
212                Err(e) => Err(e),
213            }
214        }
215    }
216}
217
218fn verify_zklogin_inputs_wrapper(
219    params: ZkLoginCachingParams,
220    all_jwk: &im::HashMap<JwkId, JWK>,
221    env: &ZkLoginEnv,
222) -> IotaResult<()> {
223    verify_zk_login(
224        &params.inputs,
225        params.max_epoch,
226        &params.extended_pk_bytes,
227        all_jwk,
228        env,
229    )
230    .map_err(|e| IotaError::InvalidSignature {
231        error: e.to_string(),
232    })
233}
234
235impl ToFromBytes for ZkLoginAuthenticator {
236    fn from_bytes(bytes: &[u8]) -> Result<Self, FastCryptoError> {
237        // The first byte matches the flag of MultiSig.
238        if bytes.first().ok_or(FastCryptoError::InvalidInput)?
239            != &SignatureScheme::ZkLoginAuthenticator.flag()
240        {
241            return Err(FastCryptoError::InvalidInput);
242        }
243        let mut zk_login: ZkLoginAuthenticator =
244            bcs::from_bytes(&bytes[1..]).map_err(|_| FastCryptoError::InvalidSignature)?;
245        zk_login.inputs.init()?;
246        Ok(zk_login)
247    }
248}
249
250impl AsRef<[u8]> for ZkLoginAuthenticator {
251    fn as_ref(&self) -> &[u8] {
252        self.bytes
253            .get_or_try_init::<_, eyre::Report>(|| {
254                let as_bytes = bcs::to_bytes(self).expect("BCS serialization should not fail");
255                let mut bytes = Vec::with_capacity(1 + as_bytes.len());
256                bytes.push(SignatureScheme::ZkLoginAuthenticator.flag());
257                bytes.extend_from_slice(as_bytes.as_slice());
258                Ok(bytes)
259            })
260            .expect("OnceCell invariant violated")
261    }
262}
263
264#[derive(Debug, Clone)]
265pub struct AddressSeed([u8; 32]);
266
267impl AddressSeed {
268    pub fn unpadded(&self) -> &[u8] {
269        let mut buf = self.0.as_slice();
270
271        while !buf.is_empty() && buf[0] == 0 {
272            buf = &buf[1..];
273        }
274
275        // If the value is '0' then just return a slice of length 1 of the final byte
276        if buf.is_empty() { &self.0[31..] } else { buf }
277    }
278
279    pub fn padded(&self) -> &[u8] {
280        &self.0
281    }
282}
283
284impl std::fmt::Display for AddressSeed {
285    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286        let big_int = num_bigint::BigUint::from_bytes_be(&self.0);
287        let radix10 = big_int.to_str_radix(10);
288        f.write_str(&radix10)
289    }
290}
291
292#[derive(thiserror::Error, Debug)]
293pub enum AddressSeedParseError {
294    #[error("unable to parse radix10 encoded value `{0}`")]
295    Parse(#[from] num_bigint::ParseBigIntError),
296    #[error("larger than 32 bytes")]
297    TooBig,
298}
299
300impl std::str::FromStr for AddressSeed {
301    type Err = AddressSeedParseError;
302
303    fn from_str(s: &str) -> Result<Self, Self::Err> {
304        let big_int = <num_bigint::BigUint as num_traits::Num>::from_str_radix(s, 10)?;
305        let be_bytes = big_int.to_bytes_be();
306        let len = be_bytes.len();
307        let mut buf = [0; 32];
308
309        if len > 32 {
310            return Err(AddressSeedParseError::TooBig);
311        }
312
313        buf[32 - len..].copy_from_slice(&be_bytes);
314        Ok(Self(buf))
315    }
316}
317
318// AddressSeed's serialized format is as a radix10 encoded string
319impl Serialize for AddressSeed {
320    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
321    where
322        S: serde::Serializer,
323    {
324        self.to_string().serialize(serializer)
325    }
326}
327
328impl<'de> Deserialize<'de> for AddressSeed {
329    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
330    where
331        D: serde::Deserializer<'de>,
332    {
333        let s = std::borrow::Cow::<'de, str>::deserialize(deserializer)?;
334        std::str::FromStr::from_str(&s).map_err(serde::de::Error::custom)
335    }
336}
337
338#[cfg(test)]
339mod test {
340    use std::str::FromStr;
341
342    use num_bigint::BigUint;
343    use proptest::prelude::*;
344
345    use super::AddressSeed;
346
347    #[test]
348    fn unpadded_slice() {
349        let seed = AddressSeed([0; 32]);
350        let zero: [u8; 1] = [0];
351        assert_eq!(seed.unpadded(), zero.as_slice());
352
353        let mut seed = AddressSeed([1; 32]);
354        seed.0[0] = 0;
355        assert_eq!(seed.unpadded(), [1; 31].as_slice());
356    }
357
358    proptest! {
359        #[test]
360        fn dont_crash_on_large_inputs(
361            bytes in proptest::collection::vec(any::<u8>(), 33..1024)
362        ) {
363            let big_int = BigUint::from_bytes_be(&bytes);
364            let radix10 = big_int.to_str_radix(10);
365
366            // doesn't crash
367            let _ = AddressSeed::from_str(&radix10);
368        }
369
370        #[test]
371        fn valid_address_seeds(
372            bytes in proptest::collection::vec(any::<u8>(), 1..=32)
373        ) {
374            let big_int = BigUint::from_bytes_be(&bytes);
375            let radix10 = big_int.to_str_radix(10);
376
377            let seed = AddressSeed::from_str(&radix10).unwrap();
378            assert_eq!(radix10, seed.to_string());
379            // Ensure unpadded doesn't crash
380            seed.unpadded();
381        }
382    }
383}