iota_rpc_loadgen/payload/
rpc_command_processor.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    fmt,
7    fs::{self, File},
8    path::PathBuf,
9    sync::Arc,
10    time::{Duration, Instant},
11};
12
13use anyhow::{Result, anyhow};
14use async_trait::async_trait;
15use dashmap::{DashMap, DashSet};
16use futures::future::join_all;
17use iota_json_rpc_types::{
18    IotaExecutionStatus, IotaObjectDataOptions, IotaTransactionBlockDataAPI,
19    IotaTransactionBlockEffectsAPI, IotaTransactionBlockResponse,
20    IotaTransactionBlockResponseOptions,
21};
22use iota_sdk::{IotaClient, IotaClientBuilder};
23use iota_types::{
24    base_types::{IotaAddress, ObjectID, ObjectRef},
25    crypto::{AccountKeyPair, EncodeDecodeBase64, IotaKeyPair, Signature, get_key_pair},
26    digests::TransactionDigest,
27    quorum_driver_types::ExecuteTransactionRequestType,
28    transaction::{Transaction, TransactionData},
29};
30use serde::{Serialize, de::DeserializeOwned};
31use shared_crypto::intent::{Intent, IntentMessage};
32use tokio::{sync::RwLock, time::sleep};
33use tracing::{debug, info};
34
35use super::MultiGetTransactionBlocks;
36use crate::{
37    load_test::LoadTestConfig,
38    payload::{
39        Command, CommandData, DryRun, GetAllBalances, GetCheckpoints, GetObject, MultiGetObjects,
40        Payload, ProcessPayload, Processor, QueryTransactionBlocks, SignerInfo,
41        checkpoint_utils::get_latest_checkpoint_stats, validation::chunk_entities,
42    },
43};
44
45pub(crate) const DEFAULT_GAS_BUDGET: u64 = 500_000_000;
46pub(crate) const DEFAULT_LARGE_GAS_BUDGET: u64 = 50_000_000_000;
47pub(crate) const MAX_NUM_NEW_OBJECTS_IN_SINGLE_TRANSACTION: usize = 120;
48
49#[derive(Clone)]
50pub struct RpcCommandProcessor {
51    clients: Arc<RwLock<Vec<IotaClient>>>,
52    // for equivocation prevention in `WaitForEffectsCert` mode
53    object_ref_cache: Arc<DashMap<ObjectID, ObjectRef>>,
54    transaction_digests: Arc<DashSet<TransactionDigest>>,
55    addresses: Arc<DashSet<IotaAddress>>,
56    data_dir: String,
57}
58
59impl RpcCommandProcessor {
60    pub async fn new(urls: &[String], data_dir: String) -> Self {
61        let clients = join_all(urls.iter().map(|url| async {
62            IotaClientBuilder::default()
63                .max_concurrent_requests(usize::MAX)
64                .request_timeout(Duration::from_secs(60))
65                .build(url.clone())
66                .await
67                .unwrap()
68        }))
69        .await;
70
71        Self {
72            clients: Arc::new(RwLock::new(clients)),
73            object_ref_cache: Arc::new(DashMap::new()),
74            transaction_digests: Arc::new(DashSet::new()),
75            addresses: Arc::new(DashSet::new()),
76            data_dir,
77        }
78    }
79
80    async fn process_command_data(
81        &self,
82        command: &CommandData,
83        signer_info: &Option<SignerInfo>,
84    ) -> Result<()> {
85        match command {
86            CommandData::DryRun(ref v) => self.process(v, signer_info).await,
87            CommandData::GetCheckpoints(ref v) => self.process(v, signer_info).await,
88            CommandData::PayIota(ref v) => self.process(v, signer_info).await,
89            CommandData::QueryTransactionBlocks(ref v) => self.process(v, signer_info).await,
90            CommandData::MultiGetTransactionBlocks(ref v) => self.process(v, signer_info).await,
91            CommandData::MultiGetObjects(ref v) => self.process(v, signer_info).await,
92            CommandData::GetObject(ref v) => self.process(v, signer_info).await,
93            CommandData::GetAllBalances(ref v) => self.process(v, signer_info).await,
94            CommandData::GetReferenceGasPrice(ref v) => self.process(v, signer_info).await,
95        }
96    }
97
98    pub(crate) async fn get_clients(&self) -> Result<Vec<IotaClient>> {
99        let read = self.clients.read().await;
100        Ok(read.clone())
101    }
102
103    /// sign_and_execute transaction and update `object_ref_cache`
104    pub(crate) async fn sign_and_execute(
105        &self,
106        client: &IotaClient,
107        keypair: &IotaKeyPair,
108        txn_data: TransactionData,
109        request_type: ExecuteTransactionRequestType,
110    ) -> IotaTransactionBlockResponse {
111        let resp = sign_and_execute(client, keypair, txn_data, request_type).await;
112        let effects = resp.effects.as_ref().unwrap();
113        let object_ref_cache = self.object_ref_cache.clone();
114        // NOTE: for now we don't need to care about deleted objects
115        for (owned_object_ref, _) in effects.all_changed_objects() {
116            let id = owned_object_ref.object_id();
117            let current = object_ref_cache.get_mut(&id);
118            match current {
119                Some(mut c) => {
120                    if c.1 < owned_object_ref.version() {
121                        *c = owned_object_ref.reference.to_object_ref();
122                    }
123                }
124                None => {
125                    object_ref_cache.insert(id, owned_object_ref.reference.to_object_ref());
126                }
127            };
128        }
129        resp
130    }
131
132    /// get the latest object ref from local cache, and if not exist, fetch from
133    /// fullnode
134    pub(crate) async fn get_object_ref(
135        &self,
136        client: &IotaClient,
137        object_id: &ObjectID,
138    ) -> ObjectRef {
139        let object_ref_cache = self.object_ref_cache.clone();
140        let current = object_ref_cache.get_mut(object_id);
141        match current {
142            Some(c) => *c,
143            None => {
144                let resp = client
145                    .read_api()
146                    .get_object_with_options(*object_id, IotaObjectDataOptions::new())
147                    .await
148                    .unwrap_or_else(|_| panic!("Unable to fetch object reference {object_id}"));
149                let object_ref = resp.object_ref_if_exists().unwrap_or_else(|| {
150                    panic!("Unable to extract object reference {object_id} from response {resp:?}")
151                });
152                object_ref_cache.insert(*object_id, object_ref);
153                object_ref
154            }
155        }
156    }
157
158    pub(crate) fn add_transaction_digests(&self, digests: Vec<TransactionDigest>) {
159        // extend method requires mutable access to the underlying DashSet, which is not
160        // allowed by the Arc
161        for digest in digests {
162            self.transaction_digests.insert(digest);
163        }
164    }
165
166    pub(crate) fn add_addresses_from_response(&self, responses: &[IotaTransactionBlockResponse]) {
167        for response in responses {
168            let transaction = &response.transaction;
169            if let Some(transaction) = transaction {
170                let data = &transaction.data;
171                self.addresses.insert(*data.sender());
172            }
173        }
174    }
175
176    pub(crate) fn add_object_ids_from_response(&self, responses: &[IotaTransactionBlockResponse]) {
177        for response in responses {
178            let effects = &response.effects;
179            if let Some(effects) = effects {
180                let all_changed_objects = effects.all_changed_objects();
181                for (object_ref, _) in all_changed_objects {
182                    self.object_ref_cache
183                        .insert(object_ref.object_id(), object_ref.reference.to_object_ref());
184                }
185            }
186        }
187    }
188
189    pub(crate) fn dump_cache_to_file(&self) {
190        // TODO: be more granular
191        let digests: Vec<TransactionDigest> = self.transaction_digests.iter().map(|x| *x).collect();
192        if !digests.is_empty() {
193            debug!("dumping transaction digests to file {:?}", digests.len());
194            write_data_to_file(
195                &digests,
196                &format!("{}/{}", &self.data_dir, CacheType::TransactionDigest),
197            )
198            .unwrap();
199        }
200
201        let addresses: Vec<IotaAddress> = self.addresses.iter().map(|x| *x).collect();
202        if !addresses.is_empty() {
203            debug!("dumping addresses to file {:?}", addresses.len());
204            write_data_to_file(
205                &addresses,
206                &format!("{}/{}", &self.data_dir, CacheType::IotaAddress),
207            )
208            .unwrap();
209        }
210
211        let mut object_ids: Vec<ObjectID> = Vec::new();
212        let cloned_object_cache = self.object_ref_cache.clone();
213
214        for item in cloned_object_cache.iter() {
215            let object_id = item.key();
216            object_ids.push(*object_id);
217        }
218
219        if !object_ids.is_empty() {
220            debug!("dumping object_ids to file {:?}", object_ids.len());
221            write_data_to_file(
222                &object_ids,
223                &format!("{}/{}", &self.data_dir, CacheType::ObjectID),
224            )
225            .unwrap();
226        }
227    }
228}
229
230#[async_trait]
231impl Processor for RpcCommandProcessor {
232    async fn apply(&self, payload: &Payload) -> Result<()> {
233        let commands = &payload.commands;
234        for command in commands.iter() {
235            let repeat_interval = command.repeat_interval;
236            let repeat_n_times = command.repeat_n_times;
237            for i in 0..=repeat_n_times {
238                let start_time = Instant::now();
239
240                self.process_command_data(&command.data, &payload.signer_info)
241                    .await?;
242
243                let elapsed_time = start_time.elapsed();
244                if elapsed_time < repeat_interval {
245                    let sleep_duration = repeat_interval - elapsed_time;
246                    sleep(sleep_duration).await;
247                }
248                let clients = self.get_clients().await?;
249                let checkpoint_stats = get_latest_checkpoint_stats(&clients, None).await;
250                info!(
251                    "Repeat {i}: Checkpoint stats {checkpoint_stats}, elapse {:.4} since last repeat",
252                    elapsed_time.as_secs_f64()
253                );
254            }
255        }
256        Ok(())
257    }
258
259    async fn prepare(&self, config: &LoadTestConfig) -> Result<Vec<Payload>> {
260        let clients = self.get_clients().await?;
261        let Command {
262            repeat_n_times,
263            repeat_interval,
264            ..
265        } = &config.command;
266        let command_payloads = match &config.command.data {
267            CommandData::GetCheckpoints(data) => {
268                if !config.divide_tasks {
269                    vec![config.command.clone(); config.num_threads]
270                } else {
271                    divide_checkpoint_tasks(&clients, data, config.num_threads).await
272                }
273            }
274            CommandData::QueryTransactionBlocks(data) => {
275                if !config.divide_tasks {
276                    vec![config.command.clone(); config.num_threads]
277                } else {
278                    divide_query_transaction_blocks_tasks(data, config.num_threads).await
279                }
280            }
281            CommandData::MultiGetTransactionBlocks(data) => {
282                if !config.divide_tasks {
283                    vec![config.command.clone(); config.num_threads]
284                } else {
285                    divide_multi_get_transaction_blocks_tasks(data, config.num_threads).await
286                }
287            }
288            CommandData::GetAllBalances(data) => {
289                if !config.divide_tasks {
290                    vec![config.command.clone(); config.num_threads]
291                } else {
292                    divide_get_all_balances_tasks(data, config.num_threads).await
293                }
294            }
295            CommandData::MultiGetObjects(data) => {
296                if !config.divide_tasks {
297                    vec![config.command.clone(); config.num_threads]
298                } else {
299                    divide_multi_get_objects_tasks(data, config.num_threads).await
300                }
301            }
302            CommandData::GetObject(data) => {
303                if !config.divide_tasks {
304                    vec![config.command.clone(); config.num_threads]
305                } else {
306                    divide_get_object_tasks(data, config.num_threads).await
307                }
308            }
309            _ => vec![config.command.clone(); config.num_threads],
310        };
311
312        let command_payloads = command_payloads.into_iter().map(|command| {
313            command
314                .with_repeat_interval(*repeat_interval)
315                .with_repeat_n_times(*repeat_n_times)
316        });
317
318        let coins_and_keys = if config.signer_info.is_some() {
319            Some(
320                prepare_new_signer_and_coins(
321                    clients.first().unwrap(),
322                    config.signer_info.as_ref().unwrap(),
323                    config.num_threads * config.num_chunks_per_thread,
324                    config.max_repeat as u64 + 1,
325                )
326                .await,
327            )
328        } else {
329            None
330        };
331
332        let num_chunks = config.num_chunks_per_thread;
333        Ok(command_payloads
334            .into_iter()
335            .enumerate()
336            .map(|(i, command)| Payload {
337                commands: vec![command], // note commands is also a vector
338                signer_info: coins_and_keys
339                    .as_ref()
340                    .map(|(coins, encoded_keypair)| SignerInfo {
341                        encoded_keypair: encoded_keypair.clone(),
342                        gas_payment: Some(coins[num_chunks * i..(i + 1) * num_chunks].to_vec()),
343                        gas_budget: None,
344                    }),
345            })
346            .collect())
347    }
348
349    fn dump_cache_to_file(&self, config: &LoadTestConfig) {
350        if let CommandData::GetCheckpoints(data) = &config.command.data {
351            if data.record {
352                self.dump_cache_to_file();
353            }
354        }
355    }
356}
357
358#[async_trait]
359impl<'a> ProcessPayload<'a, &'a DryRun> for RpcCommandProcessor {
360    async fn process(&'a self, _op: &'a DryRun, _signer_info: &Option<SignerInfo>) -> Result<()> {
361        debug!("DryRun");
362        Ok(())
363    }
364}
365
366fn write_data_to_file<T: Serialize>(data: &T, file_path: &str) -> Result<(), anyhow::Error> {
367    let mut path_buf = PathBuf::from(&file_path);
368    path_buf.pop();
369    fs::create_dir_all(&path_buf).map_err(|e| anyhow!("Error creating directory: {}", e))?;
370
371    let file_name = format!("{}.json", file_path);
372    let file = File::create(file_name).map_err(|e| anyhow!("Error creating file: {}", e))?;
373    serde_json::to_writer(file, data).map_err(|e| anyhow!("Error writing to file: {}", e))?;
374
375    Ok(())
376}
377
378pub enum CacheType {
379    IotaAddress,
380    TransactionDigest,
381    ObjectID,
382}
383
384impl fmt::Display for CacheType {
385    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
386        match self {
387            CacheType::IotaAddress => write!(f, "IotaAddress"),
388            CacheType::TransactionDigest => write!(f, "TransactionDigest"),
389            CacheType::ObjectID => write!(f, "ObjectID"),
390        }
391    }
392}
393
394// TODO(Will): Consider using enums for input and output? Would mean we need to
395// do checks any time we use generic load_cache_from_file
396pub fn load_addresses_from_file(filepath: String) -> Vec<IotaAddress> {
397    let path = format!("{}/{}", filepath, CacheType::IotaAddress);
398    let addresses: Vec<IotaAddress> = read_data_from_file(&path).expect("Failed to read addresses");
399    addresses
400}
401
402pub fn load_objects_from_file(filepath: String) -> Vec<ObjectID> {
403    let path = format!("{}/{}", filepath, CacheType::ObjectID);
404    let objects: Vec<ObjectID> = read_data_from_file(&path).expect("Failed to read objects");
405    objects
406}
407
408pub fn load_digests_from_file(filepath: String) -> Vec<TransactionDigest> {
409    let path = format!("{}/{}", filepath, CacheType::TransactionDigest);
410    let digests: Vec<TransactionDigest> =
411        read_data_from_file(&path).expect("Failed to read transaction digests");
412    digests
413}
414
415fn read_data_from_file<T: DeserializeOwned>(file_path: &str) -> Result<T, anyhow::Error> {
416    let mut path_buf = PathBuf::from(file_path);
417
418    // Check if the file has a JSON extension
419    if path_buf.extension().is_none_or(|ext| ext != "json") {
420        // If not, add .json to the filename
421        path_buf.set_extension("json");
422    }
423
424    let path = path_buf.as_path();
425    if !path.exists() {
426        return Err(anyhow!("File not found: {}", file_path));
427    }
428
429    let file = File::open(path).map_err(|e| anyhow::anyhow!("Error opening file: {}", e))?;
430    let deserialized_data: T =
431        serde_json::from_reader(file).map_err(|e| anyhow!("Deserialization error: {}", e))?;
432
433    Ok(deserialized_data)
434}
435
436async fn divide_checkpoint_tasks(
437    clients: &[IotaClient],
438    data: &GetCheckpoints,
439    num_chunks: usize,
440) -> Vec<Command> {
441    let start = data.start;
442    let end = match data.end {
443        Some(end) => end,
444        None => {
445            let end_checkpoints = join_all(clients.iter().map(|client| async {
446                client
447                    .read_api()
448                    .get_latest_checkpoint_sequence_number()
449                    .await
450                    .expect("get_latest_checkpoint_sequence_number should not fail")
451            }))
452            .await;
453            *end_checkpoints
454                .iter()
455                .max()
456                .expect("get_latest_checkpoint_sequence_number should not return empty")
457        }
458    };
459
460    let chunk_size = (end - start) / num_chunks as u64;
461    (0..num_chunks)
462        .map(|i| {
463            let start_checkpoint = start + (i as u64) * chunk_size;
464            let end_checkpoint = end.min(start + ((i + 1) as u64) * chunk_size);
465            Command::new_get_checkpoints(
466                start_checkpoint,
467                Some(end_checkpoint),
468                data.verify_transactions,
469                data.verify_objects,
470                data.record,
471            )
472        })
473        .collect()
474}
475
476async fn divide_query_transaction_blocks_tasks(
477    data: &QueryTransactionBlocks,
478    num_chunks: usize,
479) -> Vec<Command> {
480    let chunk_size = if data.addresses.len() < num_chunks {
481        1
482    } else {
483        data.addresses.len() as u64 / num_chunks as u64
484    };
485    let chunked = chunk_entities(data.addresses.as_slice(), Some(chunk_size as usize));
486    chunked
487        .into_iter()
488        .map(|chunk| Command::new_query_transaction_blocks(data.address_type.clone(), chunk))
489        .collect()
490}
491
492async fn divide_multi_get_transaction_blocks_tasks(
493    data: &MultiGetTransactionBlocks,
494    num_chunks: usize,
495) -> Vec<Command> {
496    let chunk_size = if data.digests.len() < num_chunks {
497        1
498    } else {
499        data.digests.len() as u64 / num_chunks as u64
500    };
501    let chunked = chunk_entities(data.digests.as_slice(), Some(chunk_size as usize));
502    chunked
503        .into_iter()
504        .map(Command::new_multi_get_transaction_blocks)
505        .collect()
506}
507
508async fn divide_get_all_balances_tasks(data: &GetAllBalances, num_threads: usize) -> Vec<Command> {
509    let per_thread_size = if data.addresses.len() < num_threads {
510        1
511    } else {
512        data.addresses.len() / num_threads
513    };
514
515    let chunked = chunk_entities(data.addresses.as_slice(), Some(per_thread_size));
516    chunked
517        .into_iter()
518        .map(|chunk| Command::new_get_all_balances(chunk, data.chunk_size))
519        .collect()
520}
521
522// TODO: probs can do generic divide tasks
523async fn divide_multi_get_objects_tasks(data: &MultiGetObjects, num_chunks: usize) -> Vec<Command> {
524    let chunk_size = if data.object_ids.len() < num_chunks {
525        1
526    } else {
527        data.object_ids.len() as u64 / num_chunks as u64
528    };
529    let chunked = chunk_entities(data.object_ids.as_slice(), Some(chunk_size as usize));
530    chunked
531        .into_iter()
532        .map(Command::new_multi_get_objects)
533        .collect()
534}
535
536async fn divide_get_object_tasks(data: &GetObject, num_threads: usize) -> Vec<Command> {
537    let per_thread_size = if data.object_ids.len() < num_threads {
538        1
539    } else {
540        data.object_ids.len() / num_threads
541    };
542
543    let chunked = chunk_entities(data.object_ids.as_slice(), Some(per_thread_size));
544    chunked
545        .into_iter()
546        .map(|chunk| Command::new_get_object(chunk, data.chunk_size))
547        .collect()
548}
549
550async fn prepare_new_signer_and_coins(
551    client: &IotaClient,
552    signer_info: &SignerInfo,
553    num_coins: usize,
554    num_transactions_per_coin: u64,
555) -> (Vec<ObjectID>, String) {
556    // TODO(chris): consider reference gas price
557    let amount_per_coin = num_transactions_per_coin * DEFAULT_GAS_BUDGET;
558    let pay_amount = amount_per_coin * num_coins as u64;
559    let num_split_txns =
560        num_transactions_needed(num_coins, MAX_NUM_NEW_OBJECTS_IN_SINGLE_TRANSACTION);
561    let (gas_fee_for_split, gas_fee_for_pay_iota) = (
562        DEFAULT_LARGE_GAS_BUDGET * num_split_txns as u64,
563        DEFAULT_GAS_BUDGET,
564    );
565
566    let primary_keypair = IotaKeyPair::decode_base64(&signer_info.encoded_keypair)
567        .expect("Decoding keypair should not fail");
568    let sender = IotaAddress::from(&primary_keypair.public());
569    let (coin, balance) = get_coin_with_max_balance(client, sender).await;
570    // The balance needs to cover `pay_amount` plus
571    // 1. gas fee for pay_iota from the primary address to the burner address
572    // 2. gas fee for splitting the primary coin into `num_coins`
573    let required_balance = pay_amount + gas_fee_for_split + gas_fee_for_pay_iota;
574    if required_balance > balance {
575        panic!(
576            "Current balance {balance} is smaller than require amount of NANOS to fund the operation {required_balance}"
577        );
578    }
579
580    // There is a limit for the number of new objects in a transactions, therefore
581    // we need multiple split transactions if the `num_coins` is large
582    let split_amounts = calculate_split_amounts(
583        num_coins,
584        amount_per_coin,
585        MAX_NUM_NEW_OBJECTS_IN_SINGLE_TRANSACTION,
586    );
587
588    debug!("split_amounts {split_amounts:?}");
589
590    // We don't want to split coins in our primary address because we want to avoid
591    // having a million coin objects in our address. We can also fetch directly
592    // from the faucet, but in some environment that might not be possible when
593    // faucet resource is scarce
594    let (burner_address, burner_keypair): (_, AccountKeyPair) = get_key_pair();
595    let burner_keypair = IotaKeyPair::Ed25519(burner_keypair);
596    let pay_amounts = split_amounts
597        .iter()
598        .map(|(amount, _)| *amount)
599        .chain(std::iter::once(gas_fee_for_split))
600        .collect::<Vec<_>>();
601
602    debug!("pay_amounts {pay_amounts:?}");
603
604    pay_iota(
605        client,
606        &primary_keypair,
607        vec![coin],
608        DEFAULT_GAS_BUDGET,
609        vec![burner_address; pay_amounts.len()],
610        pay_amounts,
611    )
612    .await;
613
614    let coins = get_iota_coin_ids(client, burner_address).await;
615    let gas_coin_id = get_coin_with_balance(&coins, gas_fee_for_split);
616    let primary_coin = get_coin_with_balance(&coins, split_amounts[0].0);
617    assert!(!coins.is_empty());
618    let mut results: Vec<ObjectID> = vec![];
619    assert!(!split_amounts.is_empty());
620    if split_amounts.len() == 1 && split_amounts[0].1 == 0 {
621        results.push(get_coin_with_balance(&coins, split_amounts[0].0));
622    } else if split_amounts.len() == 1 {
623        results.extend(
624            split_coins(
625                client,
626                &burner_keypair,
627                primary_coin,
628                gas_coin_id,
629                split_amounts[0].1 as u64,
630            )
631            .await,
632        );
633    } else {
634        let (max_amount, max_split) = &split_amounts[0];
635        let (remainder_amount, remainder_split) = split_amounts.last().unwrap();
636        let primary_coins = coins
637            .iter()
638            .filter(|(_, balance)| balance == max_amount)
639            .map(|(id, _)| (*id, *max_split as u64))
640            .chain(
641                coins
642                    .iter()
643                    .filter(|(_, balance)| balance == remainder_amount)
644                    .map(|(id, _)| (*id, *remainder_split as u64)),
645            )
646            .collect::<Vec<_>>();
647
648        for (coin_id, splits) in primary_coins {
649            results
650                .extend(split_coins(client, &burner_keypair, coin_id, gas_coin_id, splits).await);
651        }
652    }
653    assert_eq!(results.len(), num_coins);
654    debug!("Split off {} coins for gas payment {results:?}", num_coins);
655    (results, burner_keypair.encode_base64())
656}
657
658/// Calculate the number of transactions needed to split the given number of
659/// coins. new_coins_per_txn must be greater than 0
660fn num_transactions_needed(num_coins: usize, new_coins_per_txn: usize) -> usize {
661    assert!(new_coins_per_txn > 0);
662    if num_coins == 1 {
663        return 0;
664    }
665    num_coins.div_ceil(new_coins_per_txn)
666}
667
668/// Calculate the split amounts for a given number of coins, amount per coin,
669/// and maximum number of coins per transaction. Returns a Vec of
670/// (primary_coin_amount, split_into_n_coins)
671fn calculate_split_amounts(
672    num_coins: usize,
673    amount_per_coin: u64,
674    max_coins_per_txn: usize,
675) -> Vec<(u64, usize)> {
676    let total_amount = amount_per_coin * num_coins as u64;
677    let num_transactions = num_transactions_needed(num_coins, max_coins_per_txn);
678
679    if num_transactions == 0 {
680        return vec![(total_amount, 0)];
681    }
682
683    let amount_per_transaction = max_coins_per_txn as u64 * amount_per_coin;
684    let remaining_amount = total_amount - amount_per_transaction * (num_transactions as u64 - 1);
685    let mut split_amounts: Vec<(u64, usize)> =
686        vec![(amount_per_transaction, max_coins_per_txn); num_transactions - 1];
687    split_amounts.push((
688        remaining_amount,
689        num_coins - max_coins_per_txn * (num_transactions - 1),
690    ));
691    split_amounts
692}
693
694async fn get_coin_with_max_balance(client: &IotaClient, address: IotaAddress) -> (ObjectID, u64) {
695    let coins = get_iota_coin_ids(client, address).await;
696    assert!(!coins.is_empty());
697    coins.into_iter().max_by(|a, b| a.1.cmp(&b.1)).unwrap()
698}
699
700fn get_coin_with_balance(coins: &[(ObjectID, u64)], target: u64) -> ObjectID {
701    coins.iter().find(|(_, b)| b == &target).unwrap().0
702}
703
704// TODO: move this to the Rust SDK
705async fn get_iota_coin_ids(client: &IotaClient, address: IotaAddress) -> Vec<(ObjectID, u64)> {
706    match client
707        .coin_read_api()
708        .get_coins(address, None, None, None)
709        .await
710    {
711        Ok(page) => page
712            .data
713            .into_iter()
714            .map(|c| (c.coin_object_id, c.balance))
715            .collect::<Vec<_>>(),
716        Err(e) => {
717            panic!("get_iota_coin_ids error for address {address} {e}")
718        }
719    }
720    // TODO: implement iteration over next page
721}
722
723async fn pay_iota(
724    client: &IotaClient,
725    keypair: &IotaKeyPair,
726    input_coins: Vec<ObjectID>,
727    gas_budget: u64,
728    recipients: Vec<IotaAddress>,
729    amounts: Vec<u64>,
730) -> IotaTransactionBlockResponse {
731    let sender = IotaAddress::from(&keypair.public());
732    let tx = client
733        .transaction_builder()
734        .pay(sender, input_coins, recipients, amounts, None, gas_budget)
735        .await
736        .expect("Failed to construct pay iota transaction");
737    sign_and_execute(
738        client,
739        keypair,
740        tx,
741        ExecuteTransactionRequestType::WaitForLocalExecution,
742    )
743    .await
744}
745
746async fn split_coins(
747    client: &IotaClient,
748    keypair: &IotaKeyPair,
749    coin_to_split: ObjectID,
750    gas_payment: ObjectID,
751    num_coins: u64,
752) -> Vec<ObjectID> {
753    let sender = IotaAddress::from(&keypair.public());
754    let split_coin_tx = client
755        .transaction_builder()
756        .split_coin_equal(
757            sender,
758            coin_to_split,
759            num_coins,
760            Some(gas_payment),
761            DEFAULT_LARGE_GAS_BUDGET,
762        )
763        .await
764        .expect("Failed to construct split coin transaction");
765    sign_and_execute(
766        client,
767        keypair,
768        split_coin_tx,
769        ExecuteTransactionRequestType::WaitForLocalExecution,
770    )
771    .await
772    .effects
773    .unwrap()
774    .created()
775    .iter()
776    .map(|owned_object_ref| owned_object_ref.reference.object_id)
777    .chain(std::iter::once(coin_to_split))
778    .collect::<Vec<_>>()
779}
780
781pub(crate) async fn sign_and_execute(
782    client: &IotaClient,
783    keypair: &IotaKeyPair,
784    txn_data: TransactionData,
785    request_type: ExecuteTransactionRequestType,
786) -> IotaTransactionBlockResponse {
787    let signature = Signature::new_secure(
788        &IntentMessage::new(Intent::iota_transaction(), &txn_data),
789        keypair,
790    );
791
792    let transaction_response = match client
793        .quorum_driver_api()
794        .execute_transaction_block(
795            Transaction::from_data(txn_data, vec![signature]),
796            IotaTransactionBlockResponseOptions::new().with_effects(),
797            Some(request_type),
798        )
799        .await
800    {
801        Ok(response) => response,
802        Err(e) => {
803            panic!("sign_and_execute error {e}")
804        }
805    };
806
807    match &transaction_response.effects {
808        Some(effects) => {
809            if let IotaExecutionStatus::Failure { error } = effects.status() {
810                panic!(
811                    "Transaction {} failed with error: {}. Transaction Response: {:?}",
812                    transaction_response.digest, error, &transaction_response
813                );
814            }
815        }
816        None => {
817            panic!(
818                "Transaction {} has no effects. Response {:?}",
819                transaction_response.digest, &transaction_response
820            );
821        }
822    };
823    transaction_response
824}
825
826#[cfg(test)]
827mod tests {
828    use std::{assert_eq, vec};
829
830    use super::*;
831
832    #[test]
833    fn test_calculate_split_amounts_no_split_needed() {
834        let num_coins = 10;
835        let amount_per_coin = 100;
836        let max_coins_per_txn = 20;
837        let expected = vec![(1000, 10)];
838        let result = calculate_split_amounts(num_coins, amount_per_coin, max_coins_per_txn);
839
840        assert_eq!(expected, result);
841    }
842
843    #[test]
844    fn test_calculate_split_amounts_exact_split() {
845        let num_coins = 10;
846        let amount_per_coin = 100;
847        let max_coins_per_txn = 5;
848        let expected = vec![(500, 5), (500, 5)];
849        let result = calculate_split_amounts(num_coins, amount_per_coin, max_coins_per_txn);
850
851        assert_eq!(expected, result);
852    }
853
854    #[test]
855    fn test_calculate_split_amounts_with_remainder() {
856        let num_coins = 12;
857        let amount_per_coin = 100;
858        let max_coins_per_txn = 5;
859        let expected = vec![(500, 5), (500, 5), (200, 2)];
860        let result = calculate_split_amounts(num_coins, amount_per_coin, max_coins_per_txn);
861
862        assert_eq!(expected, result);
863    }
864
865    #[test]
866    fn test_calculate_split_amounts_single_coin() {
867        let num_coins = 1;
868        let amount_per_coin = 100;
869        let max_coins_per_txn = 5;
870        let expected = vec![(100, 0)];
871        let result = calculate_split_amounts(num_coins, amount_per_coin, max_coins_per_txn);
872
873        assert_eq!(expected, result);
874    }
875
876    #[test]
877    fn test_calculate_split_amounts_max_coins_equals_num_coins() {
878        let num_coins = 5;
879        let amount_per_coin = 100;
880        let max_coins_per_txn = 5;
881        let expected = vec![(500, 5)];
882        let result = calculate_split_amounts(num_coins, amount_per_coin, max_coins_per_txn);
883
884        assert_eq!(expected, result);
885    }
886
887    #[test]
888    #[should_panic(expected = "assertion failed: new_coins_per_txn > 0")]
889    fn test_calculate_split_amounts_zero_max_coins() {
890        let num_coins = 5;
891        let amount_per_coin = 100;
892        let max_coins_per_txn = 0;
893
894        calculate_split_amounts(num_coins, amount_per_coin, max_coins_per_txn);
895    }
896}