iota_sdk/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5//! The IOTA Rust SDK
6//!
7//! It aims at providing a similar SDK functionality like the one existing for
8//! [TypeScript](https://github.com/iotaledger/iota/tree/main/sdk/typescript/).
9//! IOTA Rust SDK builds on top of the [JSON RPC API](https://docs.iota.org/iota-api-ref)
10//! and therefore many of the return types are the ones specified in
11//! [iota_types].
12//!
13//! The API is split in several parts corresponding to different functionalities
14//! as following:
15//! * [CoinReadApi] - provides read-only functions to work with the coins
16//! * [EventApi] - provides event related functions functions to
17//! * [GovernanceApi] - provides functionality related to staking
18//! * [QuorumDriverApi] - provides functionality to execute a transaction block
19//!   and submit it to the fullnode(s)
20//! * [ReadApi] - provides functions for retrieving data about different objects
21//!   and transactions
22//! * <a href="../iota_transaction_builder/struct.TransactionBuilder.html"
23//!   title="struct
24//!   iota_transaction_builder::TransactionBuilder">TransactionBuilder</a> -
25//!   provides functions for building transactions
26//!
27//! # Usage
28//! The main way to interact with the API is through the [IotaClientBuilder],
29//! which returns an [IotaClient] object from which the user can access the
30//! various APIs.
31//!
32//! ## Getting Started
33//! Add the Rust SDK to the project by running `cargo add iota-sdk` in the root
34//! folder of your Rust project.
35//!
36//! The main building block for the IOTA Rust SDK is the [IotaClientBuilder],
37//! which provides a simple and straightforward way of connecting to an IOTA
38//! network and having access to the different available APIs.
39//!
40//! Below is a simple example which connects to a running IOTA local network,
41//! devnet, and testnet.
42//! To successfully run this program, make sure to spin up a local
43//! network with a local validator, a fullnode, and a faucet server
44//! (see [the README](https://github.com/iotaledger/iota/tree/develop/crates/iota-sdk/README.md#prerequisites) for more information).
45//!
46//! ```rust,no_run
47//! use iota_sdk::IotaClientBuilder;
48//!
49//! #[tokio::main]
50//! async fn main() -> Result<(), anyhow::Error> {
51//!     let iota = IotaClientBuilder::default()
52//!         .build("http://127.0.0.1:9000") // provide the IOTA network URL
53//!         .await?;
54//!     println!("IOTA local network version: {:?}", iota.api_version());
55//!
56//!     // local IOTA network, same result as above except using the dedicated function
57//!     let iota_local = IotaClientBuilder::default().build_localnet().await?;
58//!     println!("IOTA local network version: {:?}", iota_local.api_version());
59//!
60//!     // IOTA devnet running at `https://api.devnet.iota.cafe`
61//!     let iota_devnet = IotaClientBuilder::default().build_devnet().await?;
62//!     println!("IOTA devnet version: {:?}", iota_devnet.api_version());
63//!
64//!     // IOTA testnet running at `https://api.testnet.iota.cafe`
65//!     let iota_testnet = IotaClientBuilder::default().build_testnet().await?;
66//!     println!("IOTA testnet version: {:?}", iota_testnet.api_version());
67//!
68//!     // IOTA mainnet running at `https://api.mainnet.iota.cafe`
69//!     let iota_mainnet = IotaClientBuilder::default().build_mainnet().await?;
70//!     println!("IOTA mainnet version: {:?}", iota_mainnet.api_version());
71//!
72//!     Ok(())
73//! }
74//! ```
75//!
76//! ## Examples
77//!
78//! For detailed examples, please check the APIs docs and the examples folder
79//! in the [repository](https://github.com/iotaledger/iota/tree/main/crates/iota-sdk/examples).
80
81pub mod apis;
82pub mod error;
83pub mod iota_client_config;
84pub mod json_rpc_error;
85pub mod wallet_context;
86
87use std::{
88    collections::{HashMap, VecDeque},
89    fmt::{Debug, Formatter},
90    marker::PhantomData,
91    pin::Pin,
92    str::FromStr,
93    sync::Arc,
94    task::Poll,
95    time::Duration,
96};
97
98use async_trait::async_trait;
99use base64::Engine;
100use futures::{StreamExt, TryStreamExt};
101pub use iota_json as json;
102use iota_json_rpc_api::{
103    CLIENT_SDK_TYPE_HEADER, CLIENT_SDK_VERSION_HEADER, CLIENT_TARGET_API_VERSION_HEADER,
104};
105pub use iota_json_rpc_types as rpc_types;
106use iota_json_rpc_types::{
107    IotaObjectDataFilter, IotaObjectDataOptions, IotaObjectResponse, IotaObjectResponseQuery, Page,
108};
109use iota_transaction_builder::{DataReader, TransactionBuilder};
110pub use iota_types as types;
111use iota_types::base_types::{IotaAddress, ObjectID, ObjectInfo};
112use jsonrpsee::{
113    core::client::ClientT,
114    http_client::{HeaderMap, HeaderValue, HttpClient, HttpClientBuilder},
115    rpc_params,
116    ws_client::{PingConfig, WsClient, WsClientBuilder},
117};
118use move_core_types::language_storage::StructTag;
119use reqwest::header::HeaderName;
120use rustls::crypto::{CryptoProvider, ring};
121use serde_json::Value;
122
123use crate::{
124    apis::{CoinReadApi, EventApi, GovernanceApi, QuorumDriverApi, ReadApi},
125    error::{Error, IotaRpcResult},
126};
127
128pub const IOTA_COIN_TYPE: &str = "0x2::iota::IOTA";
129pub const IOTA_LOCAL_NETWORK_URL: &str = "http://127.0.0.1:9000";
130pub const IOTA_LOCAL_NETWORK_URL_0: &str = "http://0.0.0.0:9000";
131pub const IOTA_LOCAL_NETWORK_GRAPHQL_URL: &str = "http://127.0.0.1:9125";
132pub const IOTA_LOCAL_NETWORK_GAS_URL: &str = "http://127.0.0.1:9123/v1/gas";
133pub const IOTA_DEVNET_URL: &str = "https://api.devnet.iota.cafe";
134pub const IOTA_DEVNET_GRAPHQL_URL: &str = "https://graphql.devnet.iota.cafe";
135pub const IOTA_DEVNET_GAS_URL: &str = "https://faucet.devnet.iota.cafe/v1/gas";
136pub const IOTA_TESTNET_URL: &str = "https://api.testnet.iota.cafe";
137pub const IOTA_TESTNET_GRAPHQL_URL: &str = "https://graphql.testnet.iota.cafe";
138pub const IOTA_TESTNET_GAS_URL: &str = "https://faucet.testnet.iota.cafe/v1/gas";
139pub const IOTA_MAINNET_URL: &str = "https://api.mainnet.iota.cafe";
140pub const IOTA_MAINNET_GRAPHQL_URL: &str = "https://graphql.mainnet.iota.cafe";
141
142/// Builder for creating an [IotaClient] for connecting to the IOTA network.
143///
144/// By default `maximum concurrent requests` is set to 256 and `request timeout`
145/// is set to 60 seconds. These can be adjusted using
146/// [`Self::max_concurrent_requests()`], and the [`Self::request_timeout()`].
147/// If you use the WebSocket, consider setting `ws_ping_interval` appropriately
148/// to prevent an inactive WS subscription being disconnected due to proxy
149/// timeout.
150///
151/// # Examples
152///
153/// ```rust,no_run
154/// use iota_sdk::IotaClientBuilder;
155///
156/// #[tokio::main]
157/// async fn main() -> Result<(), anyhow::Error> {
158///     let iota = IotaClientBuilder::default()
159///         .build("http://127.0.0.1:9000")
160///         .await?;
161///
162///     println!("IOTA local network version: {:?}", iota.api_version());
163///     Ok(())
164/// }
165/// ```
166pub struct IotaClientBuilder {
167    request_timeout: Duration,
168    max_concurrent_requests: Option<usize>,
169    ws_url: Option<String>,
170    ws_ping_interval: Option<Duration>,
171    basic_auth: Option<(String, String)>,
172    tls_config: Option<rustls::ClientConfig>,
173    headers: Option<HashMap<String, String>>,
174}
175
176impl Default for IotaClientBuilder {
177    fn default() -> Self {
178        Self {
179            request_timeout: Duration::from_secs(60),
180            max_concurrent_requests: None,
181            ws_url: None,
182            ws_ping_interval: None,
183            basic_auth: None,
184            tls_config: None,
185            headers: None,
186        }
187    }
188}
189
190impl IotaClientBuilder {
191    /// Set the request timeout to the specified duration.
192    pub fn request_timeout(mut self, request_timeout: Duration) -> Self {
193        self.request_timeout = request_timeout;
194        self
195    }
196
197    /// Set the max concurrent requests allowed.
198    pub fn max_concurrent_requests(mut self, max_concurrent_requests: usize) -> Self {
199        self.max_concurrent_requests = Some(max_concurrent_requests);
200        self
201    }
202
203    /// Set the WebSocket URL for the IOTA network.
204    pub fn ws_url(mut self, url: impl AsRef<str>) -> Self {
205        self.ws_url = Some(url.as_ref().to_string());
206        self
207    }
208
209    /// Set the WebSocket ping interval.
210    pub fn ws_ping_interval(mut self, duration: Duration) -> Self {
211        self.ws_ping_interval = Some(duration);
212        self
213    }
214
215    /// Set the basic auth credentials for the HTTP client.
216    pub fn basic_auth(mut self, username: impl AsRef<str>, password: impl AsRef<str>) -> Self {
217        self.basic_auth = Some((username.as_ref().to_string(), password.as_ref().to_string()));
218        self
219    }
220
221    /// Set custom headers for the HTTP client
222    pub fn custom_headers(mut self, headers: HashMap<String, String>) -> Self {
223        self.headers = Some(headers);
224        self
225    }
226
227    /// Set a TLS configuration for the HTTP client.
228    pub fn tls_config(mut self, config: rustls::ClientConfig) -> Self {
229        self.tls_config = Some(config);
230        self
231    }
232
233    /// Return an [IotaClient] object connected to the IOTA network accessible
234    /// via the provided URI.
235    ///
236    /// # Examples
237    ///
238    /// ```rust,no_run
239    /// use iota_sdk::IotaClientBuilder;
240    ///
241    /// #[tokio::main]
242    /// async fn main() -> Result<(), anyhow::Error> {
243    ///     let iota = IotaClientBuilder::default()
244    ///         .build("http://127.0.0.1:9000")
245    ///         .await?;
246    ///
247    ///     println!("IOTA local version: {:?}", iota.api_version());
248    ///     Ok(())
249    /// }
250    /// ```
251    pub async fn build(self, http: impl AsRef<str>) -> IotaRpcResult<IotaClient> {
252        if CryptoProvider::get_default().is_none() {
253            ring::default_provider().install_default().ok();
254        }
255
256        let client_version = env!("CARGO_PKG_VERSION");
257        let mut headers = HeaderMap::new();
258        headers.insert(
259            CLIENT_TARGET_API_VERSION_HEADER,
260            // in rust, the client version is the same as the target api version
261            HeaderValue::from_static(client_version),
262        );
263        headers.insert(
264            CLIENT_SDK_VERSION_HEADER,
265            HeaderValue::from_static(client_version),
266        );
267        headers.insert(CLIENT_SDK_TYPE_HEADER, HeaderValue::from_static("rust"));
268
269        if let Some((username, password)) = self.basic_auth {
270            let auth =
271                base64::engine::general_purpose::STANDARD.encode(format!("{username}:{password}"));
272            headers.insert(
273                "authorization",
274                // reqwest::header::AUTHORIZATION,
275                HeaderValue::from_str(&format!("Basic {auth}")).unwrap(),
276            );
277        }
278
279        if let Some(custom_headers) = self.headers {
280            for (key, value) in custom_headers {
281                let header_name =
282                    HeaderName::from_str(&key).map_err(|e| Error::CustomHeaders(e.to_string()))?;
283                let header_value = HeaderValue::from_str(&value)
284                    .map_err(|e| Error::CustomHeaders(e.to_string()))?;
285                headers.insert(header_name, header_value);
286            }
287        }
288
289        let ws = if let Some(url) = self.ws_url {
290            let mut builder = WsClientBuilder::default()
291                .max_request_size(2 << 30)
292                .set_headers(headers.clone())
293                .request_timeout(self.request_timeout);
294
295            if let Some(duration) = self.ws_ping_interval {
296                builder = builder.enable_ws_ping(PingConfig::new().ping_interval(duration))
297            }
298
299            if let Some(max_concurrent_requests) = self.max_concurrent_requests {
300                builder = builder.max_concurrent_requests(max_concurrent_requests);
301            }
302
303            builder.build(url).await.ok()
304        } else {
305            None
306        };
307
308        let mut http_builder = HttpClientBuilder::default()
309            .max_request_size(2 << 30)
310            .set_headers(headers)
311            .request_timeout(self.request_timeout);
312
313        if let Some(max_concurrent_requests) = self.max_concurrent_requests {
314            http_builder = http_builder.max_concurrent_requests(max_concurrent_requests);
315        }
316
317        if let Some(tls_config) = self.tls_config {
318            http_builder = http_builder.with_custom_cert_store(tls_config);
319        }
320
321        let http = http_builder.build(http)?;
322
323        let info = Self::get_server_info(&http, &ws).await?;
324
325        let rpc = RpcClient { http, ws, info };
326        let api = Arc::new(rpc);
327        let read_api = Arc::new(ReadApi::new(api.clone()));
328        let quorum_driver_api = QuorumDriverApi::new(api.clone());
329        let event_api = EventApi::new(api.clone());
330        let transaction_builder = TransactionBuilder::new(read_api.clone());
331        let coin_read_api = CoinReadApi::new(api.clone());
332        let governance_api = GovernanceApi::new(api.clone());
333
334        Ok(IotaClient {
335            api,
336            transaction_builder,
337            read_api,
338            coin_read_api,
339            event_api,
340            quorum_driver_api,
341            governance_api,
342        })
343    }
344
345    /// Return an [IotaClient] object that is ready to interact with the local
346    /// development network (by default it expects the IOTA network to be up
347    /// and running at `127.0.0.1:9000`).
348    ///
349    /// For connecting to a custom URI, use the `build` function instead.
350    ///
351    /// # Examples
352    ///
353    /// ```rust,no_run
354    /// use iota_sdk::IotaClientBuilder;
355    ///
356    /// #[tokio::main]
357    /// async fn main() -> Result<(), anyhow::Error> {
358    ///     let iota = IotaClientBuilder::default().build_localnet().await?;
359    ///
360    ///     println!("IOTA local version: {:?}", iota.api_version());
361    ///     Ok(())
362    /// }
363    /// ```
364    pub async fn build_localnet(self) -> IotaRpcResult<IotaClient> {
365        self.build(IOTA_LOCAL_NETWORK_URL).await
366    }
367
368    /// Return an [IotaClient] object that is ready to interact with the IOTA
369    /// devnet.
370    ///
371    /// For connecting to a custom URI, use the `build` function instead.
372    ///
373    /// # Examples
374    ///
375    /// ```rust,no_run
376    /// use iota_sdk::IotaClientBuilder;
377    ///
378    /// #[tokio::main]
379    /// async fn main() -> Result<(), anyhow::Error> {
380    ///     let iota = IotaClientBuilder::default().build_devnet().await?;
381    ///
382    ///     println!("{:?}", iota.api_version());
383    ///     Ok(())
384    /// }
385    /// ```
386    pub async fn build_devnet(self) -> IotaRpcResult<IotaClient> {
387        self.build(IOTA_DEVNET_URL).await
388    }
389
390    /// Return an [IotaClient] object that is ready to interact with the IOTA
391    /// testnet.
392    ///
393    /// For connecting to a custom URI, use the `build` function instead.
394    ///
395    /// # Examples
396    ///
397    /// ```rust,no_run
398    /// use iota_sdk::IotaClientBuilder;
399    ///
400    /// #[tokio::main]
401    /// async fn main() -> Result<(), anyhow::Error> {
402    ///     let iota = IotaClientBuilder::default().build_testnet().await?;
403    ///
404    ///     println!("{:?}", iota.api_version());
405    ///     Ok(())
406    /// }
407    /// ```
408    pub async fn build_testnet(self) -> IotaRpcResult<IotaClient> {
409        self.build(IOTA_TESTNET_URL).await
410    }
411
412    /// Returns an [IotaClient] object that is ready to interact with the IOTA
413    /// mainnet.
414    ///
415    /// For connecting to a custom URI, use the `build` function instead.
416    ///
417    /// # Examples
418    ///
419    /// ```rust,no_run
420    /// use iota_sdk::IotaClientBuilder;
421    ///
422    /// #[tokio::main]
423    /// async fn main() -> Result<(), anyhow::Error> {
424    ///     let iota = IotaClientBuilder::default().build_mainnet().await?;
425    ///
426    ///     println!("{:?}", iota.api_version());
427    ///     Ok(())
428    /// }
429    /// ```
430    pub async fn build_mainnet(self) -> IotaRpcResult<IotaClient> {
431        self.build(IOTA_MAINNET_URL).await
432    }
433
434    /// Return the server information as a `ServerInfo` structure.
435    ///
436    /// Fails with an error if it cannot call the RPC discover.
437    async fn get_server_info(
438        http: &HttpClient,
439        ws: &Option<WsClient>,
440    ) -> Result<ServerInfo, Error> {
441        let rpc_spec: Value = http.request("rpc.discover", rpc_params![]).await?;
442        let version = rpc_spec
443            .pointer("/info/version")
444            .and_then(|v| v.as_str())
445            .ok_or_else(|| {
446                Error::Data("Fail parsing server version from rpc.discover endpoint.".into())
447            })?;
448        let rpc_methods = Self::parse_methods(&rpc_spec)?;
449
450        let subscriptions = if let Some(ws) = ws {
451            match ws.request("rpc.discover", rpc_params![]).await {
452                Ok(rpc_spec) => Self::parse_methods(&rpc_spec)?,
453                Err(_) => Vec::new(),
454            }
455        } else {
456            Vec::new()
457        };
458        let iota_system_state_v2_support =
459            rpc_methods.contains(&"iotax_getLatestIotaSystemStateV2".to_string());
460        Ok(ServerInfo {
461            rpc_methods,
462            subscriptions,
463            version: version.to_string(),
464            iota_system_state_v2_support,
465        })
466    }
467
468    fn parse_methods(server_spec: &Value) -> Result<Vec<String>, Error> {
469        let methods = server_spec
470            .pointer("/methods")
471            .and_then(|methods| methods.as_array())
472            .ok_or_else(|| {
473                Error::Data("Fail parsing server information from rpc.discover endpoint.".into())
474            })?;
475
476        Ok(methods
477            .iter()
478            .flat_map(|method| method["name"].as_str())
479            .map(|s| s.into())
480            .collect())
481    }
482}
483
484/// Provides all the necessary abstractions for interacting with the IOTA
485/// network.
486///
487/// # Usage
488///
489/// Use [IotaClientBuilder] to build an [IotaClient].
490///
491/// # Examples
492///
493/// ```rust,no_run
494/// use std::str::FromStr;
495///
496/// use iota_sdk::{IotaClientBuilder, types::base_types::IotaAddress};
497///
498/// #[tokio::main]
499/// async fn main() -> Result<(), anyhow::Error> {
500///     let iota = IotaClientBuilder::default()
501///         .build("http://127.0.0.1:9000")
502///         .await?;
503///
504///     println!("{:?}", iota.available_rpc_methods());
505///     println!("{:?}", iota.available_subscriptions());
506///     println!("{:?}", iota.api_version());
507///
508///     let address = IotaAddress::from_str("0x0000....0000")?;
509///     let owned_objects = iota
510///         .read_api()
511///         .get_owned_objects(address, None, None, None)
512///         .await?;
513///
514///     println!("{:?}", owned_objects);
515///
516///     Ok(())
517/// }
518/// ```
519#[derive(Clone)]
520pub struct IotaClient {
521    api: Arc<RpcClient>,
522    transaction_builder: TransactionBuilder,
523    read_api: Arc<ReadApi>,
524    coin_read_api: CoinReadApi,
525    event_api: EventApi,
526    quorum_driver_api: QuorumDriverApi,
527    governance_api: GovernanceApi,
528}
529
530pub(crate) struct RpcClient {
531    http: HttpClient,
532    ws: Option<WsClient>,
533    info: ServerInfo,
534}
535
536impl Debug for RpcClient {
537    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
538        write!(
539            f,
540            "RPC client. Http: {:?}, Websocket: {:?}",
541            self.http, self.ws
542        )
543    }
544}
545
546/// Contains all the useful information regarding the API version, the available
547/// RPC calls, and subscriptions.
548struct ServerInfo {
549    rpc_methods: Vec<String>,
550    subscriptions: Vec<String>,
551    version: String,
552    iota_system_state_v2_support: bool,
553}
554
555impl IotaClient {
556    /// Return a list of RPC methods supported by the node the client is
557    /// connected to.
558    pub fn available_rpc_methods(&self) -> &Vec<String> {
559        &self.api.info.rpc_methods
560    }
561
562    /// Return a list of streaming/subscription APIs supported by the node the
563    /// client is connected to.
564    pub fn available_subscriptions(&self) -> &Vec<String> {
565        &self.api.info.subscriptions
566    }
567
568    /// Return the API version information as a string.
569    ///
570    /// The format of this string is `<major>.<minor>.<patch>`, e.g., `1.6.0`,
571    /// and it is retrieved from the OpenRPC specification via the discover
572    /// service method.
573    pub fn api_version(&self) -> &str {
574        &self.api.info.version
575    }
576
577    /// Verify if the API version matches the server version and returns an
578    /// error if they do not match.
579    pub fn check_api_version(&self) -> IotaRpcResult<()> {
580        let server_version = self.api_version();
581        let client_version = env!("CARGO_PKG_VERSION");
582        if server_version != client_version {
583            return Err(Error::ServerVersionMismatch {
584                client_version: client_version.to_string(),
585                server_version: server_version.to_string(),
586            });
587        };
588        Ok(())
589    }
590
591    /// Return a reference to the coin read API.
592    pub fn coin_read_api(&self) -> &CoinReadApi {
593        &self.coin_read_api
594    }
595
596    /// Return a reference to the event API.
597    pub fn event_api(&self) -> &EventApi {
598        &self.event_api
599    }
600
601    /// Return a reference to the governance API.
602    pub fn governance_api(&self) -> &GovernanceApi {
603        &self.governance_api
604    }
605
606    /// Return a reference to the quorum driver API.
607    pub fn quorum_driver_api(&self) -> &QuorumDriverApi {
608        &self.quorum_driver_api
609    }
610
611    /// Return a reference to the read API.
612    pub fn read_api(&self) -> &ReadApi {
613        &self.read_api
614    }
615
616    /// Return a reference to the transaction builder API.
617    pub fn transaction_builder(&self) -> &TransactionBuilder {
618        &self.transaction_builder
619    }
620
621    /// Return a reference to the underlying http client.
622    pub fn http(&self) -> &HttpClient {
623        &self.api.http
624    }
625
626    /// Return a reference to the underlying WebSocket client, if any.
627    pub fn ws(&self) -> Option<&WsClient> {
628        self.api.ws.as_ref()
629    }
630}
631
632#[async_trait]
633impl DataReader for ReadApi {
634    async fn get_owned_objects(
635        &self,
636        address: IotaAddress,
637        object_type: StructTag,
638    ) -> Result<Vec<ObjectInfo>, anyhow::Error> {
639        let query = Some(IotaObjectResponseQuery {
640            filter: Some(IotaObjectDataFilter::StructType(object_type)),
641            options: Some(
642                IotaObjectDataOptions::new()
643                    .with_previous_transaction()
644                    .with_type()
645                    .with_owner(),
646            ),
647        });
648
649        let result = PagedFn::stream(async |cursor| {
650            self.get_owned_objects(address, query.clone(), cursor, None)
651                .await
652        })
653        .map(|v| v?.try_into())
654        .try_collect::<Vec<_>>()
655        .await?;
656
657        Ok(result)
658    }
659
660    async fn get_object_with_options(
661        &self,
662        object_id: ObjectID,
663        options: IotaObjectDataOptions,
664    ) -> Result<IotaObjectResponse, anyhow::Error> {
665        Ok(self.get_object_with_options(object_id, options).await?)
666    }
667
668    /// Return the reference gas price as a u64 or an error otherwise
669    async fn get_reference_gas_price(&self) -> Result<u64, anyhow::Error> {
670        Ok(self.get_reference_gas_price().await?)
671    }
672}
673
674/// A helper trait for repeatedly calling an async function which returns pages
675/// of data.
676pub trait PagedFn<O, C, F, E>: Sized + Fn(Option<C>) -> F
677where
678    O: Send,
679    C: Send,
680    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
681{
682    /// Get all items from the source and collect them into a vector.
683    fn collect<T>(self) -> impl futures::Future<Output = Result<T, E>>
684    where
685        T: Default + Extend<O>,
686    {
687        self.stream().try_collect::<T>()
688    }
689
690    /// Get a stream which will return all items from the source.
691    fn stream(self) -> PagedStream<O, C, F, E, Self> {
692        PagedStream::new(self)
693    }
694}
695
696impl<O, C, F, E, Fun> PagedFn<O, C, F, E> for Fun
697where
698    Fun: Fn(Option<C>) -> F,
699    O: Send,
700    C: Send,
701    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
702{
703}
704
705/// A stream which repeatedly calls an async function which returns a page of
706/// data.
707pub struct PagedStream<O, C, F, E, Fun> {
708    fun: Fun,
709    fut: Pin<Box<F>>,
710    next: VecDeque<O>,
711    has_next_page: bool,
712    _data: PhantomData<(E, C)>,
713}
714
715impl<O, C, F, E, Fun> PagedStream<O, C, F, E, Fun>
716where
717    Fun: Fn(Option<C>) -> F,
718{
719    pub fn new(fun: Fun) -> Self {
720        let fut = fun(None);
721        Self {
722            fun,
723            fut: Box::pin(fut),
724            next: Default::default(),
725            has_next_page: true,
726            _data: PhantomData,
727        }
728    }
729}
730
731impl<O, C, F, E, Fun> futures::Stream for PagedStream<O, C, F, E, Fun>
732where
733    O: Send,
734    C: Send,
735    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
736    Fun: Fn(Option<C>) -> F,
737{
738    type Item = Result<O, E>;
739
740    fn poll_next(
741        self: std::pin::Pin<&mut Self>,
742        cx: &mut std::task::Context<'_>,
743    ) -> Poll<Option<Self::Item>> {
744        let this = unsafe { self.get_unchecked_mut() };
745        if this.next.is_empty() && this.has_next_page {
746            match this.fut.as_mut().poll(cx) {
747                Poll::Ready(res) => match res {
748                    Ok(mut page) => {
749                        this.next.extend(page.data);
750                        this.has_next_page = page.has_next_page;
751                        if this.has_next_page {
752                            this.fut.set((this.fun)(page.next_cursor.take()));
753                        }
754                    }
755                    Err(e) => {
756                        this.has_next_page = false;
757                        return Poll::Ready(Some(Err(e)));
758                    }
759                },
760                Poll::Pending => return Poll::Pending,
761            }
762        }
763        Poll::Ready(this.next.pop_front().map(Ok))
764    }
765}
766
767#[cfg(test)]
768mod test {
769    use iota_json_rpc_types::Page;
770
771    use super::*;
772
773    #[tokio::test]
774    async fn test_get_all_pages() {
775        let data = (0..10000).collect::<Vec<_>>();
776        struct Endpoint {
777            data: Vec<i32>,
778        }
779
780        impl Endpoint {
781            async fn get_page(&self, cursor: Option<usize>) -> anyhow::Result<Page<i32, usize>> {
782                const PAGE_SIZE: usize = 100;
783                anyhow::ensure!(cursor.is_none_or(|v| v < self.data.len()), "invalid cursor");
784                let index = cursor.unwrap_or_default();
785                let data = self.data[index..]
786                    .iter()
787                    .copied()
788                    .take(PAGE_SIZE)
789                    .collect::<Vec<_>>();
790                let has_next_page = self.data.len() > index + PAGE_SIZE;
791                Ok(Page {
792                    data,
793                    next_cursor: has_next_page.then_some(index + PAGE_SIZE),
794                    has_next_page,
795                })
796            }
797        }
798
799        let endpoint = Endpoint { data };
800
801        let mut stream = PagedFn::stream(async |cursor| endpoint.get_page(cursor).await);
802
803        assert_eq!(
804            stream
805                .by_ref()
806                .take(9999)
807                .try_collect::<Vec<_>>()
808                .await
809                .unwrap(),
810            endpoint.data[..9999]
811        );
812        assert_eq!(stream.by_ref().try_next().await.unwrap(), Some(9999));
813        assert!(stream.try_next().await.unwrap().is_none());
814
815        let mut bad_stream = PagedFn::stream(async |_| endpoint.get_page(Some(99999)).await);
816
817        assert!(bad_stream.try_next().await.is_err());
818    }
819}