use std::{
hash::{Hash, Hasher},
sync::Arc,
};
use fastcrypto::{error::FastCryptoError, traits::ToFromBytes};
use fastcrypto_zkp::bn254::{
zk_login::{JWK, JwkId, ZkLoginInputs},
zk_login_api::{ZkLoginEnv, verify_zk_login},
};
use once_cell::sync::OnceCell;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use shared_crypto::intent::IntentMessage;
use crate::{
base_types::{EpochId, IotaAddress},
crypto::{DefaultHash, IotaSignature, PublicKey, Signature, SignatureScheme},
digests::ZKLoginInputsDigest,
error::{IotaError, IotaResult},
signature::{AuthenticatorTrait, VerifyParams},
signature_verification::VerifiedDigestCache,
};
#[cfg(test)]
#[path = "unit_tests/zk_login_authenticator_test.rs"]
mod zk_login_authenticator_test;
#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ZkLoginAuthenticator {
pub inputs: ZkLoginInputs,
max_epoch: EpochId,
user_signature: Signature,
#[serde(skip)]
pub bytes: OnceCell<Vec<u8>>,
}
#[derive(Serialize, Deserialize)]
struct ZkLoginCachingParams {
inputs: ZkLoginInputs,
max_epoch: EpochId,
extended_pk_bytes: Vec<u8>,
}
impl ZkLoginAuthenticator {
fn get_caching_params(&self) -> ZkLoginCachingParams {
let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
extended_pk_bytes.extend(self.user_signature.public_key_bytes());
ZkLoginCachingParams {
inputs: self.inputs.clone(),
max_epoch: self.max_epoch,
extended_pk_bytes,
}
}
pub fn hash_inputs(&self) -> ZKLoginInputsDigest {
use fastcrypto::hash::HashFunction;
let mut hasher = DefaultHash::default();
hasher.update(bcs::to_bytes(&self.get_caching_params()).expect("serde should not fail"));
ZKLoginInputsDigest::new(hasher.finalize().into())
}
pub fn new(inputs: ZkLoginInputs, max_epoch: EpochId, user_signature: Signature) -> Self {
Self {
inputs,
max_epoch,
user_signature,
bytes: OnceCell::new(),
}
}
pub fn get_pk(&self) -> IotaResult<PublicKey> {
PublicKey::from_zklogin_inputs(&self.inputs)
}
pub fn get_iss(&self) -> &str {
self.inputs.get_iss()
}
pub fn get_max_epoch(&self) -> EpochId {
self.max_epoch
}
#[cfg(feature = "test-utils")]
pub fn user_signature_mut_for_testing(&mut self) -> &mut Signature {
&mut self.user_signature
}
#[cfg(feature = "test-utils")]
pub fn max_epoch_mut_for_testing(&mut self) -> &mut EpochId {
&mut self.max_epoch
}
#[cfg(feature = "test-utils")]
pub fn zk_login_inputs_mut_for_testing(&mut self) -> &mut ZkLoginInputs {
&mut self.inputs
}
}
impl PartialEq for ZkLoginAuthenticator {
fn eq(&self, other: &Self) -> bool {
self.as_ref() == other.as_ref()
}
}
impl Eq for ZkLoginAuthenticator {}
impl Hash for ZkLoginAuthenticator {
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_ref().hash(state);
}
}
impl AuthenticatorTrait for ZkLoginAuthenticator {
fn verify_user_authenticator_epoch(
&self,
epoch: EpochId,
max_epoch_upper_bound_delta: Option<u64>,
) -> IotaResult {
if let Some(delta) = max_epoch_upper_bound_delta {
let max_epoch_upper_bound = epoch + delta;
if self.get_max_epoch() > max_epoch_upper_bound {
return Err(IotaError::InvalidSignature {
error: format!(
"ZKLogin max epoch too large {}, current epoch {}, max accepted: {}",
self.get_max_epoch(),
epoch,
max_epoch_upper_bound
),
});
}
}
if epoch > self.get_max_epoch() {
return Err(IotaError::InvalidSignature {
error: format!(
"ZKLogin expired at epoch {}, current epoch {}",
self.get_max_epoch(),
epoch
),
});
}
Ok(())
}
fn verify_claims<T>(
&self,
intent_msg: &IntentMessage<T>,
author: IotaAddress,
aux_verify_data: &VerifyParams,
zklogin_inputs_cache: Arc<VerifiedDigestCache<ZKLoginInputsDigest>>,
) -> IotaResult
where
T: Serialize,
{
if author != IotaAddress::try_from_unpadded(&self.inputs)? {
return Err(IotaError::InvalidAddress);
}
self.user_signature.verify_secure(
intent_msg,
author,
SignatureScheme::ZkLoginAuthenticator,
)?;
if zklogin_inputs_cache.is_cached(&self.hash_inputs()) {
Ok(())
} else {
let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
extended_pk_bytes.extend(self.user_signature.public_key_bytes());
let res = verify_zklogin_inputs_wrapper(
self.get_caching_params(),
&aux_verify_data.oidc_provider_jwks,
&aux_verify_data.zk_login_env,
)
.map_err(|e| IotaError::InvalidSignature {
error: e.to_string(),
});
match res {
Ok(_) => {
zklogin_inputs_cache.cache_digest(self.hash_inputs());
Ok(())
}
Err(e) => Err(e),
}
}
}
}
fn verify_zklogin_inputs_wrapper(
params: ZkLoginCachingParams,
all_jwk: &im::HashMap<JwkId, JWK>,
env: &ZkLoginEnv,
) -> IotaResult<()> {
verify_zk_login(
¶ms.inputs,
params.max_epoch,
¶ms.extended_pk_bytes,
all_jwk,
env,
)
.map_err(|e| IotaError::InvalidSignature {
error: e.to_string(),
})
}
impl ToFromBytes for ZkLoginAuthenticator {
fn from_bytes(bytes: &[u8]) -> Result<Self, FastCryptoError> {
if bytes.first().ok_or(FastCryptoError::InvalidInput)?
!= &SignatureScheme::ZkLoginAuthenticator.flag()
{
return Err(FastCryptoError::InvalidInput);
}
let mut zk_login: ZkLoginAuthenticator =
bcs::from_bytes(&bytes[1..]).map_err(|_| FastCryptoError::InvalidSignature)?;
zk_login.inputs.init()?;
Ok(zk_login)
}
}
impl AsRef<[u8]> for ZkLoginAuthenticator {
fn as_ref(&self) -> &[u8] {
self.bytes
.get_or_try_init::<_, eyre::Report>(|| {
let as_bytes = bcs::to_bytes(self).expect("BCS serialization should not fail");
let mut bytes = Vec::with_capacity(1 + as_bytes.len());
bytes.push(SignatureScheme::ZkLoginAuthenticator.flag());
bytes.extend_from_slice(as_bytes.as_slice());
Ok(bytes)
})
.expect("OnceCell invariant violated")
}
}
#[derive(Debug, Clone)]
pub struct AddressSeed([u8; 32]);
impl AddressSeed {
pub fn unpadded(&self) -> &[u8] {
let mut buf = self.0.as_slice();
while !buf.is_empty() && buf[0] == 0 {
buf = &buf[1..];
}
if buf.is_empty() { &self.0[31..] } else { buf }
}
pub fn padded(&self) -> &[u8] {
&self.0
}
}
impl std::fmt::Display for AddressSeed {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let big_int = num_bigint::BigUint::from_bytes_be(&self.0);
let radix10 = big_int.to_str_radix(10);
f.write_str(&radix10)
}
}
#[derive(thiserror::Error, Debug)]
pub enum AddressSeedParseError {
#[error("unable to parse radix10 encoded value `{0}`")]
Parse(#[from] num_bigint::ParseBigIntError),
#[error("larger than 32 bytes")]
TooBig,
}
impl std::str::FromStr for AddressSeed {
type Err = AddressSeedParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let big_int = <num_bigint::BigUint as num_traits::Num>::from_str_radix(s, 10)?;
let be_bytes = big_int.to_bytes_be();
let len = be_bytes.len();
let mut buf = [0; 32];
if len > 32 {
return Err(AddressSeedParseError::TooBig);
}
buf[32 - len..].copy_from_slice(&be_bytes);
Ok(Self(buf))
}
}
impl Serialize for AddressSeed {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.to_string().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for AddressSeed {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = std::borrow::Cow::<'de, str>::deserialize(deserializer)?;
std::str::FromStr::from_str(&s).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod test {
use std::str::FromStr;
use num_bigint::BigUint;
use proptest::prelude::*;
use super::AddressSeed;
#[test]
fn unpadded_slice() {
let seed = AddressSeed([0; 32]);
let zero: [u8; 1] = [0];
assert_eq!(seed.unpadded(), zero.as_slice());
let mut seed = AddressSeed([1; 32]);
seed.0[0] = 0;
assert_eq!(seed.unpadded(), [1; 31].as_slice());
}
proptest! {
#[test]
fn dont_crash_on_large_inputs(
bytes in proptest::collection::vec(any::<u8>(), 33..1024)
) {
let big_int = BigUint::from_bytes_be(&bytes);
let radix10 = big_int.to_str_radix(10);
let _ = AddressSeed::from_str(&radix10);
}
#[test]
fn valid_address_seeds(
bytes in proptest::collection::vec(any::<u8>(), 1..=32)
) {
let big_int = BigUint::from_bytes_be(&bytes);
let radix10 = big_int.to_str_radix(10);
let seed = AddressSeed::from_str(&radix10).unwrap();
assert_eq!(radix10, seed.to_string());
seed.unpadded();
}
}
}