iota_types/
zk_login_authenticator.rs

1// Copyright (c) 2021, Facebook, Inc. and its affiliates
2// Copyright (c) Mysten Labs, Inc.
3// Modifications Copyright (c) 2024 IOTA Stiftung
4// SPDX-License-Identifier: Apache-2.0
5
6use std::{
7    hash::{Hash, Hasher},
8    sync::Arc,
9};
10
11use fastcrypto::{error::FastCryptoError, traits::ToFromBytes};
12use fastcrypto_zkp::bn254::{
13    zk_login::{JWK, JwkId, ZkLoginInputs},
14    zk_login_api::{ZkLoginEnv, verify_zk_login},
15};
16use once_cell::sync::OnceCell;
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use shared_crypto::intent::IntentMessage;
20
21use crate::{
22    base_types::{EpochId, IotaAddress},
23    crypto::{DefaultHash, IotaSignature, PublicKey, Signature, SignatureScheme},
24    digests::ZKLoginInputsDigest,
25    error::{IotaError, IotaResult},
26    signature::{AuthenticatorTrait, VerifyParams},
27    signature_verification::VerifiedDigestCache,
28};
29#[cfg(test)]
30#[path = "unit_tests/zk_login_authenticator_test.rs"]
31mod zk_login_authenticator_test;
32
33/// An zk login authenticator with all the necessary fields.
34#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
35#[serde(rename_all = "camelCase")]
36pub struct ZkLoginAuthenticator {
37    pub inputs: ZkLoginInputs,
38    max_epoch: EpochId,
39    user_signature: Signature,
40    #[serde(skip)]
41    pub bytes: OnceCell<Vec<u8>>,
42}
43
44/// A helper struct that contains the necessary fields to calculate caching key.
45/// If the verify_zk_login() api changes, additional fields must be added here
46/// so the cache is not skipped.
47#[derive(Serialize, Deserialize)]
48struct ZkLoginCachingParams {
49    inputs: ZkLoginInputs,
50    max_epoch: EpochId,
51    extended_pk_bytes: Vec<u8>,
52}
53
54impl ZkLoginAuthenticator {
55    /// The caching key for zklogin signature, it is the hash of bcs bytes of
56    /// ZkLoginInputs || max_epoch || flagged_pk_bytes. If any of these fields
57    /// change, zklogin signature is re-verified without using the caching
58    /// result.
59    fn get_caching_params(&self) -> ZkLoginCachingParams {
60        let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
61        extended_pk_bytes.extend(self.user_signature.public_key_bytes());
62        ZkLoginCachingParams {
63            inputs: self.inputs.clone(),
64            max_epoch: self.max_epoch,
65            extended_pk_bytes,
66        }
67    }
68
69    pub fn hash_inputs(&self) -> ZKLoginInputsDigest {
70        use fastcrypto::hash::HashFunction;
71        let mut hasher = DefaultHash::default();
72        hasher.update(bcs::to_bytes(&self.get_caching_params()).expect("serde should not fail"));
73        ZKLoginInputsDigest::new(hasher.finalize().into())
74    }
75
76    /// Create a new [struct ZkLoginAuthenticator] with necessary fields.
77    pub fn new(inputs: ZkLoginInputs, max_epoch: EpochId, user_signature: Signature) -> Self {
78        Self {
79            inputs,
80            max_epoch,
81            user_signature,
82            bytes: OnceCell::new(),
83        }
84    }
85
86    pub fn get_pk(&self) -> IotaResult<PublicKey> {
87        PublicKey::from_zklogin_inputs(&self.inputs)
88    }
89
90    pub fn get_iss(&self) -> &str {
91        self.inputs.get_iss()
92    }
93
94    pub fn get_max_epoch(&self) -> EpochId {
95        self.max_epoch
96    }
97
98    pub fn user_signature_mut_for_testing(&mut self) -> &mut Signature {
99        &mut self.user_signature
100    }
101    pub fn max_epoch_mut_for_testing(&mut self) -> &mut EpochId {
102        &mut self.max_epoch
103    }
104    pub fn zk_login_inputs_mut_for_testing(&mut self) -> &mut ZkLoginInputs {
105        &mut self.inputs
106    }
107}
108
109/// Necessary trait for [struct SenderSignedData].
110impl PartialEq for ZkLoginAuthenticator {
111    fn eq(&self, other: &Self) -> bool {
112        self.as_ref() == other.as_ref()
113    }
114}
115
116/// Necessary trait for [struct SenderSignedData].
117impl Eq for ZkLoginAuthenticator {}
118
119/// Necessary trait for [struct SenderSignedData].
120impl Hash for ZkLoginAuthenticator {
121    fn hash<H: Hasher>(&self, state: &mut H) {
122        self.as_ref().hash(state);
123    }
124}
125
126impl AuthenticatorTrait for ZkLoginAuthenticator {
127    fn verify_user_authenticator_epoch(
128        &self,
129        epoch: EpochId,
130        max_epoch_upper_bound_delta: Option<u64>,
131    ) -> IotaResult {
132        // the checks here ensure that `current_epoch + max_epoch_upper_bound_delta >=
133        // self.max_epoch >= current_epoch`.
134        // 1. if the config for upper bound is set, ensure that the max epoch in
135        //    signature is not larger than epoch + upper_bound.
136        if let Some(delta) = max_epoch_upper_bound_delta {
137            let max_epoch_upper_bound = epoch + delta;
138            if self.get_max_epoch() > max_epoch_upper_bound {
139                return Err(IotaError::InvalidSignature {
140                    error: format!(
141                        "ZKLogin max epoch too large {}, current epoch {}, max accepted: {}",
142                        self.get_max_epoch(),
143                        epoch,
144                        max_epoch_upper_bound
145                    ),
146                });
147            }
148        }
149        // 2. ensure that max epoch in signature is greater than the current epoch.
150        if epoch > self.get_max_epoch() {
151            return Err(IotaError::InvalidSignature {
152                error: format!(
153                    "ZKLogin expired at epoch {}, current epoch {}",
154                    self.get_max_epoch(),
155                    epoch
156                ),
157            });
158        }
159        Ok(())
160    }
161
162    /// Verify an intent message of a transaction with an zk login
163    /// authenticator.
164    fn verify_claims<T>(
165        &self,
166        intent_msg: &IntentMessage<T>,
167        author: IotaAddress,
168        aux_verify_data: &VerifyParams,
169        zklogin_inputs_cache: Arc<VerifiedDigestCache<ZKLoginInputsDigest>>,
170    ) -> IotaResult
171    where
172        T: Serialize,
173    {
174        // Always evaluate the unpadded address derivation.
175        if author != IotaAddress::try_from_unpadded(&self.inputs)? {
176            return Err(IotaError::InvalidAddress);
177        }
178
179        // Verify the ephemeral signature over the intent message of the transaction
180        // data.
181        self.user_signature.verify_secure(
182            intent_msg,
183            author,
184            SignatureScheme::ZkLoginAuthenticator,
185        )?;
186
187        if zklogin_inputs_cache.is_cached(&self.hash_inputs()) {
188            // If the zklogin inputs hits the cache, we don't need to verify the zklogin
189            // again that contains the heavy computation.
190            Ok(())
191        } else {
192            // if it is not cached, we verify the full zklogin inputs.
193            // build extended_pk_bytes as flag || pk_bytes.
194            let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
195            extended_pk_bytes.extend(self.user_signature.public_key_bytes());
196            let res = verify_zklogin_inputs_wrapper(
197                self.get_caching_params(),
198                &aux_verify_data.oidc_provider_jwks,
199                &aux_verify_data.zk_login_env,
200            )
201            .map_err(|e| IotaError::InvalidSignature {
202                error: e.to_string(),
203            });
204            match res {
205                Ok(_) => {
206                    // If it's verified ok, we cache the digest.
207                    zklogin_inputs_cache.cache_digest(self.hash_inputs());
208                    Ok(())
209                }
210                Err(e) => Err(e),
211            }
212        }
213    }
214}
215
216fn verify_zklogin_inputs_wrapper(
217    params: ZkLoginCachingParams,
218    all_jwk: &im::HashMap<JwkId, JWK>,
219    env: &ZkLoginEnv,
220) -> IotaResult<()> {
221    verify_zk_login(
222        &params.inputs,
223        params.max_epoch,
224        &params.extended_pk_bytes,
225        all_jwk,
226        env,
227    )
228    .map_err(|e| IotaError::InvalidSignature {
229        error: e.to_string(),
230    })
231}
232
233impl ToFromBytes for ZkLoginAuthenticator {
234    fn from_bytes(bytes: &[u8]) -> Result<Self, FastCryptoError> {
235        // The first byte matches the flag of MultiSig.
236        if bytes.first().ok_or(FastCryptoError::InvalidInput)?
237            != &SignatureScheme::ZkLoginAuthenticator.flag()
238        {
239            return Err(FastCryptoError::InvalidInput);
240        }
241        let mut zk_login: ZkLoginAuthenticator =
242            bcs::from_bytes(&bytes[1..]).map_err(|_| FastCryptoError::InvalidSignature)?;
243        zk_login.inputs.init()?;
244        Ok(zk_login)
245    }
246}
247
248impl AsRef<[u8]> for ZkLoginAuthenticator {
249    fn as_ref(&self) -> &[u8] {
250        self.bytes
251            .get_or_try_init::<_, eyre::Report>(|| {
252                let as_bytes = bcs::to_bytes(self).expect("BCS serialization should not fail");
253                let mut bytes = Vec::with_capacity(1 + as_bytes.len());
254                bytes.push(SignatureScheme::ZkLoginAuthenticator.flag());
255                bytes.extend_from_slice(as_bytes.as_slice());
256                Ok(bytes)
257            })
258            .expect("OnceCell invariant violated")
259    }
260}
261
262#[derive(Debug, Clone)]
263pub struct AddressSeed([u8; 32]);
264
265impl AddressSeed {
266    pub fn unpadded(&self) -> &[u8] {
267        let mut buf = self.0.as_slice();
268
269        while !buf.is_empty() && buf[0] == 0 {
270            buf = &buf[1..];
271        }
272
273        // If the value is '0' then just return a slice of length 1 of the final byte
274        if buf.is_empty() { &self.0[31..] } else { buf }
275    }
276
277    pub fn padded(&self) -> &[u8] {
278        &self.0
279    }
280}
281
282impl std::fmt::Display for AddressSeed {
283    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284        let big_int = num_bigint::BigUint::from_bytes_be(&self.0);
285        let radix10 = big_int.to_str_radix(10);
286        f.write_str(&radix10)
287    }
288}
289
290#[derive(thiserror::Error, Debug)]
291pub enum AddressSeedParseError {
292    #[error("unable to parse radix10 encoded value `{0}`")]
293    Parse(#[from] num_bigint::ParseBigIntError),
294    #[error("larger than 32 bytes")]
295    TooBig,
296}
297
298impl std::str::FromStr for AddressSeed {
299    type Err = AddressSeedParseError;
300
301    fn from_str(s: &str) -> Result<Self, Self::Err> {
302        let big_int = <num_bigint::BigUint as num_traits::Num>::from_str_radix(s, 10)?;
303        let be_bytes = big_int.to_bytes_be();
304        let len = be_bytes.len();
305        let mut buf = [0; 32];
306
307        if len > 32 {
308            return Err(AddressSeedParseError::TooBig);
309        }
310
311        buf[32 - len..].copy_from_slice(&be_bytes);
312        Ok(Self(buf))
313    }
314}
315
316// AddressSeed's serialized format is as a radix10 encoded string
317impl Serialize for AddressSeed {
318    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
319    where
320        S: serde::Serializer,
321    {
322        self.to_string().serialize(serializer)
323    }
324}
325
326impl<'de> Deserialize<'de> for AddressSeed {
327    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
328    where
329        D: serde::Deserializer<'de>,
330    {
331        let s = std::borrow::Cow::<'de, str>::deserialize(deserializer)?;
332        std::str::FromStr::from_str(&s).map_err(serde::de::Error::custom)
333    }
334}
335
336#[cfg(test)]
337mod test {
338    use std::str::FromStr;
339
340    use num_bigint::BigUint;
341    use proptest::prelude::*;
342
343    use super::AddressSeed;
344
345    #[test]
346    fn unpadded_slice() {
347        let seed = AddressSeed([0; 32]);
348        let zero: [u8; 1] = [0];
349        assert_eq!(seed.unpadded(), zero.as_slice());
350
351        let mut seed = AddressSeed([1; 32]);
352        seed.0[0] = 0;
353        assert_eq!(seed.unpadded(), [1; 31].as_slice());
354    }
355
356    proptest! {
357        #[test]
358        fn dont_crash_on_large_inputs(
359            bytes in proptest::collection::vec(any::<u8>(), 33..1024)
360        ) {
361            let big_int = BigUint::from_bytes_be(&bytes);
362            let radix10 = big_int.to_str_radix(10);
363
364            // doesn't crash
365            let _ = AddressSeed::from_str(&radix10);
366        }
367
368        #[test]
369        fn valid_address_seeds(
370            bytes in proptest::collection::vec(any::<u8>(), 1..=32)
371        ) {
372            let big_int = BigUint::from_bytes_be(&bytes);
373            let radix10 = big_int.to_str_radix(10);
374
375            let seed = AddressSeed::from_str(&radix10).unwrap();
376            assert_eq!(radix10, seed.to_string());
377            // Ensure unpadded doesn't crash
378            seed.unpadded();
379        }
380    }
381}