iota_sdk/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5//! The IOTA Rust SDK
6//!
7//! It aims at providing a similar SDK functionality like the one existing for
8//! [TypeScript](https://github.com/iotaledger/iota/tree/main/sdk/typescript/).
9//! IOTA Rust SDK builds on top of the [JSON RPC API](https://docs.iota.org/iota-api-ref)
10//! and therefore many of the return types are the ones specified in
11//! [iota_types].
12//!
13//! The API is split in several parts corresponding to different functionalities
14//! as following:
15//! * [CoinReadApi] - provides read-only functions to work with the coins
16//! * [EventApi] - provides event related functions functions to
17//! * [GovernanceApi] - provides functionality related to staking
18//! * [QuorumDriverApi] - provides functionality to execute a transaction block
19//!   and submit it to the fullnode(s)
20//! * [ReadApi] - provides functions for retrieving data about different objects
21//!   and transactions
22//! * <a href="../iota_transaction_builder/struct.TransactionBuilder.html"
23//!   title="struct
24//!   iota_transaction_builder::TransactionBuilder">TransactionBuilder</a> -
25//!   provides functions for building transactions
26//!
27//! # Usage
28//! The main way to interact with the API is through the [IotaClientBuilder],
29//! which returns an [IotaClient] object from which the user can access the
30//! various APIs.
31//!
32//! ## Getting Started
33//! Add the Rust SDK to the project by running `cargo add iota-sdk` in the root
34//! folder of your Rust project.
35//!
36//! The main building block for the IOTA Rust SDK is the [IotaClientBuilder],
37//! which provides a simple and straightforward way of connecting to an IOTA
38//! network and having access to the different available APIs.
39//!
40//! Below is a simple example which connects to a running IOTA local network,
41//! devnet, and testnet.
42//! To successfully run this program, make sure to spin up a local
43//! network with a local validator, a fullnode, and a faucet server
44//! (see [the README](https://github.com/iotaledger/iota/tree/develop/crates/iota-sdk/README.md#prerequisites) for more information).
45//!
46//! ```rust,no_run
47//! use iota_sdk::IotaClientBuilder;
48//!
49//! #[tokio::main]
50//! async fn main() -> Result<(), anyhow::Error> {
51//!     let iota = IotaClientBuilder::default()
52//!         .build("http://127.0.0.1:9000") // provide the IOTA network URL
53//!         .await?;
54//!     println!("IOTA local network version: {:?}", iota.api_version());
55//!
56//!     // local IOTA network, same result as above except using the dedicated function
57//!     let iota_local = IotaClientBuilder::default().build_localnet().await?;
58//!     println!("IOTA local network version: {:?}", iota_local.api_version());
59//!
60//!     // IOTA devnet running at `https://api.devnet.iota.cafe`
61//!     let iota_devnet = IotaClientBuilder::default().build_devnet().await?;
62//!     println!("IOTA devnet version: {:?}", iota_devnet.api_version());
63//!
64//!     // IOTA testnet running at `https://api.testnet.iota.cafe`
65//!     let iota_testnet = IotaClientBuilder::default().build_testnet().await?;
66//!     println!("IOTA testnet version: {:?}", iota_testnet.api_version());
67//!
68//!     // IOTA mainnet running at `https://api.mainnet.iota.cafe`
69//!     let iota_mainnet = IotaClientBuilder::default().build_mainnet().await?;
70//!     println!("IOTA mainnet version: {:?}", iota_mainnet.api_version());
71//!
72//!     Ok(())
73//! }
74//! ```
75//!
76//! ## Examples
77//!
78//! For detailed examples, please check the APIs docs and the examples folder
79//! in the [repository](https://github.com/iotaledger/iota/tree/main/crates/iota-sdk/examples).
80
81pub mod apis;
82pub mod error;
83pub mod iota_client_config;
84pub mod json_rpc_error;
85pub mod wallet_context;
86
87use std::{
88    collections::VecDeque,
89    fmt::{Debug, Formatter},
90    marker::PhantomData,
91    pin::Pin,
92    sync::Arc,
93    task::Poll,
94    time::Duration,
95};
96
97use async_trait::async_trait;
98use base64::Engine;
99use futures::{StreamExt, TryStreamExt};
100pub use iota_json as json;
101use iota_json_rpc_api::{
102    CLIENT_SDK_TYPE_HEADER, CLIENT_SDK_VERSION_HEADER, CLIENT_TARGET_API_VERSION_HEADER,
103};
104pub use iota_json_rpc_types as rpc_types;
105use iota_json_rpc_types::{
106    IotaObjectDataFilter, IotaObjectDataOptions, IotaObjectResponse, IotaObjectResponseQuery, Page,
107};
108use iota_transaction_builder::{DataReader, TransactionBuilder};
109pub use iota_types as types;
110use iota_types::base_types::{IotaAddress, ObjectID, ObjectInfo};
111use jsonrpsee::{
112    core::client::ClientT,
113    http_client::{HeaderMap, HeaderValue, HttpClient, HttpClientBuilder},
114    rpc_params,
115    ws_client::{PingConfig, WsClient, WsClientBuilder},
116};
117use move_core_types::language_storage::StructTag;
118use rustls::crypto::{CryptoProvider, ring};
119use serde_json::Value;
120
121use crate::{
122    apis::{CoinReadApi, EventApi, GovernanceApi, QuorumDriverApi, ReadApi},
123    error::{Error, IotaRpcResult},
124};
125
126pub const IOTA_COIN_TYPE: &str = "0x2::iota::IOTA";
127pub const IOTA_LOCAL_NETWORK_URL: &str = "http://127.0.0.1:9000";
128pub const IOTA_LOCAL_NETWORK_URL_0: &str = "http://0.0.0.0:9000";
129pub const IOTA_LOCAL_NETWORK_GRAPHQL_URL: &str = "http://127.0.0.1:8000";
130pub const IOTA_LOCAL_NETWORK_GAS_URL: &str = "http://127.0.0.1:9123/v1/gas";
131pub const IOTA_DEVNET_URL: &str = "https://api.devnet.iota.cafe";
132pub const IOTA_DEVNET_GRAPHQL_URL: &str = "https://graphql.devnet.iota.cafe";
133pub const IOTA_DEVNET_GAS_URL: &str = "https://faucet.devnet.iota.cafe/v1/gas";
134pub const IOTA_TESTNET_URL: &str = "https://api.testnet.iota.cafe";
135pub const IOTA_TESTNET_GRAPHQL_URL: &str = "https://graphql.testnet.iota.cafe";
136pub const IOTA_TESTNET_GAS_URL: &str = "https://faucet.testnet.iota.cafe/v1/gas";
137pub const IOTA_MAINNET_URL: &str = "https://api.mainnet.iota.cafe";
138
139/// Builder for creating an [IotaClient] for connecting to the IOTA network.
140///
141/// By default `maximum concurrent requests` is set to 256 and `request timeout`
142/// is set to 60 seconds. These can be adjusted using
143/// [`Self::max_concurrent_requests()`], and the [`Self::request_timeout()`].
144/// If you use the WebSocket, consider setting `ws_ping_interval` appropriately
145/// to prevent an inactive WS subscription being disconnected due to proxy
146/// timeout.
147///
148/// # Examples
149///
150/// ```rust,no_run
151/// use iota_sdk::IotaClientBuilder;
152/// #[tokio::main]
153/// async fn main() -> Result<(), anyhow::Error> {
154///     let iota = IotaClientBuilder::default()
155///         .build("http://127.0.0.1:9000")
156///         .await?;
157///
158///     println!("IOTA local network version: {:?}", iota.api_version());
159///     Ok(())
160/// }
161/// ```
162pub struct IotaClientBuilder {
163    request_timeout: Duration,
164    max_concurrent_requests: Option<usize>,
165    ws_url: Option<String>,
166    ws_ping_interval: Option<Duration>,
167    basic_auth: Option<(String, String)>,
168}
169
170impl Default for IotaClientBuilder {
171    fn default() -> Self {
172        Self {
173            request_timeout: Duration::from_secs(60),
174            max_concurrent_requests: None,
175            ws_url: None,
176            ws_ping_interval: None,
177            basic_auth: None,
178        }
179    }
180}
181
182impl IotaClientBuilder {
183    /// Set the request timeout to the specified duration.
184    pub fn request_timeout(mut self, request_timeout: Duration) -> Self {
185        self.request_timeout = request_timeout;
186        self
187    }
188
189    /// Set the max concurrent requests allowed.
190    pub fn max_concurrent_requests(mut self, max_concurrent_requests: usize) -> Self {
191        self.max_concurrent_requests = Some(max_concurrent_requests);
192        self
193    }
194
195    /// Set the WebSocket URL for the IOTA network.
196    pub fn ws_url(mut self, url: impl AsRef<str>) -> Self {
197        self.ws_url = Some(url.as_ref().to_string());
198        self
199    }
200
201    /// Set the WebSocket ping interval.
202    pub fn ws_ping_interval(mut self, duration: Duration) -> Self {
203        self.ws_ping_interval = Some(duration);
204        self
205    }
206
207    /// Set the basic auth credentials for the HTTP client.
208    pub fn basic_auth(mut self, username: impl AsRef<str>, password: impl AsRef<str>) -> Self {
209        self.basic_auth = Some((username.as_ref().to_string(), password.as_ref().to_string()));
210        self
211    }
212
213    /// Return an [IotaClient] object connected to the IOTA network accessible
214    /// via the provided URI.
215    ///
216    /// # Examples
217    ///
218    /// ```rust,no_run
219    /// use iota_sdk::IotaClientBuilder;
220    ///
221    /// #[tokio::main]
222    /// async fn main() -> Result<(), anyhow::Error> {
223    ///     let iota = IotaClientBuilder::default()
224    ///         .build("http://127.0.0.1:9000")
225    ///         .await?;
226    ///
227    ///     println!("IOTA local version: {:?}", iota.api_version());
228    ///     Ok(())
229    /// }
230    /// ```
231    pub async fn build(self, http: impl AsRef<str>) -> IotaRpcResult<IotaClient> {
232        if CryptoProvider::get_default().is_none() {
233            ring::default_provider().install_default().ok();
234        }
235
236        let client_version = env!("CARGO_PKG_VERSION");
237        let mut headers = HeaderMap::new();
238        headers.insert(
239            CLIENT_TARGET_API_VERSION_HEADER,
240            // in rust, the client version is the same as the target api version
241            HeaderValue::from_static(client_version),
242        );
243        headers.insert(
244            CLIENT_SDK_VERSION_HEADER,
245            HeaderValue::from_static(client_version),
246        );
247        headers.insert(CLIENT_SDK_TYPE_HEADER, HeaderValue::from_static("rust"));
248
249        if let Some((username, password)) = self.basic_auth {
250            let auth = base64::engine::general_purpose::STANDARD
251                .encode(format!("{}:{}", username, password));
252            headers.insert(
253                "authorization",
254                // reqwest::header::AUTHORIZATION,
255                HeaderValue::from_str(&format!("Basic {}", auth)).unwrap(),
256            );
257        }
258
259        let ws = if let Some(url) = self.ws_url {
260            let mut builder = WsClientBuilder::default()
261                .max_request_size(2 << 30)
262                .set_headers(headers.clone())
263                .request_timeout(self.request_timeout);
264
265            if let Some(duration) = self.ws_ping_interval {
266                builder = builder.enable_ws_ping(PingConfig::new().ping_interval(duration))
267            }
268
269            if let Some(max_concurrent_requests) = self.max_concurrent_requests {
270                builder = builder.max_concurrent_requests(max_concurrent_requests);
271            }
272
273            builder.build(url).await.ok()
274        } else {
275            None
276        };
277
278        let mut http_builder = HttpClientBuilder::default()
279            .max_request_size(2 << 30)
280            .set_headers(headers)
281            .request_timeout(self.request_timeout);
282
283        if let Some(max_concurrent_requests) = self.max_concurrent_requests {
284            http_builder = http_builder.max_concurrent_requests(max_concurrent_requests);
285        }
286
287        let http = http_builder.build(http)?;
288
289        let info = Self::get_server_info(&http, &ws).await?;
290
291        let rpc = RpcClient { http, ws, info };
292        let api = Arc::new(rpc);
293        let read_api = Arc::new(ReadApi::new(api.clone()));
294        let quorum_driver_api = QuorumDriverApi::new(api.clone());
295        let event_api = EventApi::new(api.clone());
296        let transaction_builder = TransactionBuilder::new(read_api.clone());
297        let coin_read_api = CoinReadApi::new(api.clone());
298        let governance_api = GovernanceApi::new(api.clone());
299
300        Ok(IotaClient {
301            api,
302            transaction_builder,
303            read_api,
304            coin_read_api,
305            event_api,
306            quorum_driver_api,
307            governance_api,
308        })
309    }
310
311    /// Return an [IotaClient] object that is ready to interact with the local
312    /// development network (by default it expects the IOTA network to be up
313    /// and running at `127.0.0.1:9000`).
314    ///
315    /// For connecting to a custom URI, use the `build` function instead.
316    ///
317    /// # Examples
318    ///
319    /// ```rust,no_run
320    /// use iota_sdk::IotaClientBuilder;
321    ///
322    /// #[tokio::main]
323    /// async fn main() -> Result<(), anyhow::Error> {
324    ///     let iota = IotaClientBuilder::default().build_localnet().await?;
325    ///
326    ///     println!("IOTA local version: {:?}", iota.api_version());
327    ///     Ok(())
328    /// }
329    /// ```
330    pub async fn build_localnet(self) -> IotaRpcResult<IotaClient> {
331        self.build(IOTA_LOCAL_NETWORK_URL).await
332    }
333
334    /// Return an [IotaClient] object that is ready to interact with the IOTA
335    /// devnet.
336    ///
337    /// For connecting to a custom URI, use the `build` function instead.
338    ///
339    /// # Examples
340    ///
341    /// ```rust,no_run
342    /// use iota_sdk::IotaClientBuilder;
343    ///
344    /// #[tokio::main]
345    /// async fn main() -> Result<(), anyhow::Error> {
346    ///     let iota = IotaClientBuilder::default().build_devnet().await?;
347    ///
348    ///     println!("{:?}", iota.api_version());
349    ///     Ok(())
350    /// }
351    /// ```
352    pub async fn build_devnet(self) -> IotaRpcResult<IotaClient> {
353        self.build(IOTA_DEVNET_URL).await
354    }
355
356    /// Return an [IotaClient] object that is ready to interact with the IOTA
357    /// testnet.
358    ///
359    /// For connecting to a custom URI, use the `build` function instead.
360    ///
361    /// # Examples
362    ///
363    /// ```rust,no_run
364    /// use iota_sdk::IotaClientBuilder;
365    ///
366    /// #[tokio::main]
367    /// async fn main() -> Result<(), anyhow::Error> {
368    ///     let iota = IotaClientBuilder::default().build_testnet().await?;
369    ///
370    ///     println!("{:?}", iota.api_version());
371    ///     Ok(())
372    /// }
373    /// ```
374    pub async fn build_testnet(self) -> IotaRpcResult<IotaClient> {
375        self.build(IOTA_TESTNET_URL).await
376    }
377
378    /// Returns an [IotaClient] object that is ready to interact with the IOTA
379    /// mainnet.
380    ///
381    /// For connecting to a custom URI, use the `build` function instead.
382    ///
383    /// # Examples
384    ///
385    /// ```rust,no_run
386    /// use iota_sdk::IotaClientBuilder;
387    ///
388    /// #[tokio::main]
389    /// async fn main() -> Result<(), anyhow::Error> {
390    ///     let iota = IotaClientBuilder::default().build_mainnet().await?;
391    ///
392    ///     println!("{:?}", iota.api_version());
393    ///     Ok(())
394    /// }
395    /// ```
396    pub async fn build_mainnet(self) -> IotaRpcResult<IotaClient> {
397        self.build(IOTA_MAINNET_URL).await
398    }
399
400    /// Return the server information as a `ServerInfo` structure.
401    ///
402    /// Fails with an error if it cannot call the RPC discover.
403    async fn get_server_info(
404        http: &HttpClient,
405        ws: &Option<WsClient>,
406    ) -> Result<ServerInfo, Error> {
407        let rpc_spec: Value = http.request("rpc.discover", rpc_params![]).await?;
408        let version = rpc_spec
409            .pointer("/info/version")
410            .and_then(|v| v.as_str())
411            .ok_or_else(|| {
412                Error::Data("Fail parsing server version from rpc.discover endpoint.".into())
413            })?;
414        let rpc_methods = Self::parse_methods(&rpc_spec)?;
415
416        let subscriptions = if let Some(ws) = ws {
417            match ws.request("rpc.discover", rpc_params![]).await {
418                Ok(rpc_spec) => Self::parse_methods(&rpc_spec)?,
419                Err(_) => Vec::new(),
420            }
421        } else {
422            Vec::new()
423        };
424        let iota_system_state_v2_support =
425            rpc_methods.contains(&"iotax_getLatestIotaSystemStateV2".to_string());
426        Ok(ServerInfo {
427            rpc_methods,
428            subscriptions,
429            version: version.to_string(),
430            iota_system_state_v2_support,
431        })
432    }
433
434    fn parse_methods(server_spec: &Value) -> Result<Vec<String>, Error> {
435        let methods = server_spec
436            .pointer("/methods")
437            .and_then(|methods| methods.as_array())
438            .ok_or_else(|| {
439                Error::Data("Fail parsing server information from rpc.discover endpoint.".into())
440            })?;
441
442        Ok(methods
443            .iter()
444            .flat_map(|method| method["name"].as_str())
445            .map(|s| s.into())
446            .collect())
447    }
448}
449
450/// Provides all the necessary abstractions for interacting with the IOTA
451/// network.
452///
453/// # Usage
454///
455/// Use [IotaClientBuilder] to build an [IotaClient].
456///
457/// # Examples
458///
459/// ```rust,no_run
460/// use std::str::FromStr;
461///
462/// use iota_sdk::{IotaClientBuilder, types::base_types::IotaAddress};
463///
464/// #[tokio::main]
465/// async fn main() -> Result<(), anyhow::Error> {
466///     let iota = IotaClientBuilder::default()
467///         .build("http://127.0.0.1:9000")
468///         .await?;
469///
470///     println!("{:?}", iota.available_rpc_methods());
471///     println!("{:?}", iota.available_subscriptions());
472///     println!("{:?}", iota.api_version());
473///
474///     let address = IotaAddress::from_str("0x0000....0000")?;
475///     let owned_objects = iota
476///         .read_api()
477///         .get_owned_objects(address, None, None, None)
478///         .await?;
479///
480///     println!("{:?}", owned_objects);
481///
482///     Ok(())
483/// }
484/// ```
485#[derive(Clone)]
486pub struct IotaClient {
487    api: Arc<RpcClient>,
488    transaction_builder: TransactionBuilder,
489    read_api: Arc<ReadApi>,
490    coin_read_api: CoinReadApi,
491    event_api: EventApi,
492    quorum_driver_api: QuorumDriverApi,
493    governance_api: GovernanceApi,
494}
495
496pub(crate) struct RpcClient {
497    http: HttpClient,
498    ws: Option<WsClient>,
499    info: ServerInfo,
500}
501
502impl Debug for RpcClient {
503    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
504        write!(
505            f,
506            "RPC client. Http: {:?}, Websocket: {:?}",
507            self.http, self.ws
508        )
509    }
510}
511
512/// Contains all the useful information regarding the API version, the available
513/// RPC calls, and subscriptions.
514struct ServerInfo {
515    rpc_methods: Vec<String>,
516    subscriptions: Vec<String>,
517    version: String,
518    iota_system_state_v2_support: bool,
519}
520
521impl IotaClient {
522    /// Return a list of RPC methods supported by the node the client is
523    /// connected to.
524    pub fn available_rpc_methods(&self) -> &Vec<String> {
525        &self.api.info.rpc_methods
526    }
527
528    /// Return a list of streaming/subscription APIs supported by the node the
529    /// client is connected to.
530    pub fn available_subscriptions(&self) -> &Vec<String> {
531        &self.api.info.subscriptions
532    }
533
534    /// Return the API version information as a string.
535    ///
536    /// The format of this string is `<major>.<minor>.<patch>`, e.g., `1.6.0`,
537    /// and it is retrieved from the OpenRPC specification via the discover
538    /// service method.
539    pub fn api_version(&self) -> &str {
540        &self.api.info.version
541    }
542
543    /// Verify if the API version matches the server version and returns an
544    /// error if they do not match.
545    pub fn check_api_version(&self) -> IotaRpcResult<()> {
546        let server_version = self.api_version();
547        let client_version = env!("CARGO_PKG_VERSION");
548        if server_version != client_version {
549            return Err(Error::ServerVersionMismatch {
550                client_version: client_version.to_string(),
551                server_version: server_version.to_string(),
552            });
553        };
554        Ok(())
555    }
556
557    /// Return a reference to the coin read API.
558    pub fn coin_read_api(&self) -> &CoinReadApi {
559        &self.coin_read_api
560    }
561
562    /// Return a reference to the event API.
563    pub fn event_api(&self) -> &EventApi {
564        &self.event_api
565    }
566
567    /// Return a reference to the governance API.
568    pub fn governance_api(&self) -> &GovernanceApi {
569        &self.governance_api
570    }
571
572    /// Return a reference to the quorum driver API.
573    pub fn quorum_driver_api(&self) -> &QuorumDriverApi {
574        &self.quorum_driver_api
575    }
576
577    /// Return a reference to the read API.
578    pub fn read_api(&self) -> &ReadApi {
579        &self.read_api
580    }
581
582    /// Return a reference to the transaction builder API.
583    pub fn transaction_builder(&self) -> &TransactionBuilder {
584        &self.transaction_builder
585    }
586
587    /// Return a reference to the underlying http client.
588    pub fn http(&self) -> &HttpClient {
589        &self.api.http
590    }
591
592    /// Return a reference to the underlying WebSocket client, if any.
593    pub fn ws(&self) -> Option<&WsClient> {
594        self.api.ws.as_ref()
595    }
596}
597
598#[async_trait]
599impl DataReader for ReadApi {
600    async fn get_owned_objects(
601        &self,
602        address: IotaAddress,
603        object_type: StructTag,
604    ) -> Result<Vec<ObjectInfo>, anyhow::Error> {
605        let query = Some(IotaObjectResponseQuery {
606            filter: Some(IotaObjectDataFilter::StructType(object_type)),
607            options: Some(
608                IotaObjectDataOptions::new()
609                    .with_previous_transaction()
610                    .with_type()
611                    .with_owner(),
612            ),
613        });
614
615        let result = PagedFn::stream(async |cursor| {
616            self.get_owned_objects(address, query.clone(), cursor, None)
617                .await
618        })
619        .map(|v| v?.try_into())
620        .try_collect::<Vec<_>>()
621        .await?;
622
623        Ok(result)
624    }
625
626    async fn get_object_with_options(
627        &self,
628        object_id: ObjectID,
629        options: IotaObjectDataOptions,
630    ) -> Result<IotaObjectResponse, anyhow::Error> {
631        Ok(self.get_object_with_options(object_id, options).await?)
632    }
633
634    /// Return the reference gas price as a u64 or an error otherwise
635    async fn get_reference_gas_price(&self) -> Result<u64, anyhow::Error> {
636        Ok(self.get_reference_gas_price().await?)
637    }
638}
639
640/// A helper trait for repeatedly calling an async function which returns pages
641/// of data.
642pub trait PagedFn<O, C, F, E>: Sized + Fn(Option<C>) -> F
643where
644    O: Send,
645    C: Send,
646    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
647{
648    /// Get all items from the source and collect them into a vector.
649    fn collect<T>(self) -> impl futures::Future<Output = Result<T, E>>
650    where
651        T: Default + Extend<O>,
652    {
653        self.stream().try_collect::<T>()
654    }
655
656    /// Get a stream which will return all items from the source.
657    fn stream(self) -> PagedStream<O, C, F, E, Self> {
658        PagedStream::new(self)
659    }
660}
661
662impl<O, C, F, E, Fun> PagedFn<O, C, F, E> for Fun
663where
664    Fun: Fn(Option<C>) -> F,
665    O: Send,
666    C: Send,
667    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
668{
669}
670
671/// A stream which repeatedly calls an async function which returns a page of
672/// data.
673pub struct PagedStream<O, C, F, E, Fun> {
674    fun: Fun,
675    fut: Pin<Box<F>>,
676    next: VecDeque<O>,
677    has_next_page: bool,
678    _data: PhantomData<(E, C)>,
679}
680
681impl<O, C, F, E, Fun> PagedStream<O, C, F, E, Fun>
682where
683    Fun: Fn(Option<C>) -> F,
684{
685    pub fn new(fun: Fun) -> Self {
686        let fut = fun(None);
687        Self {
688            fun,
689            fut: Box::pin(fut),
690            next: Default::default(),
691            has_next_page: true,
692            _data: PhantomData,
693        }
694    }
695}
696
697impl<O, C, F, E, Fun> futures::Stream for PagedStream<O, C, F, E, Fun>
698where
699    O: Send,
700    C: Send,
701    F: futures::Future<Output = Result<Page<O, C>, E>> + Send,
702    Fun: Fn(Option<C>) -> F,
703{
704    type Item = Result<O, E>;
705
706    fn poll_next(
707        self: std::pin::Pin<&mut Self>,
708        cx: &mut std::task::Context<'_>,
709    ) -> Poll<Option<Self::Item>> {
710        let this = unsafe { self.get_unchecked_mut() };
711        if this.next.is_empty() && this.has_next_page {
712            match this.fut.as_mut().poll(cx) {
713                Poll::Ready(res) => match res {
714                    Ok(mut page) => {
715                        this.next.extend(page.data);
716                        this.has_next_page = page.has_next_page;
717                        if this.has_next_page {
718                            this.fut.set((this.fun)(page.next_cursor.take()));
719                        }
720                    }
721                    Err(e) => {
722                        this.has_next_page = false;
723                        return Poll::Ready(Some(Err(e)));
724                    }
725                },
726                Poll::Pending => return Poll::Pending,
727            }
728        }
729        Poll::Ready(this.next.pop_front().map(Ok))
730    }
731}
732
733#[cfg(test)]
734mod test {
735    use iota_json_rpc_types::Page;
736
737    use super::*;
738
739    #[tokio::test]
740    async fn test_get_all_pages() {
741        let data = (0..10000).collect::<Vec<_>>();
742        struct Endpoint {
743            data: Vec<i32>,
744        }
745
746        impl Endpoint {
747            async fn get_page(&self, cursor: Option<usize>) -> anyhow::Result<Page<i32, usize>> {
748                const PAGE_SIZE: usize = 100;
749                anyhow::ensure!(cursor.is_none_or(|v| v < self.data.len()), "invalid cursor");
750                let index = cursor.unwrap_or_default();
751                let data = self.data[index..]
752                    .iter()
753                    .copied()
754                    .take(PAGE_SIZE)
755                    .collect::<Vec<_>>();
756                let has_next_page = self.data.len() > index + PAGE_SIZE;
757                Ok(Page {
758                    data,
759                    next_cursor: has_next_page.then_some(index + PAGE_SIZE),
760                    has_next_page,
761                })
762            }
763        }
764
765        let endpoint = Endpoint { data };
766
767        let mut stream = PagedFn::stream(async |cursor| endpoint.get_page(cursor).await);
768
769        assert_eq!(
770            stream
771                .by_ref()
772                .take(9999)
773                .try_collect::<Vec<_>>()
774                .await
775                .unwrap(),
776            endpoint.data[..9999]
777        );
778        assert_eq!(stream.by_ref().try_next().await.unwrap(), Some(9999));
779        assert!(stream.try_next().await.unwrap().is_none());
780
781        let mut bad_stream = PagedFn::stream(async |_| endpoint.get_page(Some(99999)).await);
782
783        assert!(bad_stream.try_next().await.is_err());
784    }
785}