iota_aws_orchestrator/client/
mod.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::HashMap,
7    fmt::Display,
8    net::{Ipv4Addr, SocketAddr},
9};
10
11use serde::{Deserialize, Serialize};
12
13use super::error::CloudProviderResult;
14
15pub mod aws;
16
17#[derive(Debug, Deserialize, Clone, Eq, PartialEq, Hash)]
18pub enum InstanceRole {
19    Node,
20    Client,
21    Metrics,
22}
23
24impl Display for InstanceRole {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        write!(f, "{self:?}")
27    }
28}
29
30impl From<&str> for InstanceRole {
31    fn from(role: &str) -> Self {
32        match role {
33            "Node" => InstanceRole::Node,
34            "Client" => InstanceRole::Client,
35            "Metrics" => InstanceRole::Metrics,
36            _ => unreachable!(),
37        }
38    }
39}
40
41#[derive(Debug, Deserialize, Clone, Eq, PartialEq, Hash)]
42pub enum InstanceLifecycle {
43    Spot,
44    OnDemand,
45}
46
47impl Display for InstanceLifecycle {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "{self:?}")
50    }
51}
52/// Represents a cloud provider instance.
53#[derive(Debug, Deserialize, Clone, Eq, PartialEq, Hash)]
54pub struct Instance {
55    /// The unique identifier of the instance.
56    pub id: String,
57    /// The region where the instance runs.
58    pub region: String,
59    /// The public ip address of the instance (accessible from anywhere).
60    pub main_ip: Ipv4Addr,
61    /// The public ip address of the instance (accessible from the same VPC).
62    pub private_ip: Ipv4Addr,
63    /// The list of tags associated with the instance.
64    pub tags: Vec<String>,
65    /// The specs of the instance.
66    pub specs: String,
67    /// The current status of the instance.
68    pub status: String,
69    // The role of the instance. "Node" | "Client" | "Metrics"
70    pub role: InstanceRole,
71    // The lifecycle of the instance. "Spot" | "OnDemand"
72    pub lifecycle: InstanceLifecycle,
73}
74
75impl Instance {
76    /// Return whether the instance is active and running.
77    pub fn is_active(&self) -> bool {
78        self.status.to_lowercase() == "running"
79    }
80
81    /// Return whether the instance is inactive and not ready for use.
82    pub fn is_inactive(&self) -> bool {
83        !self.is_active()
84    }
85
86    // Return whether the instance is able to be started
87    pub fn is_stopped(&self) -> bool {
88        self.status.to_lowercase() == "stopped"
89    }
90
91    /// Return whether the instance is terminated and in the process of being
92    /// deleted.
93    pub fn is_terminated(&self) -> bool {
94        self.status.to_lowercase() == "terminated"
95    }
96
97    /// Return the ssh address to connect to the instance.
98    pub fn ssh_address(&self) -> SocketAddr {
99        format!("{}:22", self.main_ip).parse().unwrap()
100    }
101
102    #[cfg(test)]
103    pub fn new_for_test(id: String) -> Self {
104        Self {
105            id,
106            region: Default::default(),
107            main_ip: Ipv4Addr::LOCALHOST,
108            private_ip: Ipv4Addr::LOCALHOST,
109            tags: Default::default(),
110            specs: Default::default(),
111            status: Default::default(),
112            role: InstanceRole::Node,
113            lifecycle: InstanceLifecycle::OnDemand,
114        }
115    }
116}
117
118#[async_trait::async_trait]
119pub trait ServerProviderClient: Display {
120    /// The username used to connect to the instances.
121    const USERNAME: &'static str;
122
123    /// List all existing instances (regardless of their status) filtered by
124    /// role.
125    async fn list_instances_by_role(
126        &self,
127        role: InstanceRole,
128    ) -> CloudProviderResult<Vec<Instance>>;
129
130    async fn list_instances_by_region_and_ids(
131        &self,
132        ids_by_region: &HashMap<String, Vec<String>>,
133    ) -> CloudProviderResult<Vec<Instance>>;
134
135    /// Start the specified instances.
136    async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
137    where
138        I: Iterator<Item = &'a Instance> + Send;
139
140    /// Halt/Stop the specified instances. We may still be billed for stopped
141    /// instances.
142    async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
143    where
144        I: Iterator<Item = &'a Instance> + Send;
145
146    /// Create an instance in a specific region.
147    async fn create_instance<S>(
148        &self,
149        region: S,
150        role: InstanceRole,
151        quantity: usize,
152        use_spot_instances: bool,
153        id: String,
154    ) -> CloudProviderResult<Vec<Instance>>
155    where
156        S: Into<String> + Serialize + Send;
157
158    /// Delete a specific instance. Calling this function ensures we are no
159    /// longer billed for the specified instance.
160    async fn delete_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
161    where
162        I: Iterator<Item = &'a Instance> + Send;
163
164    /// Authorize the provided ssh public key to access machines.
165    async fn register_ssh_public_key(&self, public_key: String) -> CloudProviderResult<()>;
166
167    /// Return provider-specific commands to setup the instance.
168    async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>>;
169
170    #[cfg(test)]
171    fn instances(&self) -> Vec<Instance>;
172}
173
174#[cfg(test)]
175pub mod test_client {
176    use std::{collections::HashMap, fmt::Display, sync::Mutex};
177
178    use serde::Serialize;
179
180    use super::{Instance, InstanceLifecycle, InstanceRole, ServerProviderClient};
181    use crate::{error::CloudProviderResult, settings::Settings};
182
183    pub struct TestClient {
184        settings: Settings,
185        instances: Mutex<Vec<Instance>>,
186    }
187
188    impl TestClient {
189        pub fn new(settings: Settings) -> Self {
190            Self {
191                settings,
192                instances: Mutex::new(Vec::new()),
193            }
194        }
195    }
196
197    impl Display for TestClient {
198        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199            write!(f, "TestClient")
200        }
201    }
202
203    #[async_trait::async_trait]
204    impl ServerProviderClient for TestClient {
205        const USERNAME: &'static str = "root";
206
207        async fn list_instances_by_role(
208            &self,
209            _role: InstanceRole,
210        ) -> CloudProviderResult<Vec<Instance>> {
211            let guard = self.instances.lock().unwrap();
212            Ok(guard.clone())
213        }
214        async fn list_instances_by_region_and_ids(
215            &self,
216            ids_by_region: &HashMap<String, Vec<String>>,
217        ) -> CloudProviderResult<Vec<Instance>> {
218            let guard = self.instances.lock().unwrap();
219            let instances_by_ids = guard
220                .iter()
221                .filter(|x| {
222                    if let Some(instances) = ids_by_region.get(x.region.as_str()) {
223                        instances.contains(&x.id)
224                    } else {
225                        false
226                    }
227                })
228                .cloned()
229                .collect::<Vec<_>>();
230            Ok(instances_by_ids)
231        }
232
233        async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
234        where
235            I: Iterator<Item = &'a Instance> + Send,
236        {
237            let instance_ids: Vec<_> = instances.map(|x| x.id.clone()).collect();
238            let mut guard = self.instances.lock().unwrap();
239            for instance in guard.iter_mut().filter(|x| instance_ids.contains(&x.id)) {
240                instance.status = "running".into();
241            }
242            Ok(())
243        }
244
245        async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
246        where
247            I: Iterator<Item = &'a Instance> + Send,
248        {
249            let instance_ids: Vec<_> = instances.map(|x| x.id.clone()).collect();
250            let mut guard = self.instances.lock().unwrap();
251            for instance in guard.iter_mut().filter(|x| instance_ids.contains(&x.id)) {
252                instance.status = "stopped".into();
253            }
254            Ok(())
255        }
256
257        async fn create_instance<S>(
258            &self,
259            region: S,
260            role: InstanceRole,
261            quantity: usize,
262            use_spot_instances: bool,
263            _id: String,
264        ) -> CloudProviderResult<Vec<Instance>>
265        where
266            S: Into<String> + Serialize + Send,
267        {
268            let mut guard = self.instances.lock().unwrap();
269            let mut instances = Vec::new();
270            let region = region.into();
271            for _ in 0..quantity {
272                let id = guard.len();
273                let instance = Instance {
274                    id: id.to_string(),
275                    region: region.clone(),
276                    main_ip: format!("0.0.0.{id}").parse().unwrap(),
277                    private_ip: format!("0.0.0.{id}").parse().unwrap(),
278                    tags: Vec::new(),
279                    specs: self.settings.node_specs.clone(),
280                    status: "running".into(),
281                    role: role.clone(),
282                    lifecycle: if use_spot_instances {
283                        InstanceLifecycle::Spot
284                    } else {
285                        InstanceLifecycle::OnDemand
286                    },
287                };
288                guard.push(instance.clone());
289                instances.push(instance);
290            }
291
292            Ok(instances)
293        }
294
295        async fn delete_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
296        where
297            I: Iterator<Item = &'a Instance> + Send,
298        {
299            let ids_to_delete = instances.map(|x| x.id.clone()).collect::<Vec<_>>();
300            let mut guard = self.instances.lock().unwrap();
301            guard.retain(|x| !ids_to_delete.contains(&x.id));
302            Ok(())
303        }
304
305        async fn register_ssh_public_key(&self, _public_key: String) -> CloudProviderResult<()> {
306            Ok(())
307        }
308
309        async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>> {
310            Ok(Vec::new())
311        }
312        fn instances(&self) -> Vec<Instance> {
313            self.instances.lock().unwrap().clone()
314        }
315    }
316}