1use std::{
7 hash::{Hash, Hasher},
8 sync::Arc,
9};
10
11use fastcrypto::{error::FastCryptoError, traits::ToFromBytes};
12use fastcrypto_zkp::bn254::{
13 zk_login::{JWK, JwkId, ZkLoginInputs},
14 zk_login_api::{ZkLoginEnv, verify_zk_login},
15};
16use once_cell::sync::OnceCell;
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use shared_crypto::intent::IntentMessage;
20
21use crate::{
22 base_types::{EpochId, IotaAddress},
23 crypto::{DefaultHash, IotaSignature, PublicKey, Signature, SignatureScheme},
24 digests::ZKLoginInputsDigest,
25 error::{IotaError, IotaResult},
26 signature::{AuthenticatorTrait, VerifyParams},
27 signature_verification::VerifiedDigestCache,
28};
29#[cfg(test)]
30#[path = "unit_tests/zk_login_authenticator_test.rs"]
31mod zk_login_authenticator_test;
32
33#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
35#[serde(rename_all = "camelCase")]
36pub struct ZkLoginAuthenticator {
37 pub inputs: ZkLoginInputs,
38 max_epoch: EpochId,
39 user_signature: Signature,
40 #[serde(skip)]
41 pub bytes: OnceCell<Vec<u8>>,
42}
43
44#[derive(Serialize, Deserialize)]
48struct ZkLoginCachingParams {
49 inputs: ZkLoginInputs,
50 max_epoch: EpochId,
51 extended_pk_bytes: Vec<u8>,
52}
53
54impl ZkLoginAuthenticator {
55 fn get_caching_params(&self) -> ZkLoginCachingParams {
60 let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
61 extended_pk_bytes.extend(self.user_signature.public_key_bytes());
62 ZkLoginCachingParams {
63 inputs: self.inputs.clone(),
64 max_epoch: self.max_epoch,
65 extended_pk_bytes,
66 }
67 }
68
69 pub fn hash_inputs(&self) -> ZKLoginInputsDigest {
70 use fastcrypto::hash::HashFunction;
71 let mut hasher = DefaultHash::default();
72 hasher.update(bcs::to_bytes(&self.get_caching_params()).expect("serde should not fail"));
73 ZKLoginInputsDigest::new(hasher.finalize().into())
74 }
75
76 pub fn new(inputs: ZkLoginInputs, max_epoch: EpochId, user_signature: Signature) -> Self {
78 Self {
79 inputs,
80 max_epoch,
81 user_signature,
82 bytes: OnceCell::new(),
83 }
84 }
85
86 pub fn get_pk(&self) -> IotaResult<PublicKey> {
87 PublicKey::from_zklogin_inputs(&self.inputs)
88 }
89
90 pub fn get_iss(&self) -> &str {
91 self.inputs.get_iss()
92 }
93
94 pub fn get_max_epoch(&self) -> EpochId {
95 self.max_epoch
96 }
97
98 pub fn user_signature_mut_for_testing(&mut self) -> &mut Signature {
99 &mut self.user_signature
100 }
101 pub fn max_epoch_mut_for_testing(&mut self) -> &mut EpochId {
102 &mut self.max_epoch
103 }
104 pub fn zk_login_inputs_mut_for_testing(&mut self) -> &mut ZkLoginInputs {
105 &mut self.inputs
106 }
107}
108
109impl PartialEq for ZkLoginAuthenticator {
111 fn eq(&self, other: &Self) -> bool {
112 self.as_ref() == other.as_ref()
113 }
114}
115
116impl Eq for ZkLoginAuthenticator {}
118
119impl Hash for ZkLoginAuthenticator {
121 fn hash<H: Hasher>(&self, state: &mut H) {
122 self.as_ref().hash(state);
123 }
124}
125
126impl AuthenticatorTrait for ZkLoginAuthenticator {
127 fn verify_user_authenticator_epoch(
128 &self,
129 epoch: EpochId,
130 max_epoch_upper_bound_delta: Option<u64>,
131 ) -> IotaResult {
132 if let Some(delta) = max_epoch_upper_bound_delta {
137 let max_epoch_upper_bound = epoch + delta;
138 if self.get_max_epoch() > max_epoch_upper_bound {
139 return Err(IotaError::InvalidSignature {
140 error: format!(
141 "ZKLogin max epoch too large {}, current epoch {}, max accepted: {}",
142 self.get_max_epoch(),
143 epoch,
144 max_epoch_upper_bound
145 ),
146 });
147 }
148 }
149 if epoch > self.get_max_epoch() {
151 return Err(IotaError::InvalidSignature {
152 error: format!(
153 "ZKLogin expired at epoch {}, current epoch {}",
154 self.get_max_epoch(),
155 epoch
156 ),
157 });
158 }
159 Ok(())
160 }
161
162 fn verify_claims<T>(
165 &self,
166 intent_msg: &IntentMessage<T>,
167 author: IotaAddress,
168 aux_verify_data: &VerifyParams,
169 zklogin_inputs_cache: Arc<VerifiedDigestCache<ZKLoginInputsDigest>>,
170 ) -> IotaResult
171 where
172 T: Serialize,
173 {
174 if author != IotaAddress::try_from_unpadded(&self.inputs)? {
176 return Err(IotaError::InvalidAddress);
177 }
178
179 self.user_signature.verify_secure(
182 intent_msg,
183 author,
184 SignatureScheme::ZkLoginAuthenticator,
185 )?;
186
187 if zklogin_inputs_cache.is_cached(&self.hash_inputs()) {
188 Ok(())
191 } else {
192 let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
195 extended_pk_bytes.extend(self.user_signature.public_key_bytes());
196 let res = verify_zklogin_inputs_wrapper(
197 self.get_caching_params(),
198 &aux_verify_data.oidc_provider_jwks,
199 &aux_verify_data.zk_login_env,
200 )
201 .map_err(|e| IotaError::InvalidSignature {
202 error: e.to_string(),
203 });
204 match res {
205 Ok(_) => {
206 zklogin_inputs_cache.cache_digest(self.hash_inputs());
208 Ok(())
209 }
210 Err(e) => Err(e),
211 }
212 }
213 }
214}
215
216fn verify_zklogin_inputs_wrapper(
217 params: ZkLoginCachingParams,
218 all_jwk: &im::HashMap<JwkId, JWK>,
219 env: &ZkLoginEnv,
220) -> IotaResult<()> {
221 verify_zk_login(
222 ¶ms.inputs,
223 params.max_epoch,
224 ¶ms.extended_pk_bytes,
225 all_jwk,
226 env,
227 )
228 .map_err(|e| IotaError::InvalidSignature {
229 error: e.to_string(),
230 })
231}
232
233impl ToFromBytes for ZkLoginAuthenticator {
234 fn from_bytes(bytes: &[u8]) -> Result<Self, FastCryptoError> {
235 if bytes.first().ok_or(FastCryptoError::InvalidInput)?
237 != &SignatureScheme::ZkLoginAuthenticator.flag()
238 {
239 return Err(FastCryptoError::InvalidInput);
240 }
241 let mut zk_login: ZkLoginAuthenticator =
242 bcs::from_bytes(&bytes[1..]).map_err(|_| FastCryptoError::InvalidSignature)?;
243 zk_login.inputs.init()?;
244 Ok(zk_login)
245 }
246}
247
248impl AsRef<[u8]> for ZkLoginAuthenticator {
249 fn as_ref(&self) -> &[u8] {
250 self.bytes
251 .get_or_try_init::<_, eyre::Report>(|| {
252 let as_bytes = bcs::to_bytes(self).expect("BCS serialization should not fail");
253 let mut bytes = Vec::with_capacity(1 + as_bytes.len());
254 bytes.push(SignatureScheme::ZkLoginAuthenticator.flag());
255 bytes.extend_from_slice(as_bytes.as_slice());
256 Ok(bytes)
257 })
258 .expect("OnceCell invariant violated")
259 }
260}
261
262#[derive(Debug, Clone)]
263pub struct AddressSeed([u8; 32]);
264
265impl AddressSeed {
266 pub fn unpadded(&self) -> &[u8] {
267 let mut buf = self.0.as_slice();
268
269 while !buf.is_empty() && buf[0] == 0 {
270 buf = &buf[1..];
271 }
272
273 if buf.is_empty() { &self.0[31..] } else { buf }
275 }
276
277 pub fn padded(&self) -> &[u8] {
278 &self.0
279 }
280}
281
282impl std::fmt::Display for AddressSeed {
283 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284 let big_int = num_bigint::BigUint::from_bytes_be(&self.0);
285 let radix10 = big_int.to_str_radix(10);
286 f.write_str(&radix10)
287 }
288}
289
290#[derive(thiserror::Error, Debug)]
291pub enum AddressSeedParseError {
292 #[error("unable to parse radix10 encoded value `{0}`")]
293 Parse(#[from] num_bigint::ParseBigIntError),
294 #[error("larger than 32 bytes")]
295 TooBig,
296}
297
298impl std::str::FromStr for AddressSeed {
299 type Err = AddressSeedParseError;
300
301 fn from_str(s: &str) -> Result<Self, Self::Err> {
302 let big_int = <num_bigint::BigUint as num_traits::Num>::from_str_radix(s, 10)?;
303 let be_bytes = big_int.to_bytes_be();
304 let len = be_bytes.len();
305 let mut buf = [0; 32];
306
307 if len > 32 {
308 return Err(AddressSeedParseError::TooBig);
309 }
310
311 buf[32 - len..].copy_from_slice(&be_bytes);
312 Ok(Self(buf))
313 }
314}
315
316impl Serialize for AddressSeed {
318 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
319 where
320 S: serde::Serializer,
321 {
322 self.to_string().serialize(serializer)
323 }
324}
325
326impl<'de> Deserialize<'de> for AddressSeed {
327 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
328 where
329 D: serde::Deserializer<'de>,
330 {
331 let s = std::borrow::Cow::<'de, str>::deserialize(deserializer)?;
332 std::str::FromStr::from_str(&s).map_err(serde::de::Error::custom)
333 }
334}
335
336#[cfg(test)]
337mod test {
338 use std::str::FromStr;
339
340 use num_bigint::BigUint;
341 use proptest::prelude::*;
342
343 use super::AddressSeed;
344
345 #[test]
346 fn unpadded_slice() {
347 let seed = AddressSeed([0; 32]);
348 let zero: [u8; 1] = [0];
349 assert_eq!(seed.unpadded(), zero.as_slice());
350
351 let mut seed = AddressSeed([1; 32]);
352 seed.0[0] = 0;
353 assert_eq!(seed.unpadded(), [1; 31].as_slice());
354 }
355
356 proptest! {
357 #[test]
358 fn dont_crash_on_large_inputs(
359 bytes in proptest::collection::vec(any::<u8>(), 33..1024)
360 ) {
361 let big_int = BigUint::from_bytes_be(&bytes);
362 let radix10 = big_int.to_str_radix(10);
363
364 let _ = AddressSeed::from_str(&radix10);
366 }
367
368 #[test]
369 fn valid_address_seeds(
370 bytes in proptest::collection::vec(any::<u8>(), 1..=32)
371 ) {
372 let big_int = BigUint::from_bytes_be(&bytes);
373 let radix10 = big_int.to_str_radix(10);
374
375 let seed = AddressSeed::from_str(&radix10).unwrap();
376 assert_eq!(radix10, seed.to_string());
377 seed.unpadded();
379 }
380 }
381}