iota_sdk/
wallet_context.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{collections::BTreeSet, path::Path, sync::Arc};
6
7use anyhow::{anyhow, bail, ensure};
8use colored::Colorize;
9use futures::{StreamExt, TryStreamExt, future};
10use getset::{Getters, MutGetters};
11use iota_config::{Config, PersistedConfig};
12use iota_json_rpc_types::{
13    IotaObjectData, IotaObjectDataFilter, IotaObjectDataOptions, IotaObjectResponseQuery,
14    IotaTransactionBlockResponse, IotaTransactionBlockResponseOptions,
15};
16use iota_keys::keystore::{AccountKeystore, Keystore};
17use iota_sdk_types::{StructTag, crypto::Intent};
18use iota_types::{
19    base_types::{IotaAddress, ObjectID, ObjectRef},
20    crypto::IotaKeyPair,
21    gas_coin::GasCoin,
22    transaction::{Transaction, TransactionData, TransactionDataAPI},
23};
24use tokio::sync::RwLock;
25use tracing::warn;
26
27use crate::{
28    IotaClient, PagedFn,
29    iota_client_config::{IotaClientConfig, IotaEnv},
30};
31
32/// Wallet for managing accounts, objects, and interact with client APIs.
33// Mainly used in the CLI and tests.
34#[derive(Getters, MutGetters)]
35#[getset(get = "pub", get_mut = "pub")]
36pub struct WalletContext {
37    config: PersistedConfig<IotaClientConfig>,
38    request_timeout: Option<std::time::Duration>,
39    client: Arc<RwLock<Option<IotaClient>>>,
40    max_concurrent_requests: Option<u64>,
41    env_override: Option<String>,
42}
43
44impl WalletContext {
45    /// Create a new [`WalletContext`] with the config path to an existing
46    /// [`IotaClientConfig`] and optional parameters for the client.
47    pub fn new(config_path: &Path) -> Result<Self, anyhow::Error> {
48        let config: IotaClientConfig = PersistedConfig::read(config_path).map_err(|err| {
49            anyhow!(
50                "Cannot open wallet config file at {:?}. Err: {err}",
51                config_path
52            )
53        })?;
54
55        if let Some(active_address) = &config.active_address {
56            let addresses = match &config.keystore {
57                Keystore::File(file) => file.addresses(),
58                Keystore::InMem(mem) => mem.addresses(),
59            };
60            ensure!(
61                addresses.contains(active_address),
62                "error in '{}': active address not found in the keystore",
63                config_path.display()
64            );
65        }
66
67        if let Some(active_env) = &config.active_env {
68            ensure!(
69                config.get_env(active_env).is_some(),
70                "error in '{}': active environment not found in the envs list",
71                config_path.display()
72            );
73        }
74
75        let config = config.persisted(config_path);
76        let context = Self {
77            config,
78            request_timeout: None,
79            client: Default::default(),
80            max_concurrent_requests: None,
81            env_override: None,
82        };
83        Ok(context)
84    }
85
86    pub fn with_request_timeout(mut self, request_timeout: std::time::Duration) -> Self {
87        self.request_timeout = Some(request_timeout);
88        self
89    }
90
91    pub fn with_max_concurrent_requests(mut self, max_concurrent_requests: u64) -> Self {
92        self.max_concurrent_requests = Some(max_concurrent_requests);
93        self
94    }
95
96    pub fn with_env_override(mut self, env_override: String) -> Self {
97        self.env_override = Some(env_override);
98        self
99    }
100
101    /// Get all addresses from the keystore.
102    pub fn get_addresses(&self) -> Vec<IotaAddress> {
103        self.config.keystore.addresses()
104    }
105
106    pub fn get_env_override(&self) -> Option<String> {
107        self.env_override.clone()
108    }
109
110    /// Get the configured [`IotaClient`].
111    pub async fn get_client(&self) -> Result<IotaClient, anyhow::Error> {
112        let read = self.client.read().await;
113
114        Ok(if let Some(client) = read.as_ref() {
115            client.clone()
116        } else {
117            drop(read);
118            let client = self
119                .active_env()?
120                .create_rpc_client(self.request_timeout, self.max_concurrent_requests)
121                .await?;
122            if let Err(e) = client.check_api_version() {
123                warn!("{e}");
124                eprintln!("{}", format!("[warn] {e}").yellow().bold());
125            }
126            self.client.write().await.insert(client).clone()
127        })
128    }
129
130    /// Get the active [`IotaAddress`].
131    /// If not set, defaults to the first address in the keystore.
132    pub fn active_address(&self) -> Result<IotaAddress, anyhow::Error> {
133        if self.config.keystore.addresses().is_empty() {
134            bail!("No managed addresses. Create new address with the `new-address` command.");
135        }
136
137        Ok(if let Some(addr) = self.config.active_address() {
138            *addr
139        } else {
140            self.config.keystore().addresses()[0]
141        })
142    }
143
144    /// Get the active [`IotaEnv`].
145    /// If not set, defaults to the first environment in the config.
146    pub fn active_env(&self) -> Result<&IotaEnv, anyhow::Error> {
147        if self.config.envs.is_empty() {
148            bail!("No managed environments. Create new environment with the `new-env` command.");
149        }
150
151        if let Some(env_override) = &self.env_override {
152            self.config.get_env(env_override).ok_or_else(|| {
153                anyhow!(
154                    "Environment configuration not found for env [{}]",
155                    env_override
156                )
157            })
158        } else {
159            Ok(if self.config.active_env().is_some() {
160                self.config.get_active_env()?
161            } else {
162                &self.config.envs()[0]
163            })
164        }
165    }
166
167    /// Get the latest object reference given a object id.
168    pub async fn get_object_ref(&self, object_id: ObjectID) -> Result<ObjectRef, anyhow::Error> {
169        let client = self.get_client().await?;
170        Ok(client
171            .read_api()
172            .get_object_with_options(object_id, IotaObjectDataOptions::new())
173            .await?
174            .into_object()?
175            .object_ref())
176    }
177
178    /// Get all the gas objects (and conveniently, gas amounts) for the address.
179    pub async fn gas_objects(
180        &self,
181        address: IotaAddress,
182    ) -> Result<Vec<(u64, IotaObjectData)>, anyhow::Error> {
183        let client = self.get_client().await?;
184
185        let values_objects = PagedFn::stream(async |cursor| {
186            client
187                .read_api()
188                .get_owned_objects(
189                    address,
190                    IotaObjectResponseQuery::new(
191                        Some(IotaObjectDataFilter::StructType(StructTag::new_gas_coin())),
192                        Some(IotaObjectDataOptions::full_content()),
193                    ),
194                    cursor,
195                    None,
196                )
197                .await
198        })
199        .filter_map(|res| async {
200            match res {
201                Ok(res) => {
202                    if let Some(o) = res.data {
203                        match GasCoin::try_from(&o) {
204                            Ok(gas_coin) => Some(Ok((gas_coin.value(), o))),
205                            Err(e) => Some(Err(anyhow!("{e}"))),
206                        }
207                    } else {
208                        None
209                    }
210                }
211                Err(e) => Some(Err(anyhow!("{e}"))),
212            }
213        })
214        .try_collect::<Vec<_>>()
215        .await?;
216
217        Ok(values_objects)
218    }
219
220    /// Get the address that owns the object of the provided [`ObjectID`].
221    pub async fn get_object_owner(&self, id: &ObjectID) -> Result<IotaAddress, anyhow::Error> {
222        let client = self.get_client().await?;
223        let object = client
224            .read_api()
225            .get_object_with_options(*id, IotaObjectDataOptions::new().with_owner())
226            .await?
227            .into_object()?;
228        Ok(*object
229            .owner
230            .ok_or_else(|| anyhow!("Owner field is None"))?
231            .address_or_object()
232            .ok_or_else(|| anyhow::anyhow!("not an address or object owner"))?)
233    }
234
235    /// Get the address that owns the object, if an [`ObjectID`] is provided.
236    pub async fn try_get_object_owner(
237        &self,
238        id: &Option<ObjectID>,
239    ) -> Result<Option<IotaAddress>, anyhow::Error> {
240        if let Some(id) = id {
241            Ok(Some(self.get_object_owner(id).await?))
242        } else {
243            Ok(None)
244        }
245    }
246
247    /// Infer the sender of a transaction based on the gas objects provided. If
248    /// no gas objects are provided, assume the active address is the
249    /// sender.
250    pub async fn infer_sender(&mut self, gas: &[ObjectID]) -> Result<IotaAddress, anyhow::Error> {
251        if gas.is_empty() {
252            return self.active_address();
253        }
254
255        // Find the owners of all supplied object IDs
256        let owners = future::try_join_all(gas.iter().map(|id| self.get_object_owner(id))).await?;
257
258        // SAFETY `gas` is non-empty.
259        let owner = owners[0];
260
261        ensure!(
262            owners.iter().all(|o| o == &owner),
263            "Cannot infer sender, not all gas objects have the same owner."
264        );
265
266        Ok(owner)
267    }
268
269    /// Find a gas object which fits the budget.
270    pub async fn gas_for_owner_budget(
271        &self,
272        address: IotaAddress,
273        budget: u64,
274        forbidden_gas_objects: BTreeSet<ObjectID>,
275    ) -> Result<(u64, IotaObjectData), anyhow::Error> {
276        for o in self.gas_objects(address).await? {
277            if o.0 >= budget && !forbidden_gas_objects.contains(&o.1.object_id) {
278                return Ok((o.0, o.1));
279            }
280        }
281        bail!(
282            "No non-argument gas objects found for this address with value >= budget {budget}. Run iota client gas to check for gas objects."
283        )
284    }
285
286    /// Get the [`ObjectRef`] for gas objects owned by the provided address.
287    /// Maximum is RPC_QUERY_MAX_RESULT_LIMIT (50 by default).
288    pub async fn get_all_gas_objects_owned_by_address(
289        &self,
290        address: IotaAddress,
291    ) -> anyhow::Result<Vec<ObjectRef>> {
292        self.get_gas_objects_owned_by_address(address, None).await
293    }
294
295    /// Get a limited amount of [`ObjectRef`]s for gas objects owned by the
296    /// provided address. Max limit is RPC_QUERY_MAX_RESULT_LIMIT (50 by
297    /// default).
298    pub async fn get_gas_objects_owned_by_address(
299        &self,
300        address: IotaAddress,
301        limit: impl Into<Option<usize>>,
302    ) -> anyhow::Result<Vec<ObjectRef>> {
303        let client = self.get_client().await?;
304        let results: Vec<_> = client
305            .read_api()
306            .get_owned_objects(
307                address,
308                IotaObjectResponseQuery::new(
309                    Some(IotaObjectDataFilter::StructType(StructTag::new_gas_coin())),
310                    Some(IotaObjectDataOptions::full_content()),
311                ),
312                None,
313                limit,
314            )
315            .await?
316            .data
317            .into_iter()
318            .filter_map(|r| r.data.map(|o| o.object_ref()))
319            .collect();
320        Ok(results)
321    }
322
323    /// Given an address, return one gas object owned by this address.
324    /// The actual implementation just returns the first one returned by the
325    /// read api.
326    pub async fn get_one_gas_object_owned_by_address(
327        &self,
328        address: IotaAddress,
329    ) -> anyhow::Result<Option<ObjectRef>> {
330        Ok(self
331            .get_gas_objects_owned_by_address(address, 1)
332            .await?
333            .pop())
334    }
335
336    /// Return one address and all gas objects owned by that address.
337    pub async fn get_one_account(&self) -> anyhow::Result<(IotaAddress, Vec<ObjectRef>)> {
338        let address = self.get_addresses().pop().unwrap();
339        Ok((
340            address,
341            self.get_all_gas_objects_owned_by_address(address).await?,
342        ))
343    }
344
345    /// Return a gas object owned by an arbitrary address managed by the wallet.
346    pub async fn get_one_gas_object(&self) -> anyhow::Result<Option<(IotaAddress, ObjectRef)>> {
347        for address in self.get_addresses() {
348            if let Some(gas_object) = self.get_one_gas_object_owned_by_address(address).await? {
349                return Ok(Some((address, gas_object)));
350            }
351        }
352        Ok(None)
353    }
354
355    /// Return all the account addresses managed by the wallet and their owned
356    /// gas objects.
357    pub async fn get_all_accounts_and_gas_objects(
358        &self,
359    ) -> anyhow::Result<Vec<(IotaAddress, Vec<ObjectRef>)>> {
360        let mut result = vec![];
361        for address in self.get_addresses() {
362            let objects = self
363                .gas_objects(address)
364                .await?
365                .into_iter()
366                .map(|(_, o)| o.object_ref())
367                .collect();
368            result.push((address, objects));
369        }
370        Ok(result)
371    }
372
373    pub async fn get_reference_gas_price(&self) -> Result<u64, anyhow::Error> {
374        let client = self.get_client().await?;
375        let gas_price = client.governance_api().get_reference_gas_price().await?;
376        Ok(gas_price)
377    }
378
379    /// Add an account.
380    pub fn add_account(&mut self, alias: impl Into<Option<String>>, keypair: IotaKeyPair) {
381        self.config.keystore.add_key(alias.into(), keypair).unwrap();
382    }
383
384    /// Sign a transaction with a key currently managed by the WalletContext.
385    pub fn sign_transaction(&self, data: &TransactionData) -> Transaction {
386        let sig = self
387            .config
388            .keystore
389            .sign_secure(&data.sender(), data, Intent::iota_transaction())
390            .unwrap();
391        // TODO: To support sponsored transaction, we should also look at the gas owner.
392        Transaction::from_data(data.clone(), vec![sig])
393    }
394
395    /// Execute a transaction and wait for it to be locally executed on the
396    /// fullnode. Also expects the effects status to be
397    /// ExecutionStatus::Success.
398    pub async fn execute_transaction_must_succeed(
399        &self,
400        tx: Transaction,
401    ) -> IotaTransactionBlockResponse {
402        tracing::debug!("Executing transaction: {:?}", tx);
403        let response = self.execute_transaction_may_fail(tx).await.unwrap();
404        assert!(
405            response.status_ok().unwrap(),
406            "Transaction failed: {response:?}"
407        );
408        response
409    }
410
411    /// Execute a transaction and wait for it to be locally executed on the
412    /// fullnode. The transaction execution is not guaranteed to succeed and
413    /// may fail. This is usually only needed in non-test environment or the
414    /// caller is explicitly testing some failure behavior.
415    pub async fn execute_transaction_may_fail(
416        &self,
417        tx: Transaction,
418    ) -> anyhow::Result<IotaTransactionBlockResponse> {
419        let client = self.get_client().await?;
420        Ok(client
421            .quorum_driver_api()
422            .execute_transaction_block(
423                tx,
424                IotaTransactionBlockResponseOptions::new()
425                    .with_effects()
426                    .with_input()
427                    .with_events()
428                    .with_object_changes()
429                    .with_balance_changes(),
430                iota_types::quorum_driver_types::ExecuteTransactionRequestType::WaitForLocalExecution,
431            )
432            .await?)
433    }
434}