iota_sdk/
wallet_context.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{collections::BTreeSet, path::Path, sync::Arc};
6
7use anyhow::{anyhow, bail, ensure};
8use colored::Colorize;
9use futures::{StreamExt, TryStreamExt, future};
10use getset::{Getters, MutGetters};
11use iota_config::{Config, PersistedConfig};
12use iota_json_rpc_types::{
13    IotaObjectData, IotaObjectDataFilter, IotaObjectDataOptions, IotaObjectResponseQuery,
14    IotaTransactionBlockResponse, IotaTransactionBlockResponseOptions,
15};
16use iota_keys::keystore::{AccountKeystore, Keystore};
17use iota_sdk_types::crypto::Intent;
18use iota_types::{
19    base_types::{IotaAddress, ObjectID, ObjectRef},
20    crypto::IotaKeyPair,
21    gas_coin::GasCoin,
22    transaction::{Transaction, TransactionData, TransactionDataAPI},
23};
24use tokio::sync::RwLock;
25use tracing::warn;
26
27use crate::{
28    IotaClient, PagedFn,
29    iota_client_config::{IotaClientConfig, IotaEnv},
30};
31
32/// Wallet for managing accounts, objects, and interact with client APIs.
33// Mainly used in the CLI and tests.
34#[derive(Getters, MutGetters)]
35#[getset(get = "pub", get_mut = "pub")]
36pub struct WalletContext {
37    config: PersistedConfig<IotaClientConfig>,
38    request_timeout: Option<std::time::Duration>,
39    client: Arc<RwLock<Option<IotaClient>>>,
40    max_concurrent_requests: Option<u64>,
41    env_override: Option<String>,
42}
43
44impl WalletContext {
45    /// Create a new [`WalletContext`] with the config path to an existing
46    /// [`IotaClientConfig`] and optional parameters for the client.
47    pub fn new(config_path: &Path) -> Result<Self, anyhow::Error> {
48        let config: IotaClientConfig = PersistedConfig::read(config_path).map_err(|err| {
49            anyhow!(
50                "Cannot open wallet config file at {:?}. Err: {err}",
51                config_path
52            )
53        })?;
54
55        if let Some(active_address) = &config.active_address {
56            let addresses = match &config.keystore {
57                Keystore::File(file) => file.addresses(),
58                Keystore::InMem(mem) => mem.addresses(),
59            };
60            ensure!(
61                addresses.contains(active_address),
62                "error in '{}': active address not found in the keystore",
63                config_path.display()
64            );
65        }
66
67        if let Some(active_env) = &config.active_env {
68            ensure!(
69                config.get_env(active_env).is_some(),
70                "error in '{}': active environment not found in the envs list",
71                config_path.display()
72            );
73        }
74
75        let config = config.persisted(config_path);
76        let context = Self {
77            config,
78            request_timeout: None,
79            client: Default::default(),
80            max_concurrent_requests: None,
81            env_override: None,
82        };
83        Ok(context)
84    }
85
86    pub fn with_request_timeout(mut self, request_timeout: std::time::Duration) -> Self {
87        self.request_timeout = Some(request_timeout);
88        self
89    }
90
91    pub fn with_max_concurrent_requests(mut self, max_concurrent_requests: u64) -> Self {
92        self.max_concurrent_requests = Some(max_concurrent_requests);
93        self
94    }
95
96    pub fn with_env_override(mut self, env_override: String) -> Self {
97        self.env_override = Some(env_override);
98        self
99    }
100
101    /// Get all addresses from the keystore.
102    pub fn get_addresses(&self) -> Vec<IotaAddress> {
103        self.config.keystore.addresses()
104    }
105
106    pub fn get_env_override(&self) -> Option<String> {
107        self.env_override.clone()
108    }
109
110    /// Get the configured [`IotaClient`].
111    pub async fn get_client(&self) -> Result<IotaClient, anyhow::Error> {
112        let read = self.client.read().await;
113
114        Ok(if let Some(client) = read.as_ref() {
115            client.clone()
116        } else {
117            drop(read);
118            let client = self
119                .active_env()?
120                .create_rpc_client(self.request_timeout, self.max_concurrent_requests)
121                .await?;
122            if let Err(e) = client.check_api_version() {
123                warn!("{e}");
124                eprintln!("{}", format!("[warn] {e}").yellow().bold());
125            }
126            self.client.write().await.insert(client).clone()
127        })
128    }
129
130    /// Get the active [`IotaAddress`].
131    /// If not set, defaults to the first address in the keystore.
132    pub fn active_address(&self) -> Result<IotaAddress, anyhow::Error> {
133        if self.config.keystore.addresses().is_empty() {
134            bail!("No managed addresses. Create new address with the `new-address` command.");
135        }
136
137        Ok(if let Some(addr) = self.config.active_address() {
138            *addr
139        } else {
140            self.config.keystore().addresses()[0]
141        })
142    }
143
144    /// Get the active [`IotaEnv`].
145    /// If not set, defaults to the first environment in the config.
146    pub fn active_env(&self) -> Result<&IotaEnv, anyhow::Error> {
147        if self.config.envs.is_empty() {
148            bail!("No managed environments. Create new environment with the `new-env` command.");
149        }
150
151        if let Some(env_override) = &self.env_override {
152            self.config.get_env(env_override).ok_or_else(|| {
153                anyhow!(
154                    "Environment configuration not found for env [{}]",
155                    env_override
156                )
157            })
158        } else {
159            Ok(if self.config.active_env().is_some() {
160                self.config.get_active_env()?
161            } else {
162                &self.config.envs()[0]
163            })
164        }
165    }
166
167    /// Get the latest object reference given a object id.
168    pub async fn get_object_ref(&self, object_id: ObjectID) -> Result<ObjectRef, anyhow::Error> {
169        let client = self.get_client().await?;
170        Ok(client
171            .read_api()
172            .get_object_with_options(object_id, IotaObjectDataOptions::new())
173            .await?
174            .into_object()?
175            .object_ref())
176    }
177
178    /// Get all the gas objects (and conveniently, gas amounts) for the address.
179    pub async fn gas_objects(
180        &self,
181        address: IotaAddress,
182    ) -> Result<Vec<(u64, IotaObjectData)>, anyhow::Error> {
183        let client = self.get_client().await?;
184
185        let values_objects = PagedFn::stream(async |cursor| {
186            client
187                .read_api()
188                .get_owned_objects(
189                    address,
190                    IotaObjectResponseQuery::new(
191                        Some(IotaObjectDataFilter::StructType(GasCoin::type_())),
192                        Some(IotaObjectDataOptions::full_content()),
193                    ),
194                    cursor,
195                    None,
196                )
197                .await
198        })
199        .filter_map(|res| async {
200            match res {
201                Ok(res) => {
202                    if let Some(o) = res.data {
203                        match GasCoin::try_from(&o) {
204                            Ok(gas_coin) => Some(Ok((gas_coin.value(), o.clone()))),
205                            Err(e) => Some(Err(anyhow!("{e}"))),
206                        }
207                    } else {
208                        None
209                    }
210                }
211                Err(e) => Some(Err(anyhow!("{e}"))),
212            }
213        })
214        .try_collect::<Vec<_>>()
215        .await?;
216
217        Ok(values_objects)
218    }
219
220    /// Get the address that owns the object of the provided [`ObjectID`].
221    pub async fn get_object_owner(&self, id: &ObjectID) -> Result<IotaAddress, anyhow::Error> {
222        let client = self.get_client().await?;
223        let object = client
224            .read_api()
225            .get_object_with_options(*id, IotaObjectDataOptions::new().with_owner())
226            .await?
227            .into_object()?;
228        Ok(object
229            .owner
230            .ok_or_else(|| anyhow!("Owner field is None"))?
231            .get_owner_address()?)
232    }
233
234    /// Get the address that owns the object, if an [`ObjectID`] is provided.
235    pub async fn try_get_object_owner(
236        &self,
237        id: &Option<ObjectID>,
238    ) -> Result<Option<IotaAddress>, anyhow::Error> {
239        if let Some(id) = id {
240            Ok(Some(self.get_object_owner(id).await?))
241        } else {
242            Ok(None)
243        }
244    }
245
246    /// Infer the sender of a transaction based on the gas objects provided. If
247    /// no gas objects are provided, assume the active address is the
248    /// sender.
249    pub async fn infer_sender(&mut self, gas: &[ObjectID]) -> Result<IotaAddress, anyhow::Error> {
250        if gas.is_empty() {
251            return self.active_address();
252        }
253
254        // Find the owners of all supplied object IDs
255        let owners = future::try_join_all(gas.iter().map(|id| self.get_object_owner(id))).await?;
256
257        // SAFETY `gas` is non-empty.
258        let owner = owners[0];
259
260        ensure!(
261            owners.iter().all(|o| o == &owner),
262            "Cannot infer sender, not all gas objects have the same owner."
263        );
264
265        Ok(owner)
266    }
267
268    /// Find a gas object which fits the budget.
269    pub async fn gas_for_owner_budget(
270        &self,
271        address: IotaAddress,
272        budget: u64,
273        forbidden_gas_objects: BTreeSet<ObjectID>,
274    ) -> Result<(u64, IotaObjectData), anyhow::Error> {
275        for o in self.gas_objects(address).await? {
276            if o.0 >= budget && !forbidden_gas_objects.contains(&o.1.object_id) {
277                return Ok((o.0, o.1));
278            }
279        }
280        bail!(
281            "No non-argument gas objects found for this address with value >= budget {budget}. Run iota client gas to check for gas objects."
282        )
283    }
284
285    /// Get the [`ObjectRef`] for gas objects owned by the provided address.
286    /// Maximum is RPC_QUERY_MAX_RESULT_LIMIT (50 by default).
287    pub async fn get_all_gas_objects_owned_by_address(
288        &self,
289        address: IotaAddress,
290    ) -> anyhow::Result<Vec<ObjectRef>> {
291        self.get_gas_objects_owned_by_address(address, None).await
292    }
293
294    /// Get a limited amount of [`ObjectRef`]s for gas objects owned by the
295    /// provided address. Max limit is RPC_QUERY_MAX_RESULT_LIMIT (50 by
296    /// default).
297    pub async fn get_gas_objects_owned_by_address(
298        &self,
299        address: IotaAddress,
300        limit: impl Into<Option<usize>>,
301    ) -> anyhow::Result<Vec<ObjectRef>> {
302        let client = self.get_client().await?;
303        let results: Vec<_> = client
304            .read_api()
305            .get_owned_objects(
306                address,
307                IotaObjectResponseQuery::new(
308                    Some(IotaObjectDataFilter::StructType(GasCoin::type_())),
309                    Some(IotaObjectDataOptions::full_content()),
310                ),
311                None,
312                limit,
313            )
314            .await?
315            .data
316            .into_iter()
317            .filter_map(|r| r.data.map(|o| o.object_ref()))
318            .collect();
319        Ok(results)
320    }
321
322    /// Given an address, return one gas object owned by this address.
323    /// The actual implementation just returns the first one returned by the
324    /// read api.
325    pub async fn get_one_gas_object_owned_by_address(
326        &self,
327        address: IotaAddress,
328    ) -> anyhow::Result<Option<ObjectRef>> {
329        Ok(self
330            .get_gas_objects_owned_by_address(address, 1)
331            .await?
332            .pop())
333    }
334
335    /// Return one address and all gas objects owned by that address.
336    pub async fn get_one_account(&self) -> anyhow::Result<(IotaAddress, Vec<ObjectRef>)> {
337        let address = self.get_addresses().pop().unwrap();
338        Ok((
339            address,
340            self.get_all_gas_objects_owned_by_address(address).await?,
341        ))
342    }
343
344    /// Return a gas object owned by an arbitrary address managed by the wallet.
345    pub async fn get_one_gas_object(&self) -> anyhow::Result<Option<(IotaAddress, ObjectRef)>> {
346        for address in self.get_addresses() {
347            if let Some(gas_object) = self.get_one_gas_object_owned_by_address(address).await? {
348                return Ok(Some((address, gas_object)));
349            }
350        }
351        Ok(None)
352    }
353
354    /// Return all the account addresses managed by the wallet and their owned
355    /// gas objects.
356    pub async fn get_all_accounts_and_gas_objects(
357        &self,
358    ) -> anyhow::Result<Vec<(IotaAddress, Vec<ObjectRef>)>> {
359        let mut result = vec![];
360        for address in self.get_addresses() {
361            let objects = self
362                .gas_objects(address)
363                .await?
364                .into_iter()
365                .map(|(_, o)| o.object_ref())
366                .collect();
367            result.push((address, objects));
368        }
369        Ok(result)
370    }
371
372    pub async fn get_reference_gas_price(&self) -> Result<u64, anyhow::Error> {
373        let client = self.get_client().await?;
374        let gas_price = client.governance_api().get_reference_gas_price().await?;
375        Ok(gas_price)
376    }
377
378    /// Add an account.
379    pub fn add_account(&mut self, alias: impl Into<Option<String>>, keypair: IotaKeyPair) {
380        self.config.keystore.add_key(alias.into(), keypair).unwrap();
381    }
382
383    /// Sign a transaction with a key currently managed by the WalletContext.
384    pub fn sign_transaction(&self, data: &TransactionData) -> Transaction {
385        let sig = self
386            .config
387            .keystore
388            .sign_secure(&data.sender(), data, Intent::iota_transaction())
389            .unwrap();
390        // TODO: To support sponsored transaction, we should also look at the gas owner.
391        Transaction::from_data(data.clone(), vec![sig])
392    }
393
394    /// Execute a transaction and wait for it to be locally executed on the
395    /// fullnode. Also expects the effects status to be
396    /// ExecutionStatus::Success.
397    pub async fn execute_transaction_must_succeed(
398        &self,
399        tx: Transaction,
400    ) -> IotaTransactionBlockResponse {
401        tracing::debug!("Executing transaction: {:?}", tx);
402        let response = self.execute_transaction_may_fail(tx).await.unwrap();
403        assert!(
404            response.status_ok().unwrap(),
405            "Transaction failed: {response:?}"
406        );
407        response
408    }
409
410    /// Execute a transaction and wait for it to be locally executed on the
411    /// fullnode. The transaction execution is not guaranteed to succeed and
412    /// may fail. This is usually only needed in non-test environment or the
413    /// caller is explicitly testing some failure behavior.
414    pub async fn execute_transaction_may_fail(
415        &self,
416        tx: Transaction,
417    ) -> anyhow::Result<IotaTransactionBlockResponse> {
418        let client = self.get_client().await?;
419        Ok(client
420            .quorum_driver_api()
421            .execute_transaction_block(
422                tx,
423                IotaTransactionBlockResponseOptions::new()
424                    .with_effects()
425                    .with_input()
426                    .with_events()
427                    .with_object_changes()
428                    .with_balance_changes(),
429                iota_types::quorum_driver_types::ExecuteTransactionRequestType::WaitForLocalExecution,
430            )
431            .await?)
432    }
433}