iota_sdk/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5//! The IOTA Rust SDK
6//!
7//! It aims at providing a similar SDK functionality like the one existing for
8//! [TypeScript](https://github.com/iotaledger/iota/tree/main/sdk/typescript/).
9//! IOTA Rust SDK builds on top of the [JSON RPC API](https://docs.iota.org/iota-api-ref)
10//! and therefore many of the return types are the ones specified in
11//! [iota_types].
12//!
13//! The API is split in several parts corresponding to different functionalities
14//! as following:
15//! * [CoinReadApi] - provides read-only functions to work with the coins
16//! * [EventApi] - provides event related functions functions to
17//! * [GovernanceApi] - provides functionality related to staking
18//! * [QuorumDriverApi] - provides functionality to execute a transaction block
19//!   and submit it to the fullnode(s)
20//! * [ReadApi] - provides functions for retrieving data about different objects
21//!   and transactions
22//! * <a href="../iota_transaction_builder/struct.TransactionBuilder.html"
23//!   title="struct
24//!   iota_transaction_builder::TransactionBuilder">TransactionBuilder</a> -
25//!   provides functions for building transactions
26//!
27//! # Usage
28//! The main way to interact with the API is through the [IotaClientBuilder],
29//! which returns an [IotaClient] object from which the user can access the
30//! various APIs.
31//!
32//! ## Getting Started
33//! Add the Rust SDK to the project by running `cargo add iota-sdk` in the root
34//! folder of your Rust project.
35//!
36//! The main building block for the IOTA Rust SDK is the [IotaClientBuilder],
37//! which provides a simple and straightforward way of connecting to an IOTA
38//! network and having access to the different available APIs.
39//!
40//! Below is a simple example which connects to a running IOTA local network,
41//! devnet, and testnet.
42//! To successfully run this program, make sure to spin up a local
43//! network with a local validator, a fullnode, and a faucet server
44//! (see [the README](https://github.com/iotaledger/iota/tree/develop/crates/iota-sdk/README.md#prerequisites) for more information).
45//!
46//! ```rust,no_run
47//! use iota_sdk::IotaClientBuilder;
48//!
49//! #[tokio::main]
50//! async fn main() -> Result<(), anyhow::Error> {
51//!     let iota = IotaClientBuilder::default()
52//!         .build("http://127.0.0.1:9000") // provide the IOTA network URL
53//!         .await?;
54//!     println!("IOTA local network version: {:?}", iota.api_version());
55//!
56//!     // local IOTA network, same result as above except using the dedicated function
57//!     let iota_local = IotaClientBuilder::default().build_localnet().await?;
58//!     println!("IOTA local network version: {:?}", iota_local.api_version());
59//!
60//!     // IOTA devnet running at `https://api.devnet.iota.cafe`
61//!     let iota_devnet = IotaClientBuilder::default().build_devnet().await?;
62//!     println!("IOTA devnet version: {:?}", iota_devnet.api_version());
63//!
64//!     // IOTA testnet running at `https://api.testnet.iota.cafe`
65//!     let iota_testnet = IotaClientBuilder::default().build_testnet().await?;
66//!     println!("IOTA testnet version: {:?}", iota_testnet.api_version());
67//!
68//!     // IOTA mainnet running at `https://api.mainnet.iota.cafe`
69//!     let iota_mainnet = IotaClientBuilder::default().build_mainnet().await?;
70//!     println!("IOTA mainnet version: {:?}", iota_mainnet.api_version());
71//!
72//!     Ok(())
73//! }
74//! ```
75//!
76//! ## Examples
77//!
78//! For detailed examples, please check the APIs docs and the examples folder
79//! in the [repository](https://github.com/iotaledger/iota/tree/main/crates/iota-sdk/examples).
80
81pub mod apis;
82pub mod error;
83pub mod iota_client_config;
84pub mod json_rpc_error;
85pub mod wallet_context;
86
87use std::{
88    collections::VecDeque,
89    fmt::{Debug, Formatter},
90    marker::PhantomData,
91    pin::Pin,
92    sync::Arc,
93    task::Poll,
94    time::Duration,
95};
96
97use async_trait::async_trait;
98use base64::Engine;
99use futures::{StreamExt, TryStreamExt};
100pub use iota_json as json;
101use iota_json_rpc_api::{
102    CLIENT_SDK_TYPE_HEADER, CLIENT_SDK_VERSION_HEADER, CLIENT_TARGET_API_VERSION_HEADER,
103};
104pub use iota_json_rpc_types as rpc_types;
105use iota_json_rpc_types::{
106    IotaObjectDataFilter, IotaObjectDataOptions, IotaObjectResponse, IotaObjectResponseQuery, Page,
107};
108use iota_transaction_builder::{DataReader, TransactionBuilder};
109pub use iota_types as types;
110use iota_types::base_types::{IotaAddress, ObjectID, ObjectInfo};
111use jsonrpsee::{
112    core::client::ClientT,
113    http_client::{HeaderMap, HeaderValue, HttpClient, HttpClientBuilder},
114    rpc_params,
115    ws_client::{PingConfig, WsClient, WsClientBuilder},
116};
117use move_core_types::language_storage::StructTag;
118use rustls::crypto::{CryptoProvider, ring};
119use serde_json::Value;
120
121use crate::{
122    apis::{CoinReadApi, EventApi, GovernanceApi, QuorumDriverApi, ReadApi},
123    error::{Error, IotaRpcResult},
124};
125
126pub const IOTA_COIN_TYPE: &str = "0x2::iota::IOTA";
127pub const IOTA_LOCAL_NETWORK_URL: &str = "http://127.0.0.1:9000";
128pub const IOTA_LOCAL_NETWORK_URL_0: &str = "http://0.0.0.0:9000";
129pub const IOTA_LOCAL_NETWORK_GRAPHQL_URL: &str = "http://127.0.0.1:8000";
130pub const IOTA_LOCAL_NETWORK_GAS_URL: &str = "http://127.0.0.1:9123/v1/gas";
131pub const IOTA_DEVNET_URL: &str = "https://api.devnet.iota.cafe";
132pub const IOTA_DEVNET_GRAPHQL_URL: &str = "https://graphql.devnet.iota.cafe";
133pub const IOTA_DEVNET_GAS_URL: &str = "https://faucet.devnet.iota.cafe/v1/gas";
134pub const IOTA_TESTNET_URL: &str = "https://api.testnet.iota.cafe";
135pub const IOTA_TESTNET_GRAPHQL_URL: &str = "https://graphql.testnet.iota.cafe";
136pub const IOTA_TESTNET_GAS_URL: &str = "https://faucet.testnet.iota.cafe/v1/gas";
137pub const IOTA_MAINNET_URL: &str = "https://api.mainnet.iota.cafe";
138
139/// Builder for creating an [IotaClient] for connecting to the IOTA network.
140///
141/// By default `maximum concurrent requests` is set to 256 and `request timeout`
142/// is set to 60 seconds. These can be adjusted using
143/// [`Self::max_concurrent_requests()`], and the [`Self::request_timeout()`].
144/// If you use the WebSocket, consider setting `ws_ping_interval` appropriately
145/// to prevent an inactive WS subscription being disconnected due to proxy
146/// timeout.
147///
148/// # Examples
149///
150/// ```rust,no_run
151/// use iota_sdk::IotaClientBuilder;
152///
153/// #[tokio::main]
154/// async fn main() -> Result<(), anyhow::Error> {
155///     let iota = IotaClientBuilder::default()
156///         .build("http://127.0.0.1:9000")
157///         .await?;
158///
159///     println!("IOTA local network version: {:?}", iota.api_version());
160///     Ok(())
161/// }
162/// ```
163pub struct IotaClientBuilder {
164    request_timeout: Duration,
165    max_concurrent_requests: Option<usize>,
166    ws_url: Option<String>,
167    ws_ping_interval: Option<Duration>,
168    basic_auth: Option<(String, String)>,
169    tls_config: Option<rustls::ClientConfig>,
170}
171
172impl Default for IotaClientBuilder {
173    fn default() -> Self {
174        Self {
175            request_timeout: Duration::from_secs(60),
176            max_concurrent_requests: None,
177            ws_url: None,
178            ws_ping_interval: None,
179            basic_auth: None,
180            tls_config: None,
181        }
182    }
183}
184
185impl IotaClientBuilder {
186    /// Set the request timeout to the specified duration.
187    pub fn request_timeout(mut self, request_timeout: Duration) -> Self {
188        self.request_timeout = request_timeout;
189        self
190    }
191
192    /// Set the max concurrent requests allowed.
193    pub fn max_concurrent_requests(mut self, max_concurrent_requests: usize) -> Self {
194        self.max_concurrent_requests = Some(max_concurrent_requests);
195        self
196    }
197
198    /// Set the WebSocket URL for the IOTA network.
199    pub fn ws_url(mut self, url: impl AsRef<str>) -> Self {
200        self.ws_url = Some(url.as_ref().to_string());
201        self
202    }
203
204    /// Set the WebSocket ping interval.
205    pub fn ws_ping_interval(mut self, duration: Duration) -> Self {
206        self.ws_ping_interval = Some(duration);
207        self
208    }
209
210    /// Set the basic auth credentials for the HTTP client.
211    pub fn basic_auth(mut self, username: impl AsRef<str>, password: impl AsRef<str>) -> Self {
212        self.basic_auth = Some((username.as_ref().to_string(), password.as_ref().to_string()));
213        self
214    }
215
216    /// Set a TLS configuration for the HTTP client.
217    pub fn tls_config(mut self, config: rustls::ClientConfig) -> Self {
218        self.tls_config = Some(config);
219        self
220    }
221
222    /// Return an [IotaClient] object connected to the IOTA network accessible
223    /// via the provided URI.
224    ///
225    /// # Examples
226    ///
227    /// ```rust,no_run
228    /// use iota_sdk::IotaClientBuilder;
229    ///
230    /// #[tokio::main]
231    /// async fn main() -> Result<(), anyhow::Error> {
232    ///     let iota = IotaClientBuilder::default()
233    ///         .build("http://127.0.0.1:9000")
234    ///         .await?;
235    ///
236    ///     println!("IOTA local version: {:?}", iota.api_version());
237    ///     Ok(())
238    /// }
239    /// ```
240    pub async fn build(self, http: impl AsRef<str>) -> IotaRpcResult<IotaClient> {
241        if CryptoProvider::get_default().is_none() {
242            ring::default_provider().install_default().ok();
243        }
244
245        let client_version = env!("CARGO_PKG_VERSION");
246        let mut headers = HeaderMap::new();
247        headers.insert(
248            CLIENT_TARGET_API_VERSION_HEADER,
249            // in rust, the client version is the same as the target api version
250            HeaderValue::from_static(client_version),
251        );
252        headers.insert(
253            CLIENT_SDK_VERSION_HEADER,
254            HeaderValue::from_static(client_version),
255        );
256        headers.insert(CLIENT_SDK_TYPE_HEADER, HeaderValue::from_static("rust"));
257
258        if let Some((username, password)) = self.basic_auth {
259            let auth = base64::engine::general_purpose::STANDARD
260                .encode(format!("{}:{}", username, password));
261            headers.insert(
262                "authorization",
263                // reqwest::header::AUTHORIZATION,
264                HeaderValue::from_str(&format!("Basic {}", auth)).unwrap(),
265            );
266        }
267
268        let ws = if let Some(url) = self.ws_url {
269            let mut builder = WsClientBuilder::default()
270                .max_request_size(2 << 30)
271                .set_headers(headers.clone())
272                .request_timeout(self.request_timeout);
273
274            if let Some(duration) = self.ws_ping_interval {
275                builder = builder.enable_ws_ping(PingConfig::new().ping_interval(duration))
276            }
277
278            if let Some(max_concurrent_requests) = self.max_concurrent_requests {
279                builder = builder.max_concurrent_requests(max_concurrent_requests);
280            }
281
282            builder.build(url).await.ok()
283        } else {
284            None
285        };
286
287        let mut http_builder = HttpClientBuilder::default()
288            .max_request_size(2 << 30)
289            .set_headers(headers)
290            .request_timeout(self.request_timeout);
291
292        if let Some(max_concurrent_requests) = self.max_concurrent_requests {
293            http_builder = http_builder.max_concurrent_requests(max_concurrent_requests);
294        }
295
296        if let Some(tls_config) = self.tls_config {
297            http_builder = http_builder.with_custom_cert_store(tls_config);
298        }
299
300        let http = http_builder.build(http)?;
301
302        let info = Self::get_server_info(&http, &ws).await?;
303
304        let rpc = RpcClient { http, ws, info };
305        let api = Arc::new(rpc);
306        let read_api = Arc::new(ReadApi::new(api.clone()));
307        let quorum_driver_api = QuorumDriverApi::new(api.clone());
308        let event_api = EventApi::new(api.clone());
309        let transaction_builder = TransactionBuilder::new(read_api.clone());
310        let coin_read_api = CoinReadApi::new(api.clone());
311        let governance_api = GovernanceApi::new(api.clone());
312
313        Ok(IotaClient {
314            api,
315            transaction_builder,
316            read_api,
317            coin_read_api,
318            event_api,
319            quorum_driver_api,
320            governance_api,
321        })
322    }
323
324    /// Return an [IotaClient] object that is ready to interact with the local
325    /// development network (by default it expects the IOTA network to be up
326    /// and running at `127.0.0.1:9000`).
327    ///
328    /// For connecting to a custom URI, use the `build` function instead.
329    ///
330    /// # Examples
331    ///
332    /// ```rust,no_run
333    /// use iota_sdk::IotaClientBuilder;
334    ///
335    /// #[tokio::main]
336    /// async fn main() -> Result<(), anyhow::Error> {
337    ///     let iota = IotaClientBuilder::default().build_localnet().await?;
338    ///
339    ///     println!("IOTA local version: {:?}", iota.api_version());
340    ///     Ok(())
341    /// }
342    /// ```
343    pub async fn build_localnet(self) -> IotaRpcResult<IotaClient> {
344        self.build(IOTA_LOCAL_NETWORK_URL).await
345    }
346
347    /// Return an [IotaClient] object that is ready to interact with the IOTA
348    /// devnet.
349    ///
350    /// For connecting to a custom URI, use the `build` function instead.
351    ///
352    /// # Examples
353    ///
354    /// ```rust,no_run
355    /// use iota_sdk::IotaClientBuilder;
356    ///
357    /// #[tokio::main]
358    /// async fn main() -> Result<(), anyhow::Error> {
359    ///     let iota = IotaClientBuilder::default().build_devnet().await?;
360    ///
361    ///     println!("{:?}", iota.api_version());
362    ///     Ok(())
363    /// }
364    /// ```
365    pub async fn build_devnet(self) -> IotaRpcResult<IotaClient> {
366        self.build(IOTA_DEVNET_URL).await
367    }
368
369    /// Return an [IotaClient] object that is ready to interact with the IOTA
370    /// testnet.
371    ///
372    /// For connecting to a custom URI, use the `build` function instead.
373    ///
374    /// # Examples
375    ///
376    /// ```rust,no_run
377    /// use iota_sdk::IotaClientBuilder;
378    ///
379    /// #[tokio::main]
380    /// async fn main() -> Result<(), anyhow::Error> {
381    ///     let iota = IotaClientBuilder::default().build_testnet().await?;
382    ///
383    ///     println!("{:?}", iota.api_version());
384    ///     Ok(())
385    /// }
386    /// ```
387    pub async fn build_testnet(self) -> IotaRpcResult<IotaClient> {
388        self.build(IOTA_TESTNET_URL).await
389    }
390
391    /// Returns an [IotaClient] object that is ready to interact with the IOTA
392    /// mainnet.
393    ///
394    /// For connecting to a custom URI, use the `build` function instead.
395    ///
396    /// # Examples
397    ///
398    /// ```rust,no_run
399    /// use iota_sdk::IotaClientBuilder;
400    ///
401    /// #[tokio::main]
402    /// async fn main() -> Result<(), anyhow::Error> {
403    ///     let iota = IotaClientBuilder::default().build_mainnet().await?;
404    ///
405    ///     println!("{:?}", iota.api_version());
406    ///     Ok(())
407    /// }
408    /// ```
409    pub async fn build_mainnet(self) -> IotaRpcResult<IotaClient> {
410        self.build(IOTA_MAINNET_URL).await
411    }
412
413    /// Return the server information as a `ServerInfo` structure.
414    ///
415    /// Fails with an error if it cannot call the RPC discover.
416    async fn get_server_info(
417        http: &HttpClient,
418        ws: &Option<WsClient>,
419    ) -> Result<ServerInfo, Error> {
420        let rpc_spec: Value = http.request("rpc.discover", rpc_params![]).await?;
421        let version = rpc_spec
422            .pointer("/info/version")
423            .and_then(|v| v.as_str())
424            .ok_or_else(|| {
425                Error::Data("Fail parsing server version from rpc.discover endpoint.".into())
426            })?;
427        let rpc_methods = Self::parse_methods(&rpc_spec)?;
428
429        let subscriptions = if let Some(ws) = ws {
430            match ws.request("rpc.discover", rpc_params![]).await {
431                Ok(rpc_spec) => Self::parse_methods(&rpc_spec)?,
432                Err(_) => Vec::new(),
433            }
434        } else {
435            Vec::new()
436        };
437        let iota_system_state_v2_support =
438            rpc_methods.contains(&"iotax_getLatestIotaSystemStateV2".to_string());
439        Ok(ServerInfo {
440            rpc_methods,
441            subscriptions,
442            version: version.to_string(),
443            iota_system_state_v2_support,
444        })
445    }
446
447    fn parse_methods(server_spec: &Value) -> Result<Vec<String>, Error> {
448        let methods = server_spec
449            .pointer("/methods")
450            .and_then(|methods| methods.as_array())
451            .ok_or_else(|| {
452                Error::Data("Fail parsing server information from rpc.discover endpoint.".into())
453            })?;
454
455        Ok(methods
456            .iter()
457            .flat_map(|method| method["name"].as_str())
458            .map(|s| s.into())
459            .collect())
460    }
461}
462
463/// Provides all the necessary abstractions for interacting with the IOTA
464/// network.
465///
466/// # Usage
467///
468/// Use [IotaClientBuilder] to build an [IotaClient].
469///
470/// # Examples
471///
472/// ```rust,no_run
473/// use std::str::FromStr;
474///
475/// use iota_sdk::{IotaClientBuilder, types::base_types::IotaAddress};
476///
477/// #[tokio::main]
478/// async fn main() -> Result<(), anyhow::Error> {
479///     let iota = IotaClientBuilder::default()
480///         .build("http://127.0.0.1:9000")
481///         .await?;
482///
483///     println!("{:?}", iota.available_rpc_methods());
484///     println!("{:?}", iota.available_subscriptions());
485///     println!("{:?}", iota.api_version());
486///
487///     let address = IotaAddress::from_str("0x0000....0000")?;
488///     let owned_objects = iota
489///         .read_api()
490///         .get_owned_objects(address, None, None, None)
491///         .await?;
492///
493///     println!("{:?}", owned_objects);
494///
495///     Ok(())
496/// }
497/// ```
498#[derive(Clone)]
499pub struct IotaClient {
500    api: Arc<RpcClient>,
501    transaction_builder: TransactionBuilder,
502    read_api: Arc<ReadApi>,
503    coin_read_api: CoinReadApi,
504    event_api: EventApi,
505    quorum_driver_api: QuorumDriverApi,
506    governance_api: GovernanceApi,
507}
508
509pub(crate) struct RpcClient {
510    http: HttpClient,
511    ws: Option<WsClient>,
512    info: ServerInfo,
513}
514
515impl Debug for RpcClient {
516    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
517        write!(
518            f,
519            "RPC client. Http: {:?}, Websocket: {:?}",
520            self.http, self.ws
521        )
522    }
523}
524
525/// Contains all the useful information regarding the API version, the available
526/// RPC calls, and subscriptions.
527struct ServerInfo {
528    rpc_methods: Vec<String>,
529    subscriptions: Vec<String>,
530    version: String,
531    iota_system_state_v2_support: bool,
532}
533
534impl IotaClient {
535    /// Return a list of RPC methods supported by the node the client is
536    /// connected to.
537    pub fn available_rpc_methods(&self) -> &Vec<String> {
538        &self.api.info.rpc_methods
539    }
540
541    /// Return a list of streaming/subscription APIs supported by the node the
542    /// client is connected to.
543    pub fn available_subscriptions(&self) -> &Vec<String> {
544        &self.api.info.subscriptions
545    }
546
547    /// Return the API version information as a string.
548    ///
549    /// The format of this string is `<major>.<minor>.<patch>`, e.g., `1.6.0`,
550    /// and it is retrieved from the OpenRPC specification via the discover
551    /// service method.
552    pub fn api_version(&self) -> &str {
553        &self.api.info.version
554    }
555
556    /// Verify if the API version matches the server version and returns an
557    /// error if they do not match.
558    pub fn check_api_version(&self) -> IotaRpcResult<()> {
559        let server_version = self.api_version();
560        let client_version = env!("CARGO_PKG_VERSION");
561        if server_version != client_version {
562            return Err(Error::ServerVersionMismatch {
563                client_version: client_version.to_string(),
564                server_version: server_version.to_string(),
565            });
566        };
567        Ok(())
568    }
569
570    /// Return a reference to the coin read API.
571    pub fn coin_read_api(&self) -> &CoinReadApi {
572        &self.coin_read_api
573    }
574
575    /// Return a reference to the event API.
576    pub fn event_api(&self) -> &EventApi {
577        &self.event_api
578    }
579
580    /// Return a reference to the governance API.
581    pub fn governance_api(&self) -> &GovernanceApi {
582        &self.governance_api
583    }
584
585    /// Return a reference to the quorum driver API.
586    pub fn quorum_driver_api(&self) -> &QuorumDriverApi {
587        &self.quorum_driver_api
588    }
589
590    /// Return a reference to the read API.
591    pub fn read_api(&self) -> &ReadApi {
592        &self.read_api
593    }
594
595    /// Return a reference to the transaction builder API.
596    pub fn transaction_builder(&self) -> &TransactionBuilder {
597        &self.transaction_builder
598    }
599
600    /// Return a reference to the underlying http client.
601    pub fn http(&self) -> &HttpClient {
602        &self.api.http
603    }
604
605    /// Return a reference to the underlying WebSocket client, if any.
606    pub fn ws(&self) -> Option<&WsClient> {
607        self.api.ws.as_ref()
608    }
609}
610
611#[async_trait]
612impl DataReader for ReadApi {
613    async fn get_owned_objects(
614        &self,
615        address: IotaAddress,
616        object_type: StructTag,
617    ) -> Result<Vec<ObjectInfo>, anyhow::Error> {
618        let query = Some(IotaObjectResponseQuery {
619            filter: Some(IotaObjectDataFilter::StructType(object_type)),
620            options: Some(
621                IotaObjectDataOptions::new()
622                    .with_previous_transaction()
623                    .with_type()
624                    .with_owner(),
625            ),
626        });
627
628        let result = PagedFn::stream(async |cursor| {
629            self.get_owned_objects(address, query.clone(), cursor, None)
630                .await
631        })
632        .map(|v| v?.try_into())
633        .try_collect::<Vec<_>>()
634        .await?;
635
636        Ok(result)
637    }
638
639    async fn get_object_with_options(
640        &self,
641        object_id: ObjectID,
642        options: IotaObjectDataOptions,
643    ) -> Result<IotaObjectResponse, anyhow::Error> {
644        Ok(self.get_object_with_options(object_id, options).await?)
645    }
646
647    /// Return the reference gas price as a u64 or an error otherwise
648    async fn get_reference_gas_price(&self) -> Result<u64, anyhow::Error> {
649        Ok(self.get_reference_gas_price().await?)
650    }
651}
652
653/// A helper trait for repeatedly calling an async function which returns pages
654/// of data.
655pub trait PagedFn<O, C, F, E>: Sized + Fn(Option<C>) -> F
656where
657    O: Send,
658    C: Send,
659    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
660{
661    /// Get all items from the source and collect them into a vector.
662    fn collect<T>(self) -> impl futures::Future<Output = Result<T, E>>
663    where
664        T: Default + Extend<O>,
665    {
666        self.stream().try_collect::<T>()
667    }
668
669    /// Get a stream which will return all items from the source.
670    fn stream(self) -> PagedStream<O, C, F, E, Self> {
671        PagedStream::new(self)
672    }
673}
674
675impl<O, C, F, E, Fun> PagedFn<O, C, F, E> for Fun
676where
677    Fun: Fn(Option<C>) -> F,
678    O: Send,
679    C: Send,
680    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
681{
682}
683
684/// A stream which repeatedly calls an async function which returns a page of
685/// data.
686pub struct PagedStream<O, C, F, E, Fun> {
687    fun: Fun,
688    fut: Pin<Box<F>>,
689    next: VecDeque<O>,
690    has_next_page: bool,
691    _data: PhantomData<(E, C)>,
692}
693
694impl<O, C, F, E, Fun> PagedStream<O, C, F, E, Fun>
695where
696    Fun: Fn(Option<C>) -> F,
697{
698    pub fn new(fun: Fun) -> Self {
699        let fut = fun(None);
700        Self {
701            fun,
702            fut: Box::pin(fut),
703            next: Default::default(),
704            has_next_page: true,
705            _data: PhantomData,
706        }
707    }
708}
709
710impl<O, C, F, E, Fun> futures::Stream for PagedStream<O, C, F, E, Fun>
711where
712    O: Send,
713    C: Send,
714    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
715    Fun: Fn(Option<C>) -> F,
716{
717    type Item = Result<O, E>;
718
719    fn poll_next(
720        self: std::pin::Pin<&mut Self>,
721        cx: &mut std::task::Context<'_>,
722    ) -> Poll<Option<Self::Item>> {
723        let this = unsafe { self.get_unchecked_mut() };
724        if this.next.is_empty() && this.has_next_page {
725            match this.fut.as_mut().poll(cx) {
726                Poll::Ready(res) => match res {
727                    Ok(mut page) => {
728                        this.next.extend(page.data);
729                        this.has_next_page = page.has_next_page;
730                        if this.has_next_page {
731                            this.fut.set((this.fun)(page.next_cursor.take()));
732                        }
733                    }
734                    Err(e) => {
735                        this.has_next_page = false;
736                        return Poll::Ready(Some(Err(e)));
737                    }
738                },
739                Poll::Pending => return Poll::Pending,
740            }
741        }
742        Poll::Ready(this.next.pop_front().map(Ok))
743    }
744}
745
746#[cfg(test)]
747mod test {
748    use iota_json_rpc_types::Page;
749
750    use super::*;
751
752    #[tokio::test]
753    async fn test_get_all_pages() {
754        let data = (0..10000).collect::<Vec<_>>();
755        struct Endpoint {
756            data: Vec<i32>,
757        }
758
759        impl Endpoint {
760            async fn get_page(&self, cursor: Option<usize>) -> anyhow::Result<Page<i32, usize>> {
761                const PAGE_SIZE: usize = 100;
762                anyhow::ensure!(cursor.is_none_or(|v| v < self.data.len()), "invalid cursor");
763                let index = cursor.unwrap_or_default();
764                let data = self.data[index..]
765                    .iter()
766                    .copied()
767                    .take(PAGE_SIZE)
768                    .collect::<Vec<_>>();
769                let has_next_page = self.data.len() > index + PAGE_SIZE;
770                Ok(Page {
771                    data,
772                    next_cursor: has_next_page.then_some(index + PAGE_SIZE),
773                    has_next_page,
774                })
775            }
776        }
777
778        let endpoint = Endpoint { data };
779
780        let mut stream = PagedFn::stream(async |cursor| endpoint.get_page(cursor).await);
781
782        assert_eq!(
783            stream
784                .by_ref()
785                .take(9999)
786                .try_collect::<Vec<_>>()
787                .await
788                .unwrap(),
789            endpoint.data[..9999]
790        );
791        assert_eq!(stream.by_ref().try_next().await.unwrap(), Some(9999));
792        assert!(stream.try_next().await.unwrap().is_none());
793
794        let mut bad_stream = PagedFn::stream(async |_| endpoint.get_page(Some(99999)).await);
795
796        assert!(bad_stream.try_next().await.is_err());
797    }
798}