1use std::{
7 hash::{Hash, Hasher},
8 sync::Arc,
9};
10
11use fastcrypto::{error::FastCryptoError, traits::ToFromBytes};
12use fastcrypto_zkp::bn254::{
13 zk_login::{JWK, JwkId, ZkLoginInputs},
14 zk_login_api::{ZkLoginEnv, verify_zk_login},
15};
16use once_cell::sync::OnceCell;
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use shared_crypto::intent::IntentMessage;
20
21use crate::{
22 base_types::{EpochId, IotaAddress},
23 crypto::{DefaultHash, IotaSignature, PublicKey, Signature, SignatureScheme},
24 digests::ZKLoginInputsDigest,
25 error::{IotaError, IotaResult},
26 signature::{AuthenticatorTrait, VerifyParams},
27 signature_verification::VerifiedDigestCache,
28};
29#[cfg(test)]
30#[path = "unit_tests/zk_login_authenticator_test.rs"]
31mod zk_login_authenticator_test;
32
33#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
35#[serde(rename_all = "camelCase")]
36pub struct ZkLoginAuthenticator {
37 pub inputs: ZkLoginInputs,
38 max_epoch: EpochId,
39 user_signature: Signature,
40 #[serde(skip)]
41 pub bytes: OnceCell<Vec<u8>>,
42}
43
44#[derive(Serialize, Deserialize)]
48struct ZkLoginCachingParams {
49 inputs: ZkLoginInputs,
50 max_epoch: EpochId,
51 extended_pk_bytes: Vec<u8>,
52}
53
54impl ZkLoginAuthenticator {
55 fn get_caching_params(&self) -> ZkLoginCachingParams {
60 let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
61 extended_pk_bytes.extend(self.user_signature.public_key_bytes());
62 ZkLoginCachingParams {
63 inputs: self.inputs.clone(),
64 max_epoch: self.max_epoch,
65 extended_pk_bytes,
66 }
67 }
68
69 pub fn hash_inputs(&self) -> ZKLoginInputsDigest {
70 use fastcrypto::hash::HashFunction;
71 let mut hasher = DefaultHash::default();
72 hasher.update(bcs::to_bytes(&self.get_caching_params()).expect("serde should not fail"));
73 ZKLoginInputsDigest::new(hasher.finalize().into())
74 }
75
76 pub fn new(inputs: ZkLoginInputs, max_epoch: EpochId, user_signature: Signature) -> Self {
78 Self {
79 inputs,
80 max_epoch,
81 user_signature,
82 bytes: OnceCell::new(),
83 }
84 }
85
86 pub fn get_pk(&self) -> IotaResult<PublicKey> {
87 PublicKey::from_zklogin_inputs(&self.inputs)
88 }
89
90 pub fn get_iss(&self) -> &str {
91 self.inputs.get_iss()
92 }
93
94 pub fn get_max_epoch(&self) -> EpochId {
95 self.max_epoch
96 }
97
98 pub fn user_signature_mut_for_testing(&mut self) -> &mut Signature {
99 &mut self.user_signature
100 }
101
102 pub fn max_epoch_mut_for_testing(&mut self) -> &mut EpochId {
103 &mut self.max_epoch
104 }
105
106 pub fn zk_login_inputs_mut_for_testing(&mut self) -> &mut ZkLoginInputs {
107 &mut self.inputs
108 }
109}
110
111impl PartialEq for ZkLoginAuthenticator {
113 fn eq(&self, other: &Self) -> bool {
114 self.as_ref() == other.as_ref()
115 }
116}
117
118impl Eq for ZkLoginAuthenticator {}
120
121impl Hash for ZkLoginAuthenticator {
123 fn hash<H: Hasher>(&self, state: &mut H) {
124 self.as_ref().hash(state);
125 }
126}
127
128impl AuthenticatorTrait for ZkLoginAuthenticator {
129 fn verify_user_authenticator_epoch(
130 &self,
131 epoch: EpochId,
132 max_epoch_upper_bound_delta: Option<u64>,
133 ) -> IotaResult {
134 if let Some(delta) = max_epoch_upper_bound_delta {
139 let max_epoch_upper_bound = epoch + delta;
140 if self.get_max_epoch() > max_epoch_upper_bound {
141 return Err(IotaError::InvalidSignature {
142 error: format!(
143 "ZKLogin max epoch too large {}, current epoch {}, max accepted: {}",
144 self.get_max_epoch(),
145 epoch,
146 max_epoch_upper_bound
147 ),
148 });
149 }
150 }
151 if epoch > self.get_max_epoch() {
153 return Err(IotaError::InvalidSignature {
154 error: format!(
155 "ZKLogin expired at epoch {}, current epoch {}",
156 self.get_max_epoch(),
157 epoch
158 ),
159 });
160 }
161 Ok(())
162 }
163
164 fn verify_claims<T>(
167 &self,
168 intent_msg: &IntentMessage<T>,
169 author: IotaAddress,
170 aux_verify_data: &VerifyParams,
171 zklogin_inputs_cache: Arc<VerifiedDigestCache<ZKLoginInputsDigest>>,
172 ) -> IotaResult
173 where
174 T: Serialize,
175 {
176 if author != IotaAddress::try_from_unpadded(&self.inputs)? {
178 return Err(IotaError::InvalidAddress);
179 }
180
181 self.user_signature.verify_secure(
184 intent_msg,
185 author,
186 SignatureScheme::ZkLoginAuthenticator,
187 )?;
188
189 if zklogin_inputs_cache.is_cached(&self.hash_inputs()) {
190 Ok(())
193 } else {
194 let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
197 extended_pk_bytes.extend(self.user_signature.public_key_bytes());
198 let res = verify_zklogin_inputs_wrapper(
199 self.get_caching_params(),
200 &aux_verify_data.oidc_provider_jwks,
201 &aux_verify_data.zk_login_env,
202 )
203 .map_err(|e| IotaError::InvalidSignature {
204 error: e.to_string(),
205 });
206 match res {
207 Ok(_) => {
208 zklogin_inputs_cache.cache_digest(self.hash_inputs());
210 Ok(())
211 }
212 Err(e) => Err(e),
213 }
214 }
215 }
216}
217
218fn verify_zklogin_inputs_wrapper(
219 params: ZkLoginCachingParams,
220 all_jwk: &im::HashMap<JwkId, JWK>,
221 env: &ZkLoginEnv,
222) -> IotaResult<()> {
223 verify_zk_login(
224 ¶ms.inputs,
225 params.max_epoch,
226 ¶ms.extended_pk_bytes,
227 all_jwk,
228 env,
229 )
230 .map_err(|e| IotaError::InvalidSignature {
231 error: e.to_string(),
232 })
233}
234
235impl ToFromBytes for ZkLoginAuthenticator {
236 fn from_bytes(bytes: &[u8]) -> Result<Self, FastCryptoError> {
237 if bytes.first().ok_or(FastCryptoError::InvalidInput)?
239 != &SignatureScheme::ZkLoginAuthenticator.flag()
240 {
241 return Err(FastCryptoError::InvalidInput);
242 }
243 let mut zk_login: ZkLoginAuthenticator =
244 bcs::from_bytes(&bytes[1..]).map_err(|_| FastCryptoError::InvalidSignature)?;
245 zk_login.inputs.init()?;
246 Ok(zk_login)
247 }
248}
249
250impl AsRef<[u8]> for ZkLoginAuthenticator {
251 fn as_ref(&self) -> &[u8] {
252 self.bytes
253 .get_or_try_init::<_, eyre::Report>(|| {
254 let as_bytes = bcs::to_bytes(self).expect("BCS serialization should not fail");
255 let mut bytes = Vec::with_capacity(1 + as_bytes.len());
256 bytes.push(SignatureScheme::ZkLoginAuthenticator.flag());
257 bytes.extend_from_slice(as_bytes.as_slice());
258 Ok(bytes)
259 })
260 .expect("OnceCell invariant violated")
261 }
262}
263
264#[derive(Debug, Clone)]
265pub struct AddressSeed([u8; 32]);
266
267impl AddressSeed {
268 pub fn unpadded(&self) -> &[u8] {
269 let mut buf = self.0.as_slice();
270
271 while !buf.is_empty() && buf[0] == 0 {
272 buf = &buf[1..];
273 }
274
275 if buf.is_empty() { &self.0[31..] } else { buf }
277 }
278
279 pub fn padded(&self) -> &[u8] {
280 &self.0
281 }
282}
283
284impl std::fmt::Display for AddressSeed {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 let big_int = num_bigint::BigUint::from_bytes_be(&self.0);
287 let radix10 = big_int.to_str_radix(10);
288 f.write_str(&radix10)
289 }
290}
291
292#[derive(thiserror::Error, Debug)]
293pub enum AddressSeedParseError {
294 #[error("unable to parse radix10 encoded value `{0}`")]
295 Parse(#[from] num_bigint::ParseBigIntError),
296 #[error("larger than 32 bytes")]
297 TooBig,
298}
299
300impl std::str::FromStr for AddressSeed {
301 type Err = AddressSeedParseError;
302
303 fn from_str(s: &str) -> Result<Self, Self::Err> {
304 let big_int = <num_bigint::BigUint as num_traits::Num>::from_str_radix(s, 10)?;
305 let be_bytes = big_int.to_bytes_be();
306 let len = be_bytes.len();
307 let mut buf = [0; 32];
308
309 if len > 32 {
310 return Err(AddressSeedParseError::TooBig);
311 }
312
313 buf[32 - len..].copy_from_slice(&be_bytes);
314 Ok(Self(buf))
315 }
316}
317
318impl Serialize for AddressSeed {
320 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
321 where
322 S: serde::Serializer,
323 {
324 self.to_string().serialize(serializer)
325 }
326}
327
328impl<'de> Deserialize<'de> for AddressSeed {
329 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
330 where
331 D: serde::Deserializer<'de>,
332 {
333 let s = std::borrow::Cow::<'de, str>::deserialize(deserializer)?;
334 std::str::FromStr::from_str(&s).map_err(serde::de::Error::custom)
335 }
336}
337
338#[cfg(test)]
339mod test {
340 use std::str::FromStr;
341
342 use num_bigint::BigUint;
343 use proptest::prelude::*;
344
345 use super::AddressSeed;
346
347 #[test]
348 fn unpadded_slice() {
349 let seed = AddressSeed([0; 32]);
350 let zero: [u8; 1] = [0];
351 assert_eq!(seed.unpadded(), zero.as_slice());
352
353 let mut seed = AddressSeed([1; 32]);
354 seed.0[0] = 0;
355 assert_eq!(seed.unpadded(), [1; 31].as_slice());
356 }
357
358 proptest! {
359 #[test]
360 fn dont_crash_on_large_inputs(
361 bytes in proptest::collection::vec(any::<u8>(), 33..1024)
362 ) {
363 let big_int = BigUint::from_bytes_be(&bytes);
364 let radix10 = big_int.to_str_radix(10);
365
366 let _ = AddressSeed::from_str(&radix10);
368 }
369
370 #[test]
371 fn valid_address_seeds(
372 bytes in proptest::collection::vec(any::<u8>(), 1..=32)
373 ) {
374 let big_int = BigUint::from_bytes_be(&bytes);
375 let radix10 = big_int.to_str_radix(10);
376
377 let seed = AddressSeed::from_str(&radix10).unwrap();
378 assert_eq!(radix10, seed.to_string());
379 seed.unpadded();
381 }
382 }
383}