iota_sdk/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5//! The IOTA Rust SDK
6//!
7//! It aims at providing a similar SDK functionality like the one existing for
8//! [TypeScript](https://github.com/iotaledger/iota/tree/main/sdk/typescript/).
9//! IOTA Rust SDK builds on top of the [JSON RPC API](https://docs.iota.org/iota-api-ref)
10//! and therefore many of the return types are the ones specified in
11//! [iota_types].
12//!
13//! The API is split in several parts corresponding to different functionalities
14//! as following:
15//! * [CoinReadApi] - provides read-only functions to work with the coins
16//! * [EventApi] - provides event related functions functions to
17//! * [GovernanceApi] - provides functionality related to staking
18//! * [QuorumDriverApi] - provides functionality to execute a transaction block
19//!   and submit it to the fullnode(s)
20//! * [ReadApi] - provides functions for retrieving data about different objects
21//!   and transactions
22//! * <a href="../iota_transaction_builder/struct.TransactionBuilder.html"
23//!   title="struct
24//!   iota_transaction_builder::TransactionBuilder">TransactionBuilder</a> -
25//!   provides functions for building transactions
26//!
27//! # Usage
28//! The main way to interact with the API is through the [IotaClientBuilder],
29//! which returns an [IotaClient] object from which the user can access the
30//! various APIs.
31//!
32//! ## Getting Started
33//! Add the Rust SDK to the project by running `cargo add iota-sdk` in the root
34//! folder of your Rust project.
35//!
36//! The main building block for the IOTA Rust SDK is the [IotaClientBuilder],
37//! which provides a simple and straightforward way of connecting to an IOTA
38//! network and having access to the different available APIs.
39//!
40//! Below is a simple example which connects to a running IOTA local network,
41//! devnet, and testnet.
42//! To successfully run this program, make sure to spin up a local
43//! network with a local validator, a fullnode, and a faucet server
44//! (see [the README](https://github.com/iotaledger/iota/tree/develop/crates/iota-sdk/README.md#prerequisites) for more information).
45//!
46//! ```rust,no_run
47//! use iota_sdk::IotaClientBuilder;
48//!
49//! #[tokio::main]
50//! async fn main() -> Result<(), anyhow::Error> {
51//!     let iota = IotaClientBuilder::default()
52//!         .build("http://127.0.0.1:9000") // provide the IOTA network URL
53//!         .await?;
54//!     println!("IOTA local network version: {:?}", iota.api_version());
55//!
56//!     // local IOTA network, same result as above except using the dedicated function
57//!     let iota_local = IotaClientBuilder::default().build_localnet().await?;
58//!     println!("IOTA local network version: {:?}", iota_local.api_version());
59//!
60//!     // IOTA devnet running at `https://api.devnet.iota.cafe`
61//!     let iota_devnet = IotaClientBuilder::default().build_devnet().await?;
62//!     println!("IOTA devnet version: {:?}", iota_devnet.api_version());
63//!
64//!     // IOTA testnet running at `https://api.testnet.iota.cafe`
65//!     let iota_testnet = IotaClientBuilder::default().build_testnet().await?;
66//!     println!("IOTA testnet version: {:?}", iota_testnet.api_version());
67//!
68//!     // IOTA mainnet running at `https://api.mainnet.iota.cafe`
69//!     let iota_mainnet = IotaClientBuilder::default().build_mainnet().await?;
70//!     println!("IOTA mainnet version: {:?}", iota_mainnet.api_version());
71//!
72//!     Ok(())
73//! }
74//! ```
75//!
76//! ## Examples
77//!
78//! For detailed examples, please check the APIs docs and the examples folder
79//! in the [repository](https://github.com/iotaledger/iota/tree/main/crates/iota-sdk/examples).
80
81pub mod apis;
82pub mod error;
83pub mod iota_client_config;
84pub mod json_rpc_error;
85pub mod wallet_context;
86
87use std::{
88    collections::{HashMap, VecDeque},
89    fmt::{Debug, Formatter},
90    marker::PhantomData,
91    pin::Pin,
92    str::FromStr,
93    sync::Arc,
94    task::Poll,
95    time::Duration,
96};
97
98use async_trait::async_trait;
99use base64::Engine;
100use futures::TryStreamExt;
101pub use iota_json as json;
102use iota_json_rpc_api::{
103    CLIENT_SDK_TYPE_HEADER, CLIENT_SDK_VERSION_HEADER, CLIENT_TARGET_API_VERSION_HEADER,
104};
105pub use iota_json_rpc_types as rpc_types;
106use iota_json_rpc_types::{
107    IotaObjectDataFilter, IotaObjectDataOptions, IotaObjectResponse, IotaObjectResponseQuery, Page,
108};
109use iota_transaction_builder::{DataReader, TransactionBuilder};
110pub use iota_types as types;
111use iota_types::base_types::{IotaAddress, ObjectID, StructTag};
112use jsonrpsee::{
113    core::client::ClientT,
114    http_client::{HeaderMap, HeaderValue, HttpClient, HttpClientBuilder},
115    rpc_params,
116    ws_client::{PingConfig, WsClient, WsClientBuilder},
117};
118use reqwest::header::HeaderName;
119use rustls::crypto::{CryptoProvider, ring};
120use serde_json::Value;
121
122use crate::{
123    apis::{CoinReadApi, EventApi, GovernanceApi, QuorumDriverApi, ReadApi},
124    error::{Error, IotaRpcResult},
125};
126
127pub const IOTA_COIN_TYPE: &str = "0x2::iota::IOTA";
128pub const IOTA_LOCAL_NETWORK_URL: &str = "http://127.0.0.1:9000";
129pub const IOTA_LOCAL_NETWORK_URL_0: &str = "http://0.0.0.0:9000";
130pub const IOTA_LOCAL_NETWORK_GRAPHQL_URL: &str = "http://127.0.0.1:9125";
131pub const IOTA_LOCAL_NETWORK_GAS_URL: &str = "http://127.0.0.1:9123/v1/gas";
132pub const IOTA_DEVNET_URL: &str = "https://api.devnet.iota.cafe";
133pub const IOTA_DEVNET_GRAPHQL_URL: &str = "https://graphql.devnet.iota.cafe";
134pub const IOTA_DEVNET_GAS_URL: &str = "https://faucet.devnet.iota.cafe/v1/gas";
135pub const IOTA_TESTNET_URL: &str = "https://api.testnet.iota.cafe";
136pub const IOTA_TESTNET_GRAPHQL_URL: &str = "https://graphql.testnet.iota.cafe";
137pub const IOTA_TESTNET_GAS_URL: &str = "https://faucet.testnet.iota.cafe/v1/gas";
138pub const IOTA_MAINNET_URL: &str = "https://api.mainnet.iota.cafe";
139pub const IOTA_MAINNET_GRAPHQL_URL: &str = "https://graphql.mainnet.iota.cafe";
140
141/// Builder for creating an [IotaClient] for connecting to the IOTA network.
142///
143/// By default `maximum concurrent requests` is set to 256 and `request timeout`
144/// is set to 60 seconds. These can be adjusted using
145/// [`Self::max_concurrent_requests()`], and the [`Self::request_timeout()`].
146/// If you use the WebSocket, consider setting `ws_ping_interval` appropriately
147/// to prevent an inactive WS subscription being disconnected due to proxy
148/// timeout.
149///
150/// # Examples
151///
152/// ```rust,no_run
153/// use iota_sdk::IotaClientBuilder;
154///
155/// #[tokio::main]
156/// async fn main() -> Result<(), anyhow::Error> {
157///     let iota = IotaClientBuilder::default()
158///         .build("http://127.0.0.1:9000")
159///         .await?;
160///
161///     println!("IOTA local network version: {:?}", iota.api_version());
162///     Ok(())
163/// }
164/// ```
165pub struct IotaClientBuilder {
166    request_timeout: Duration,
167    max_concurrent_requests: Option<usize>,
168    ws_url: Option<String>,
169    ws_ping_interval: Option<Duration>,
170    basic_auth: Option<(String, String)>,
171    tls_config: Option<rustls::ClientConfig>,
172    headers: Option<HashMap<String, String>>,
173}
174
175impl Default for IotaClientBuilder {
176    fn default() -> Self {
177        Self {
178            request_timeout: Duration::from_secs(60),
179            max_concurrent_requests: None,
180            ws_url: None,
181            ws_ping_interval: None,
182            basic_auth: None,
183            tls_config: None,
184            headers: None,
185        }
186    }
187}
188
189impl IotaClientBuilder {
190    /// Set the request timeout to the specified duration.
191    pub fn request_timeout(mut self, request_timeout: Duration) -> Self {
192        self.request_timeout = request_timeout;
193        self
194    }
195
196    /// Set the max concurrent requests allowed.
197    pub fn max_concurrent_requests(mut self, max_concurrent_requests: usize) -> Self {
198        self.max_concurrent_requests = Some(max_concurrent_requests);
199        self
200    }
201
202    /// Set the WebSocket URL for the IOTA network.
203    pub fn ws_url(mut self, url: impl AsRef<str>) -> Self {
204        self.ws_url = Some(url.as_ref().to_string());
205        self
206    }
207
208    /// Set the WebSocket ping interval.
209    pub fn ws_ping_interval(mut self, duration: Duration) -> Self {
210        self.ws_ping_interval = Some(duration);
211        self
212    }
213
214    /// Set the basic auth credentials for the HTTP client.
215    pub fn basic_auth(mut self, username: impl AsRef<str>, password: impl AsRef<str>) -> Self {
216        self.basic_auth = Some((username.as_ref().to_string(), password.as_ref().to_string()));
217        self
218    }
219
220    /// Set custom headers for the HTTP client
221    pub fn custom_headers(mut self, headers: HashMap<String, String>) -> Self {
222        self.headers = Some(headers);
223        self
224    }
225
226    /// Set a TLS configuration for the HTTP client.
227    pub fn tls_config(mut self, config: rustls::ClientConfig) -> Self {
228        self.tls_config = Some(config);
229        self
230    }
231
232    /// Return an [IotaClient] object connected to the IOTA network accessible
233    /// via the provided URI.
234    ///
235    /// # Examples
236    ///
237    /// ```rust,no_run
238    /// use iota_sdk::IotaClientBuilder;
239    ///
240    /// #[tokio::main]
241    /// async fn main() -> Result<(), anyhow::Error> {
242    ///     let iota = IotaClientBuilder::default()
243    ///         .build("http://127.0.0.1:9000")
244    ///         .await?;
245    ///
246    ///     println!("IOTA local version: {:?}", iota.api_version());
247    ///     Ok(())
248    /// }
249    /// ```
250    pub async fn build(self, http: impl AsRef<str>) -> IotaRpcResult<IotaClient> {
251        if CryptoProvider::get_default().is_none() {
252            ring::default_provider().install_default().ok();
253        }
254
255        let client_version = env!("CARGO_PKG_VERSION");
256        let mut headers = HeaderMap::new();
257        headers.insert(
258            CLIENT_TARGET_API_VERSION_HEADER,
259            // in rust, the client version is the same as the target api version
260            HeaderValue::from_static(client_version),
261        );
262        headers.insert(
263            CLIENT_SDK_VERSION_HEADER,
264            HeaderValue::from_static(client_version),
265        );
266        headers.insert(CLIENT_SDK_TYPE_HEADER, HeaderValue::from_static("rust"));
267
268        if let Some((username, password)) = self.basic_auth {
269            let auth =
270                base64::engine::general_purpose::STANDARD.encode(format!("{username}:{password}"));
271            headers.insert(
272                "authorization",
273                // reqwest::header::AUTHORIZATION,
274                HeaderValue::from_str(&format!("Basic {auth}")).unwrap(),
275            );
276        }
277
278        if let Some(custom_headers) = self.headers {
279            for (key, value) in custom_headers {
280                let header_name =
281                    HeaderName::from_str(&key).map_err(|e| Error::CustomHeaders(e.to_string()))?;
282                let header_value = HeaderValue::from_str(&value)
283                    .map_err(|e| Error::CustomHeaders(e.to_string()))?;
284                headers.insert(header_name, header_value);
285            }
286        }
287
288        let ws = if let Some(url) = self.ws_url {
289            let mut builder = WsClientBuilder::default()
290                .max_request_size(2 << 30)
291                .set_headers(headers.clone())
292                .request_timeout(self.request_timeout);
293
294            if let Some(duration) = self.ws_ping_interval {
295                builder = builder.enable_ws_ping(PingConfig::new().ping_interval(duration))
296            }
297
298            if let Some(max_concurrent_requests) = self.max_concurrent_requests {
299                builder = builder.max_concurrent_requests(max_concurrent_requests);
300            }
301
302            builder.build(url).await.ok()
303        } else {
304            None
305        };
306
307        let mut http_builder = HttpClientBuilder::default()
308            .max_request_size(2 << 30)
309            .set_headers(headers)
310            .request_timeout(self.request_timeout);
311
312        if let Some(max_concurrent_requests) = self.max_concurrent_requests {
313            http_builder = http_builder.max_concurrent_requests(max_concurrent_requests);
314        }
315
316        if let Some(tls_config) = self.tls_config {
317            http_builder = http_builder.with_custom_cert_store(tls_config);
318        }
319
320        let http = http_builder.build(http)?;
321
322        let info = Self::get_server_info(&http, &ws).await?;
323
324        let rpc = RpcClient { http, ws, info };
325        let api = Arc::new(rpc);
326        let read_api = Arc::new(ReadApi::new(api.clone()));
327        let quorum_driver_api = QuorumDriverApi::new(api.clone());
328        let event_api = EventApi::new(api.clone());
329        let transaction_builder = TransactionBuilder::new(read_api.clone());
330        let coin_read_api = CoinReadApi::new(api.clone());
331        let governance_api = GovernanceApi::new(api.clone());
332
333        Ok(IotaClient {
334            api,
335            transaction_builder,
336            read_api,
337            coin_read_api,
338            event_api,
339            quorum_driver_api,
340            governance_api,
341        })
342    }
343
344    /// Return an [IotaClient] object that is ready to interact with the local
345    /// development network (by default it expects the IOTA network to be up
346    /// and running at `127.0.0.1:9000`).
347    ///
348    /// For connecting to a custom URI, use the `build` function instead.
349    ///
350    /// # Examples
351    ///
352    /// ```rust,no_run
353    /// use iota_sdk::IotaClientBuilder;
354    ///
355    /// #[tokio::main]
356    /// async fn main() -> Result<(), anyhow::Error> {
357    ///     let iota = IotaClientBuilder::default().build_localnet().await?;
358    ///
359    ///     println!("IOTA local version: {:?}", iota.api_version());
360    ///     Ok(())
361    /// }
362    /// ```
363    pub async fn build_localnet(self) -> IotaRpcResult<IotaClient> {
364        self.build(IOTA_LOCAL_NETWORK_URL).await
365    }
366
367    /// Return an [IotaClient] object that is ready to interact with the IOTA
368    /// devnet.
369    ///
370    /// For connecting to a custom URI, use the `build` function instead.
371    ///
372    /// # Examples
373    ///
374    /// ```rust,no_run
375    /// use iota_sdk::IotaClientBuilder;
376    ///
377    /// #[tokio::main]
378    /// async fn main() -> Result<(), anyhow::Error> {
379    ///     let iota = IotaClientBuilder::default().build_devnet().await?;
380    ///
381    ///     println!("{:?}", iota.api_version());
382    ///     Ok(())
383    /// }
384    /// ```
385    pub async fn build_devnet(self) -> IotaRpcResult<IotaClient> {
386        self.build(IOTA_DEVNET_URL).await
387    }
388
389    /// Return an [IotaClient] object that is ready to interact with the IOTA
390    /// testnet.
391    ///
392    /// For connecting to a custom URI, use the `build` function instead.
393    ///
394    /// # Examples
395    ///
396    /// ```rust,no_run
397    /// use iota_sdk::IotaClientBuilder;
398    ///
399    /// #[tokio::main]
400    /// async fn main() -> Result<(), anyhow::Error> {
401    ///     let iota = IotaClientBuilder::default().build_testnet().await?;
402    ///
403    ///     println!("{:?}", iota.api_version());
404    ///     Ok(())
405    /// }
406    /// ```
407    pub async fn build_testnet(self) -> IotaRpcResult<IotaClient> {
408        self.build(IOTA_TESTNET_URL).await
409    }
410
411    /// Returns an [IotaClient] object that is ready to interact with the IOTA
412    /// mainnet.
413    ///
414    /// For connecting to a custom URI, use the `build` function instead.
415    ///
416    /// # Examples
417    ///
418    /// ```rust,no_run
419    /// use iota_sdk::IotaClientBuilder;
420    ///
421    /// #[tokio::main]
422    /// async fn main() -> Result<(), anyhow::Error> {
423    ///     let iota = IotaClientBuilder::default().build_mainnet().await?;
424    ///
425    ///     println!("{:?}", iota.api_version());
426    ///     Ok(())
427    /// }
428    /// ```
429    pub async fn build_mainnet(self) -> IotaRpcResult<IotaClient> {
430        self.build(IOTA_MAINNET_URL).await
431    }
432
433    /// Return the server information as a `ServerInfo` structure.
434    ///
435    /// Fails with an error if it cannot call the RPC discover.
436    async fn get_server_info(
437        http: &HttpClient,
438        ws: &Option<WsClient>,
439    ) -> Result<ServerInfo, Error> {
440        let rpc_spec: Value = http.request("rpc.discover", rpc_params![]).await?;
441        let version = rpc_spec
442            .pointer("/info/version")
443            .and_then(|v| v.as_str())
444            .ok_or_else(|| {
445                Error::Data("Fail parsing server version from rpc.discover endpoint.".into())
446            })?;
447        let rpc_methods = Self::parse_methods(&rpc_spec)?;
448
449        let subscriptions = if let Some(ws) = ws {
450            match ws.request("rpc.discover", rpc_params![]).await {
451                Ok(rpc_spec) => Self::parse_methods(&rpc_spec)?,
452                Err(_) => Vec::new(),
453            }
454        } else {
455            Vec::new()
456        };
457        let iota_system_state_v2_support =
458            rpc_methods.contains(&"iotax_getLatestIotaSystemStateV2".to_string());
459        Ok(ServerInfo {
460            rpc_methods,
461            subscriptions,
462            version: version.to_string(),
463            iota_system_state_v2_support,
464        })
465    }
466
467    fn parse_methods(server_spec: &Value) -> Result<Vec<String>, Error> {
468        let methods = server_spec
469            .pointer("/methods")
470            .and_then(|methods| methods.as_array())
471            .ok_or_else(|| {
472                Error::Data("Fail parsing server information from rpc.discover endpoint.".into())
473            })?;
474
475        Ok(methods
476            .iter()
477            .flat_map(|method| method["name"].as_str())
478            .map(|s| s.into())
479            .collect())
480    }
481}
482
483/// Provides all the necessary abstractions for interacting with the IOTA
484/// network.
485///
486/// # Usage
487///
488/// Use [IotaClientBuilder] to build an [IotaClient].
489///
490/// # Examples
491///
492/// ```rust,no_run
493/// use std::str::FromStr;
494///
495/// use iota_sdk::{IotaClientBuilder, types::base_types::IotaAddress};
496///
497/// #[tokio::main]
498/// async fn main() -> Result<(), anyhow::Error> {
499///     let iota = IotaClientBuilder::default()
500///         .build("http://127.0.0.1:9000")
501///         .await?;
502///
503///     println!("{:?}", iota.available_rpc_methods());
504///     println!("{:?}", iota.available_subscriptions());
505///     println!("{:?}", iota.api_version());
506///
507///     let address = IotaAddress::from_str("0x0000....0000")?;
508///     let owned_objects = iota
509///         .read_api()
510///         .get_owned_objects(address, None, None, None)
511///         .await?;
512///
513///     println!("{:?}", owned_objects);
514///
515///     Ok(())
516/// }
517/// ```
518#[derive(Clone)]
519pub struct IotaClient {
520    api: Arc<RpcClient>,
521    transaction_builder: TransactionBuilder,
522    read_api: Arc<ReadApi>,
523    coin_read_api: CoinReadApi,
524    event_api: EventApi,
525    quorum_driver_api: QuorumDriverApi,
526    governance_api: GovernanceApi,
527}
528
529pub(crate) struct RpcClient {
530    http: HttpClient,
531    ws: Option<WsClient>,
532    info: ServerInfo,
533}
534
535impl Debug for RpcClient {
536    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
537        write!(
538            f,
539            "RPC client. Http: {:?}, Websocket: {:?}",
540            self.http, self.ws
541        )
542    }
543}
544
545/// Contains all the useful information regarding the API version, the available
546/// RPC calls, and subscriptions.
547struct ServerInfo {
548    rpc_methods: Vec<String>,
549    subscriptions: Vec<String>,
550    version: String,
551    iota_system_state_v2_support: bool,
552}
553
554impl IotaClient {
555    /// Return a list of RPC methods supported by the node the client is
556    /// connected to.
557    pub fn available_rpc_methods(&self) -> &Vec<String> {
558        &self.api.info.rpc_methods
559    }
560
561    /// Return a list of streaming/subscription APIs supported by the node the
562    /// client is connected to.
563    pub fn available_subscriptions(&self) -> &Vec<String> {
564        &self.api.info.subscriptions
565    }
566
567    /// Return the API version information as a string.
568    ///
569    /// The format of this string is `<major>.<minor>.<patch>`, e.g., `1.6.0`,
570    /// and it is retrieved from the OpenRPC specification via the discover
571    /// service method.
572    pub fn api_version(&self) -> &str {
573        &self.api.info.version
574    }
575
576    /// Verify if the API version matches the server version and returns an
577    /// error if they do not match.
578    pub fn check_api_version(&self) -> IotaRpcResult<()> {
579        let server_version = self.api_version();
580        let client_version = env!("CARGO_PKG_VERSION");
581        if server_version != client_version {
582            return Err(Error::ServerVersionMismatch {
583                client_version: client_version.to_string(),
584                server_version: server_version.to_string(),
585            });
586        };
587        Ok(())
588    }
589
590    /// Return a reference to the coin read API.
591    pub fn coin_read_api(&self) -> &CoinReadApi {
592        &self.coin_read_api
593    }
594
595    /// Return a reference to the event API.
596    pub fn event_api(&self) -> &EventApi {
597        &self.event_api
598    }
599
600    /// Return a reference to the governance API.
601    pub fn governance_api(&self) -> &GovernanceApi {
602        &self.governance_api
603    }
604
605    /// Return a reference to the quorum driver API.
606    pub fn quorum_driver_api(&self) -> &QuorumDriverApi {
607        &self.quorum_driver_api
608    }
609
610    /// Return a reference to the read API.
611    pub fn read_api(&self) -> &ReadApi {
612        &self.read_api
613    }
614
615    /// Return a reference to the transaction builder API.
616    pub fn transaction_builder(&self) -> &TransactionBuilder {
617        &self.transaction_builder
618    }
619
620    /// Return a reference to the underlying http client.
621    pub fn http(&self) -> &HttpClient {
622        &self.api.http
623    }
624
625    /// Return a reference to the underlying WebSocket client, if any.
626    pub fn ws(&self) -> Option<&WsClient> {
627        self.api.ws.as_ref()
628    }
629}
630
631#[async_trait]
632impl DataReader for ReadApi {
633    async fn get_owned_objects(
634        &self,
635        address: IotaAddress,
636        object_type: StructTag,
637        cursor: Option<ObjectID>,
638        limit: Option<usize>,
639        options: IotaObjectDataOptions,
640    ) -> Result<iota_json_rpc_types::ObjectsPage, anyhow::Error> {
641        let query = Some(IotaObjectResponseQuery {
642            filter: Some(IotaObjectDataFilter::StructType(object_type)),
643            options: Some(options),
644        });
645
646        Ok(self
647            .get_owned_objects(address, query, cursor, limit)
648            .await?)
649    }
650
651    async fn get_object_with_options(
652        &self,
653        object_id: ObjectID,
654        options: IotaObjectDataOptions,
655    ) -> Result<IotaObjectResponse, anyhow::Error> {
656        Ok(self.get_object_with_options(object_id, options).await?)
657    }
658
659    /// Return the reference gas price as a u64 or an error otherwise
660    async fn get_reference_gas_price(&self) -> Result<u64, anyhow::Error> {
661        Ok(self.get_reference_gas_price().await?)
662    }
663}
664
665/// A helper trait for repeatedly calling an async function which returns pages
666/// of data.
667pub trait PagedFn<O, C, F, E>: Sized + Fn(Option<C>) -> F
668where
669    O: Send,
670    C: Send,
671    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
672{
673    /// Get all items from the source and collect them into a vector.
674    fn collect<T>(self) -> impl futures::Future<Output = Result<T, E>>
675    where
676        T: Default + Extend<O>,
677    {
678        self.stream().try_collect::<T>()
679    }
680
681    /// Get a stream which will return all items from the source.
682    fn stream(self) -> PagedStream<O, C, F, E, Self> {
683        PagedStream::new(self)
684    }
685}
686
687impl<O, C, F, E, Fun> PagedFn<O, C, F, E> for Fun
688where
689    Fun: Fn(Option<C>) -> F,
690    O: Send,
691    C: Send,
692    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
693{
694}
695
696/// A stream which repeatedly calls an async function which returns a page of
697/// data.
698pub struct PagedStream<O, C, F, E, Fun> {
699    fun: Fun,
700    fut: Pin<Box<F>>,
701    next: VecDeque<O>,
702    has_next_page: bool,
703    _data: PhantomData<(E, C)>,
704}
705
706impl<O, C, F, E, Fun> PagedStream<O, C, F, E, Fun>
707where
708    Fun: Fn(Option<C>) -> F,
709{
710    pub fn new(fun: Fun) -> Self {
711        let fut = fun(None);
712        Self {
713            fun,
714            fut: Box::pin(fut),
715            next: Default::default(),
716            has_next_page: true,
717            _data: PhantomData,
718        }
719    }
720}
721
722impl<O, C, F, E, Fun> futures::Stream for PagedStream<O, C, F, E, Fun>
723where
724    O: Send,
725    C: Send,
726    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
727    Fun: Fn(Option<C>) -> F,
728{
729    type Item = Result<O, E>;
730
731    fn poll_next(
732        self: std::pin::Pin<&mut Self>,
733        cx: &mut std::task::Context<'_>,
734    ) -> Poll<Option<Self::Item>> {
735        let this = unsafe { self.get_unchecked_mut() };
736        if this.next.is_empty() && this.has_next_page {
737            match this.fut.as_mut().poll(cx) {
738                Poll::Ready(res) => match res {
739                    Ok(mut page) => {
740                        this.next.extend(page.data);
741                        this.has_next_page = page.has_next_page;
742                        if this.has_next_page {
743                            this.fut.set((this.fun)(page.next_cursor.take()));
744                        }
745                    }
746                    Err(e) => {
747                        this.has_next_page = false;
748                        return Poll::Ready(Some(Err(e)));
749                    }
750                },
751                Poll::Pending => return Poll::Pending,
752            }
753        }
754        Poll::Ready(this.next.pop_front().map(Ok))
755    }
756}
757
758#[cfg(test)]
759mod test {
760    use futures::StreamExt;
761    use iota_json_rpc_types::Page;
762
763    use super::*;
764
765    #[tokio::test]
766    async fn test_get_all_pages() {
767        let data = (0..10000).collect::<Vec<_>>();
768        struct Endpoint {
769            data: Vec<i32>,
770        }
771
772        impl Endpoint {
773            async fn get_page(&self, cursor: Option<usize>) -> anyhow::Result<Page<i32, usize>> {
774                const PAGE_SIZE: usize = 100;
775                anyhow::ensure!(cursor.is_none_or(|v| v < self.data.len()), "invalid cursor");
776                let index = cursor.unwrap_or_default();
777                let data = self.data[index..]
778                    .iter()
779                    .copied()
780                    .take(PAGE_SIZE)
781                    .collect::<Vec<_>>();
782                let has_next_page = self.data.len() > index + PAGE_SIZE;
783                Ok(Page {
784                    data,
785                    next_cursor: has_next_page.then_some(index + PAGE_SIZE),
786                    has_next_page,
787                })
788            }
789        }
790
791        let endpoint = Endpoint { data };
792
793        let mut stream = PagedFn::stream(async |cursor| endpoint.get_page(cursor).await);
794
795        assert_eq!(
796            stream
797                .by_ref()
798                .take(9999)
799                .try_collect::<Vec<_>>()
800                .await
801                .unwrap(),
802            endpoint.data[..9999]
803        );
804        assert_eq!(stream.by_ref().try_next().await.unwrap(), Some(9999));
805        assert!(stream.try_next().await.unwrap().is_none());
806
807        let mut bad_stream = PagedFn::stream(async |_| endpoint.get_page(Some(99999)).await);
808
809        assert!(bad_stream.try_next().await.is_err());
810    }
811}