iota_sdk/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5//! The IOTA Rust SDK
6//!
7//! It aims at providing a similar SDK functionality like the one existing for
8//! [TypeScript](https://github.com/iotaledger/iota/tree/main/sdk/typescript/).
9//! IOTA Rust SDK builds on top of the [JSON RPC API](https://docs.iota.org/iota-api-ref)
10//! and therefore many of the return types are the ones specified in
11//! [iota_types].
12//!
13//! The API is split in several parts corresponding to different functionalities
14//! as following:
15//! * [CoinReadApi] - provides read-only functions to work with the coins
16//! * [EventApi] - provides event related functions functions to
17//! * [GovernanceApi] - provides functionality related to staking
18//! * [QuorumDriverApi] - provides functionality to execute a transaction block
19//!   and submit it to the fullnode(s)
20//! * [ReadApi] - provides functions for retrieving data about different objects
21//!   and transactions
22//! * <a href="../iota_transaction_builder/struct.TransactionBuilder.html"
23//!   title="struct
24//!   iota_transaction_builder::TransactionBuilder">TransactionBuilder</a> -
25//!   provides functions for building transactions
26//!
27//! # Usage
28//! The main way to interact with the API is through the [IotaClientBuilder],
29//! which returns an [IotaClient] object from which the user can access the
30//! various APIs.
31//!
32//! ## Getting Started
33//! Add the Rust SDK to the project by running `cargo add iota-sdk` in the root
34//! folder of your Rust project.
35//!
36//! The main building block for the IOTA Rust SDK is the [IotaClientBuilder],
37//! which provides a simple and straightforward way of connecting to an IOTA
38//! network and having access to the different available APIs.
39//!
40//! Below is a simple example which connects to a running IOTA local network,
41//! devnet, and testnet.
42//! To successfully run this program, make sure to spin up a local
43//! network with a local validator, a fullnode, and a faucet server
44//! (see [the README](https://github.com/iotaledger/iota/tree/develop/crates/iota-sdk/README.md#prerequisites) for more information).
45//!
46//! ```rust,no_run
47//! use iota_sdk::IotaClientBuilder;
48//!
49//! #[tokio::main]
50//! async fn main() -> Result<(), anyhow::Error> {
51//!     let iota = IotaClientBuilder::default()
52//!         .build("http://127.0.0.1:9000") // provide the IOTA network URL
53//!         .await?;
54//!     println!("IOTA local network version: {:?}", iota.api_version());
55//!
56//!     // local IOTA network, same result as above except using the dedicated function
57//!     let iota_local = IotaClientBuilder::default().build_localnet().await?;
58//!     println!("IOTA local network version: {:?}", iota_local.api_version());
59//!
60//!     // IOTA devnet running at `https://api.devnet.iota.cafe`
61//!     let iota_devnet = IotaClientBuilder::default().build_devnet().await?;
62//!     println!("IOTA devnet version: {:?}", iota_devnet.api_version());
63//!
64//!     // IOTA testnet running at `https://api.testnet.iota.cafe`
65//!     let iota_testnet = IotaClientBuilder::default().build_testnet().await?;
66//!     println!("IOTA testnet version: {:?}", iota_testnet.api_version());
67//!
68//!     // IOTA mainnet running at `https://api.mainnet.iota.cafe`
69//!     let iota_mainnet = IotaClientBuilder::default().build_mainnet().await?;
70//!     println!("IOTA mainnet version: {:?}", iota_mainnet.api_version());
71//!
72//!     Ok(())
73//! }
74//! ```
75//!
76//! ## Examples
77//!
78//! For detailed examples, please check the APIs docs and the examples folder
79//! in the [repository](https://github.com/iotaledger/iota/tree/main/crates/iota-sdk/examples).
80
81pub mod apis;
82pub mod error;
83pub mod iota_client_config;
84pub mod json_rpc_error;
85pub mod wallet_context;
86
87use std::{
88    collections::VecDeque,
89    fmt::{Debug, Formatter},
90    marker::PhantomData,
91    pin::Pin,
92    sync::Arc,
93    task::Poll,
94    time::Duration,
95};
96
97use async_trait::async_trait;
98use base64::Engine;
99use futures::{StreamExt, TryStreamExt};
100pub use iota_json as json;
101use iota_json_rpc_api::{
102    CLIENT_SDK_TYPE_HEADER, CLIENT_SDK_VERSION_HEADER, CLIENT_TARGET_API_VERSION_HEADER,
103};
104pub use iota_json_rpc_types as rpc_types;
105use iota_json_rpc_types::{
106    IotaObjectDataFilter, IotaObjectDataOptions, IotaObjectResponse, IotaObjectResponseQuery, Page,
107};
108use iota_transaction_builder::{DataReader, TransactionBuilder};
109pub use iota_types as types;
110use iota_types::base_types::{IotaAddress, ObjectID, ObjectInfo};
111use jsonrpsee::{
112    core::client::ClientT,
113    http_client::{HeaderMap, HeaderValue, HttpClient, HttpClientBuilder},
114    rpc_params,
115    ws_client::{PingConfig, WsClient, WsClientBuilder},
116};
117use move_core_types::language_storage::StructTag;
118use rustls::crypto::{CryptoProvider, ring};
119use serde_json::Value;
120
121use crate::{
122    apis::{CoinReadApi, EventApi, GovernanceApi, QuorumDriverApi, ReadApi},
123    error::{Error, IotaRpcResult},
124};
125
126pub const IOTA_COIN_TYPE: &str = "0x2::iota::IOTA";
127pub const IOTA_LOCAL_NETWORK_URL: &str = "http://127.0.0.1:9000";
128pub const IOTA_LOCAL_NETWORK_URL_0: &str = "http://0.0.0.0:9000";
129pub const IOTA_LOCAL_NETWORK_GRAPHQL_URL: &str = "http://127.0.0.1:8000";
130pub const IOTA_LOCAL_NETWORK_GAS_URL: &str = "http://127.0.0.1:9123/v1/gas";
131pub const IOTA_DEVNET_URL: &str = "https://api.devnet.iota.cafe";
132pub const IOTA_DEVNET_GRAPHQL_URL: &str = "https://graphql.devnet.iota.cafe";
133pub const IOTA_DEVNET_GAS_URL: &str = "https://faucet.devnet.iota.cafe/v1/gas";
134pub const IOTA_TESTNET_URL: &str = "https://api.testnet.iota.cafe";
135pub const IOTA_TESTNET_GRAPHQL_URL: &str = "https://graphql.testnet.iota.cafe";
136pub const IOTA_TESTNET_GAS_URL: &str = "https://faucet.testnet.iota.cafe/v1/gas";
137pub const IOTA_MAINNET_URL: &str = "https://api.mainnet.iota.cafe";
138
139/// Builder for creating an [IotaClient] for connecting to the IOTA network.
140///
141/// By default `maximum concurrent requests` is set to 256 and `request timeout`
142/// is set to 60 seconds. These can be adjusted using
143/// [`Self::max_concurrent_requests()`], and the [`Self::request_timeout()`].
144/// If you use the WebSocket, consider setting `ws_ping_interval` appropriately
145/// to prevent an inactive WS subscription being disconnected due to proxy
146/// timeout.
147///
148/// # Examples
149///
150/// ```rust,no_run
151/// use iota_sdk::IotaClientBuilder;
152///
153/// #[tokio::main]
154/// async fn main() -> Result<(), anyhow::Error> {
155///     let iota = IotaClientBuilder::default()
156///         .build("http://127.0.0.1:9000")
157///         .await?;
158///
159///     println!("IOTA local network version: {:?}", iota.api_version());
160///     Ok(())
161/// }
162/// ```
163pub struct IotaClientBuilder {
164    request_timeout: Duration,
165    max_concurrent_requests: Option<usize>,
166    ws_url: Option<String>,
167    ws_ping_interval: Option<Duration>,
168    basic_auth: Option<(String, String)>,
169}
170
171impl Default for IotaClientBuilder {
172    fn default() -> Self {
173        Self {
174            request_timeout: Duration::from_secs(60),
175            max_concurrent_requests: None,
176            ws_url: None,
177            ws_ping_interval: None,
178            basic_auth: None,
179        }
180    }
181}
182
183impl IotaClientBuilder {
184    /// Set the request timeout to the specified duration.
185    pub fn request_timeout(mut self, request_timeout: Duration) -> Self {
186        self.request_timeout = request_timeout;
187        self
188    }
189
190    /// Set the max concurrent requests allowed.
191    pub fn max_concurrent_requests(mut self, max_concurrent_requests: usize) -> Self {
192        self.max_concurrent_requests = Some(max_concurrent_requests);
193        self
194    }
195
196    /// Set the WebSocket URL for the IOTA network.
197    pub fn ws_url(mut self, url: impl AsRef<str>) -> Self {
198        self.ws_url = Some(url.as_ref().to_string());
199        self
200    }
201
202    /// Set the WebSocket ping interval.
203    pub fn ws_ping_interval(mut self, duration: Duration) -> Self {
204        self.ws_ping_interval = Some(duration);
205        self
206    }
207
208    /// Set the basic auth credentials for the HTTP client.
209    pub fn basic_auth(mut self, username: impl AsRef<str>, password: impl AsRef<str>) -> Self {
210        self.basic_auth = Some((username.as_ref().to_string(), password.as_ref().to_string()));
211        self
212    }
213
214    /// Return an [IotaClient] object connected to the IOTA network accessible
215    /// via the provided URI.
216    ///
217    /// # Examples
218    ///
219    /// ```rust,no_run
220    /// use iota_sdk::IotaClientBuilder;
221    ///
222    /// #[tokio::main]
223    /// async fn main() -> Result<(), anyhow::Error> {
224    ///     let iota = IotaClientBuilder::default()
225    ///         .build("http://127.0.0.1:9000")
226    ///         .await?;
227    ///
228    ///     println!("IOTA local version: {:?}", iota.api_version());
229    ///     Ok(())
230    /// }
231    /// ```
232    pub async fn build(self, http: impl AsRef<str>) -> IotaRpcResult<IotaClient> {
233        if CryptoProvider::get_default().is_none() {
234            ring::default_provider().install_default().ok();
235        }
236
237        let client_version = env!("CARGO_PKG_VERSION");
238        let mut headers = HeaderMap::new();
239        headers.insert(
240            CLIENT_TARGET_API_VERSION_HEADER,
241            // in rust, the client version is the same as the target api version
242            HeaderValue::from_static(client_version),
243        );
244        headers.insert(
245            CLIENT_SDK_VERSION_HEADER,
246            HeaderValue::from_static(client_version),
247        );
248        headers.insert(CLIENT_SDK_TYPE_HEADER, HeaderValue::from_static("rust"));
249
250        if let Some((username, password)) = self.basic_auth {
251            let auth = base64::engine::general_purpose::STANDARD
252                .encode(format!("{}:{}", username, password));
253            headers.insert(
254                "authorization",
255                // reqwest::header::AUTHORIZATION,
256                HeaderValue::from_str(&format!("Basic {}", auth)).unwrap(),
257            );
258        }
259
260        let ws = if let Some(url) = self.ws_url {
261            let mut builder = WsClientBuilder::default()
262                .max_request_size(2 << 30)
263                .set_headers(headers.clone())
264                .request_timeout(self.request_timeout);
265
266            if let Some(duration) = self.ws_ping_interval {
267                builder = builder.enable_ws_ping(PingConfig::new().ping_interval(duration))
268            }
269
270            if let Some(max_concurrent_requests) = self.max_concurrent_requests {
271                builder = builder.max_concurrent_requests(max_concurrent_requests);
272            }
273
274            builder.build(url).await.ok()
275        } else {
276            None
277        };
278
279        let mut http_builder = HttpClientBuilder::default()
280            .max_request_size(2 << 30)
281            .set_headers(headers)
282            .request_timeout(self.request_timeout);
283
284        if let Some(max_concurrent_requests) = self.max_concurrent_requests {
285            http_builder = http_builder.max_concurrent_requests(max_concurrent_requests);
286        }
287
288        let http = http_builder.build(http)?;
289
290        let info = Self::get_server_info(&http, &ws).await?;
291
292        let rpc = RpcClient { http, ws, info };
293        let api = Arc::new(rpc);
294        let read_api = Arc::new(ReadApi::new(api.clone()));
295        let quorum_driver_api = QuorumDriverApi::new(api.clone());
296        let event_api = EventApi::new(api.clone());
297        let transaction_builder = TransactionBuilder::new(read_api.clone());
298        let coin_read_api = CoinReadApi::new(api.clone());
299        let governance_api = GovernanceApi::new(api.clone());
300
301        Ok(IotaClient {
302            api,
303            transaction_builder,
304            read_api,
305            coin_read_api,
306            event_api,
307            quorum_driver_api,
308            governance_api,
309        })
310    }
311
312    /// Return an [IotaClient] object that is ready to interact with the local
313    /// development network (by default it expects the IOTA network to be up
314    /// and running at `127.0.0.1:9000`).
315    ///
316    /// For connecting to a custom URI, use the `build` function instead.
317    ///
318    /// # Examples
319    ///
320    /// ```rust,no_run
321    /// use iota_sdk::IotaClientBuilder;
322    ///
323    /// #[tokio::main]
324    /// async fn main() -> Result<(), anyhow::Error> {
325    ///     let iota = IotaClientBuilder::default().build_localnet().await?;
326    ///
327    ///     println!("IOTA local version: {:?}", iota.api_version());
328    ///     Ok(())
329    /// }
330    /// ```
331    pub async fn build_localnet(self) -> IotaRpcResult<IotaClient> {
332        self.build(IOTA_LOCAL_NETWORK_URL).await
333    }
334
335    /// Return an [IotaClient] object that is ready to interact with the IOTA
336    /// devnet.
337    ///
338    /// For connecting to a custom URI, use the `build` function instead.
339    ///
340    /// # Examples
341    ///
342    /// ```rust,no_run
343    /// use iota_sdk::IotaClientBuilder;
344    ///
345    /// #[tokio::main]
346    /// async fn main() -> Result<(), anyhow::Error> {
347    ///     let iota = IotaClientBuilder::default().build_devnet().await?;
348    ///
349    ///     println!("{:?}", iota.api_version());
350    ///     Ok(())
351    /// }
352    /// ```
353    pub async fn build_devnet(self) -> IotaRpcResult<IotaClient> {
354        self.build(IOTA_DEVNET_URL).await
355    }
356
357    /// Return an [IotaClient] object that is ready to interact with the IOTA
358    /// testnet.
359    ///
360    /// For connecting to a custom URI, use the `build` function instead.
361    ///
362    /// # Examples
363    ///
364    /// ```rust,no_run
365    /// use iota_sdk::IotaClientBuilder;
366    ///
367    /// #[tokio::main]
368    /// async fn main() -> Result<(), anyhow::Error> {
369    ///     let iota = IotaClientBuilder::default().build_testnet().await?;
370    ///
371    ///     println!("{:?}", iota.api_version());
372    ///     Ok(())
373    /// }
374    /// ```
375    pub async fn build_testnet(self) -> IotaRpcResult<IotaClient> {
376        self.build(IOTA_TESTNET_URL).await
377    }
378
379    /// Returns an [IotaClient] object that is ready to interact with the IOTA
380    /// mainnet.
381    ///
382    /// For connecting to a custom URI, use the `build` function instead.
383    ///
384    /// # Examples
385    ///
386    /// ```rust,no_run
387    /// use iota_sdk::IotaClientBuilder;
388    ///
389    /// #[tokio::main]
390    /// async fn main() -> Result<(), anyhow::Error> {
391    ///     let iota = IotaClientBuilder::default().build_mainnet().await?;
392    ///
393    ///     println!("{:?}", iota.api_version());
394    ///     Ok(())
395    /// }
396    /// ```
397    pub async fn build_mainnet(self) -> IotaRpcResult<IotaClient> {
398        self.build(IOTA_MAINNET_URL).await
399    }
400
401    /// Return the server information as a `ServerInfo` structure.
402    ///
403    /// Fails with an error if it cannot call the RPC discover.
404    async fn get_server_info(
405        http: &HttpClient,
406        ws: &Option<WsClient>,
407    ) -> Result<ServerInfo, Error> {
408        let rpc_spec: Value = http.request("rpc.discover", rpc_params![]).await?;
409        let version = rpc_spec
410            .pointer("/info/version")
411            .and_then(|v| v.as_str())
412            .ok_or_else(|| {
413                Error::Data("Fail parsing server version from rpc.discover endpoint.".into())
414            })?;
415        let rpc_methods = Self::parse_methods(&rpc_spec)?;
416
417        let subscriptions = if let Some(ws) = ws {
418            match ws.request("rpc.discover", rpc_params![]).await {
419                Ok(rpc_spec) => Self::parse_methods(&rpc_spec)?,
420                Err(_) => Vec::new(),
421            }
422        } else {
423            Vec::new()
424        };
425        let iota_system_state_v2_support =
426            rpc_methods.contains(&"iotax_getLatestIotaSystemStateV2".to_string());
427        Ok(ServerInfo {
428            rpc_methods,
429            subscriptions,
430            version: version.to_string(),
431            iota_system_state_v2_support,
432        })
433    }
434
435    fn parse_methods(server_spec: &Value) -> Result<Vec<String>, Error> {
436        let methods = server_spec
437            .pointer("/methods")
438            .and_then(|methods| methods.as_array())
439            .ok_or_else(|| {
440                Error::Data("Fail parsing server information from rpc.discover endpoint.".into())
441            })?;
442
443        Ok(methods
444            .iter()
445            .flat_map(|method| method["name"].as_str())
446            .map(|s| s.into())
447            .collect())
448    }
449}
450
451/// Provides all the necessary abstractions for interacting with the IOTA
452/// network.
453///
454/// # Usage
455///
456/// Use [IotaClientBuilder] to build an [IotaClient].
457///
458/// # Examples
459///
460/// ```rust,no_run
461/// use std::str::FromStr;
462///
463/// use iota_sdk::{IotaClientBuilder, types::base_types::IotaAddress};
464///
465/// #[tokio::main]
466/// async fn main() -> Result<(), anyhow::Error> {
467///     let iota = IotaClientBuilder::default()
468///         .build("http://127.0.0.1:9000")
469///         .await?;
470///
471///     println!("{:?}", iota.available_rpc_methods());
472///     println!("{:?}", iota.available_subscriptions());
473///     println!("{:?}", iota.api_version());
474///
475///     let address = IotaAddress::from_str("0x0000....0000")?;
476///     let owned_objects = iota
477///         .read_api()
478///         .get_owned_objects(address, None, None, None)
479///         .await?;
480///
481///     println!("{:?}", owned_objects);
482///
483///     Ok(())
484/// }
485/// ```
486#[derive(Clone)]
487pub struct IotaClient {
488    api: Arc<RpcClient>,
489    transaction_builder: TransactionBuilder,
490    read_api: Arc<ReadApi>,
491    coin_read_api: CoinReadApi,
492    event_api: EventApi,
493    quorum_driver_api: QuorumDriverApi,
494    governance_api: GovernanceApi,
495}
496
497pub(crate) struct RpcClient {
498    http: HttpClient,
499    ws: Option<WsClient>,
500    info: ServerInfo,
501}
502
503impl Debug for RpcClient {
504    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
505        write!(
506            f,
507            "RPC client. Http: {:?}, Websocket: {:?}",
508            self.http, self.ws
509        )
510    }
511}
512
513/// Contains all the useful information regarding the API version, the available
514/// RPC calls, and subscriptions.
515struct ServerInfo {
516    rpc_methods: Vec<String>,
517    subscriptions: Vec<String>,
518    version: String,
519    iota_system_state_v2_support: bool,
520}
521
522impl IotaClient {
523    /// Return a list of RPC methods supported by the node the client is
524    /// connected to.
525    pub fn available_rpc_methods(&self) -> &Vec<String> {
526        &self.api.info.rpc_methods
527    }
528
529    /// Return a list of streaming/subscription APIs supported by the node the
530    /// client is connected to.
531    pub fn available_subscriptions(&self) -> &Vec<String> {
532        &self.api.info.subscriptions
533    }
534
535    /// Return the API version information as a string.
536    ///
537    /// The format of this string is `<major>.<minor>.<patch>`, e.g., `1.6.0`,
538    /// and it is retrieved from the OpenRPC specification via the discover
539    /// service method.
540    pub fn api_version(&self) -> &str {
541        &self.api.info.version
542    }
543
544    /// Verify if the API version matches the server version and returns an
545    /// error if they do not match.
546    pub fn check_api_version(&self) -> IotaRpcResult<()> {
547        let server_version = self.api_version();
548        let client_version = env!("CARGO_PKG_VERSION");
549        if server_version != client_version {
550            return Err(Error::ServerVersionMismatch {
551                client_version: client_version.to_string(),
552                server_version: server_version.to_string(),
553            });
554        };
555        Ok(())
556    }
557
558    /// Return a reference to the coin read API.
559    pub fn coin_read_api(&self) -> &CoinReadApi {
560        &self.coin_read_api
561    }
562
563    /// Return a reference to the event API.
564    pub fn event_api(&self) -> &EventApi {
565        &self.event_api
566    }
567
568    /// Return a reference to the governance API.
569    pub fn governance_api(&self) -> &GovernanceApi {
570        &self.governance_api
571    }
572
573    /// Return a reference to the quorum driver API.
574    pub fn quorum_driver_api(&self) -> &QuorumDriverApi {
575        &self.quorum_driver_api
576    }
577
578    /// Return a reference to the read API.
579    pub fn read_api(&self) -> &ReadApi {
580        &self.read_api
581    }
582
583    /// Return a reference to the transaction builder API.
584    pub fn transaction_builder(&self) -> &TransactionBuilder {
585        &self.transaction_builder
586    }
587
588    /// Return a reference to the underlying http client.
589    pub fn http(&self) -> &HttpClient {
590        &self.api.http
591    }
592
593    /// Return a reference to the underlying WebSocket client, if any.
594    pub fn ws(&self) -> Option<&WsClient> {
595        self.api.ws.as_ref()
596    }
597}
598
599#[async_trait]
600impl DataReader for ReadApi {
601    async fn get_owned_objects(
602        &self,
603        address: IotaAddress,
604        object_type: StructTag,
605    ) -> Result<Vec<ObjectInfo>, anyhow::Error> {
606        let query = Some(IotaObjectResponseQuery {
607            filter: Some(IotaObjectDataFilter::StructType(object_type)),
608            options: Some(
609                IotaObjectDataOptions::new()
610                    .with_previous_transaction()
611                    .with_type()
612                    .with_owner(),
613            ),
614        });
615
616        let result = PagedFn::stream(async |cursor| {
617            self.get_owned_objects(address, query.clone(), cursor, None)
618                .await
619        })
620        .map(|v| v?.try_into())
621        .try_collect::<Vec<_>>()
622        .await?;
623
624        Ok(result)
625    }
626
627    async fn get_object_with_options(
628        &self,
629        object_id: ObjectID,
630        options: IotaObjectDataOptions,
631    ) -> Result<IotaObjectResponse, anyhow::Error> {
632        Ok(self.get_object_with_options(object_id, options).await?)
633    }
634
635    /// Return the reference gas price as a u64 or an error otherwise
636    async fn get_reference_gas_price(&self) -> Result<u64, anyhow::Error> {
637        Ok(self.get_reference_gas_price().await?)
638    }
639}
640
641/// A helper trait for repeatedly calling an async function which returns pages
642/// of data.
643pub trait PagedFn<O, C, F, E>: Sized + Fn(Option<C>) -> F
644where
645    O: Send,
646    C: Send,
647    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
648{
649    /// Get all items from the source and collect them into a vector.
650    fn collect<T>(self) -> impl futures::Future<Output = Result<T, E>>
651    where
652        T: Default + Extend<O>,
653    {
654        self.stream().try_collect::<T>()
655    }
656
657    /// Get a stream which will return all items from the source.
658    fn stream(self) -> PagedStream<O, C, F, E, Self> {
659        PagedStream::new(self)
660    }
661}
662
663impl<O, C, F, E, Fun> PagedFn<O, C, F, E> for Fun
664where
665    Fun: Fn(Option<C>) -> F,
666    O: Send,
667    C: Send,
668    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
669{
670}
671
672/// A stream which repeatedly calls an async function which returns a page of
673/// data.
674pub struct PagedStream<O, C, F, E, Fun> {
675    fun: Fun,
676    fut: Pin<Box<F>>,
677    next: VecDeque<O>,
678    has_next_page: bool,
679    _data: PhantomData<(E, C)>,
680}
681
682impl<O, C, F, E, Fun> PagedStream<O, C, F, E, Fun>
683where
684    Fun: Fn(Option<C>) -> F,
685{
686    pub fn new(fun: Fun) -> Self {
687        let fut = fun(None);
688        Self {
689            fun,
690            fut: Box::pin(fut),
691            next: Default::default(),
692            has_next_page: true,
693            _data: PhantomData,
694        }
695    }
696}
697
698impl<O, C, F, E, Fun> futures::Stream for PagedStream<O, C, F, E, Fun>
699where
700    O: Send,
701    C: Send,
702    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
703    Fun: Fn(Option<C>) -> F,
704{
705    type Item = Result<O, E>;
706
707    fn poll_next(
708        self: std::pin::Pin<&mut Self>,
709        cx: &mut std::task::Context<'_>,
710    ) -> Poll<Option<Self::Item>> {
711        let this = unsafe { self.get_unchecked_mut() };
712        if this.next.is_empty() && this.has_next_page {
713            match this.fut.as_mut().poll(cx) {
714                Poll::Ready(res) => match res {
715                    Ok(mut page) => {
716                        this.next.extend(page.data);
717                        this.has_next_page = page.has_next_page;
718                        if this.has_next_page {
719                            this.fut.set((this.fun)(page.next_cursor.take()));
720                        }
721                    }
722                    Err(e) => {
723                        this.has_next_page = false;
724                        return Poll::Ready(Some(Err(e)));
725                    }
726                },
727                Poll::Pending => return Poll::Pending,
728            }
729        }
730        Poll::Ready(this.next.pop_front().map(Ok))
731    }
732}
733
734#[cfg(test)]
735mod test {
736    use iota_json_rpc_types::Page;
737
738    use super::*;
739
740    #[tokio::test]
741    async fn test_get_all_pages() {
742        let data = (0..10000).collect::<Vec<_>>();
743        struct Endpoint {
744            data: Vec<i32>,
745        }
746
747        impl Endpoint {
748            async fn get_page(&self, cursor: Option<usize>) -> anyhow::Result<Page<i32, usize>> {
749                const PAGE_SIZE: usize = 100;
750                anyhow::ensure!(cursor.is_none_or(|v| v < self.data.len()), "invalid cursor");
751                let index = cursor.unwrap_or_default();
752                let data = self.data[index..]
753                    .iter()
754                    .copied()
755                    .take(PAGE_SIZE)
756                    .collect::<Vec<_>>();
757                let has_next_page = self.data.len() > index + PAGE_SIZE;
758                Ok(Page {
759                    data,
760                    next_cursor: has_next_page.then_some(index + PAGE_SIZE),
761                    has_next_page,
762                })
763            }
764        }
765
766        let endpoint = Endpoint { data };
767
768        let mut stream = PagedFn::stream(async |cursor| endpoint.get_page(cursor).await);
769
770        assert_eq!(
771            stream
772                .by_ref()
773                .take(9999)
774                .try_collect::<Vec<_>>()
775                .await
776                .unwrap(),
777            endpoint.data[..9999]
778        );
779        assert_eq!(stream.by_ref().try_next().await.unwrap(), Some(9999));
780        assert!(stream.try_next().await.unwrap().is_none());
781
782        let mut bad_stream = PagedFn::stream(async |_| endpoint.get_page(Some(99999)).await);
783
784        assert!(bad_stream.try_next().await.is_err());
785    }
786}