iota_aws_orchestrator/client/
mod.rs1use std::{
6 collections::HashMap,
7 fmt::Display,
8 net::{Ipv4Addr, SocketAddr},
9};
10
11use serde::{Deserialize, Serialize};
12
13use super::error::CloudProviderResult;
14
15pub mod aws;
16
17#[derive(Debug, Deserialize, Clone, Eq, PartialEq, Hash)]
18pub enum InstanceRole {
19 Node,
20 Client,
21 Metrics,
22}
23
24impl Display for InstanceRole {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 write!(f, "{self:?}")
27 }
28}
29
30impl From<&str> for InstanceRole {
31 fn from(role: &str) -> Self {
32 match role {
33 "Node" => InstanceRole::Node,
34 "Client" => InstanceRole::Client,
35 "Metrics" => InstanceRole::Metrics,
36 _ => unreachable!(),
37 }
38 }
39}
40
41#[derive(Debug, Deserialize, Clone, Eq, PartialEq, Hash)]
42pub enum InstanceLifecycle {
43 Spot,
44 OnDemand,
45}
46
47impl Display for InstanceLifecycle {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(f, "{self:?}")
50 }
51}
52#[derive(Debug, Deserialize, Clone, Eq, PartialEq, Hash)]
54pub struct Instance {
55 pub id: String,
57 pub region: String,
59 pub main_ip: Ipv4Addr,
61 pub private_ip: Ipv4Addr,
63 pub tags: Vec<String>,
65 pub specs: String,
67 pub status: String,
69 pub role: InstanceRole,
71 pub lifecycle: InstanceLifecycle,
73}
74
75impl Instance {
76 pub fn is_active(&self) -> bool {
78 self.status.to_lowercase() == "running"
79 }
80
81 pub fn is_inactive(&self) -> bool {
83 !self.is_active()
84 }
85
86 pub fn is_stopped(&self) -> bool {
88 self.status.to_lowercase() == "stopped"
89 }
90
91 pub fn is_terminated(&self) -> bool {
94 self.status.to_lowercase() == "terminated"
95 }
96
97 pub fn ssh_address(&self) -> SocketAddr {
99 format!("{}:22", self.main_ip).parse().unwrap()
100 }
101
102 #[cfg(test)]
103 pub fn new_for_test(id: String) -> Self {
104 Self {
105 id,
106 region: Default::default(),
107 main_ip: Ipv4Addr::LOCALHOST,
108 private_ip: Ipv4Addr::LOCALHOST,
109 tags: Default::default(),
110 specs: Default::default(),
111 status: Default::default(),
112 role: InstanceRole::Node,
113 lifecycle: InstanceLifecycle::OnDemand,
114 }
115 }
116}
117
118#[async_trait::async_trait]
119pub trait ServerProviderClient: Display {
120 const USERNAME: &'static str;
122
123 async fn list_instances_by_role(
126 &self,
127 role: InstanceRole,
128 ) -> CloudProviderResult<Vec<Instance>>;
129
130 async fn list_instances_by_region_and_ids(
131 &self,
132 ids_by_region: &HashMap<String, Vec<String>>,
133 ) -> CloudProviderResult<Vec<Instance>>;
134
135 async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
137 where
138 I: Iterator<Item = &'a Instance> + Send;
139
140 async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
143 where
144 I: Iterator<Item = &'a Instance> + Send;
145
146 async fn create_instance<S>(
148 &self,
149 region: S,
150 role: InstanceRole,
151 quantity: usize,
152 use_spot_instances: bool,
153 id: String,
154 ) -> CloudProviderResult<Vec<Instance>>
155 where
156 S: Into<String> + Serialize + Send;
157
158 async fn delete_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
161 where
162 I: Iterator<Item = &'a Instance> + Send;
163
164 async fn register_ssh_public_key(&self, public_key: String) -> CloudProviderResult<()>;
166
167 async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>>;
169
170 #[cfg(test)]
171 fn instances(&self) -> Vec<Instance>;
172}
173
174#[cfg(test)]
175pub mod test_client {
176 use std::{collections::HashMap, fmt::Display, sync::Mutex};
177
178 use serde::Serialize;
179
180 use super::{Instance, InstanceLifecycle, InstanceRole, ServerProviderClient};
181 use crate::{error::CloudProviderResult, settings::Settings};
182
183 pub struct TestClient {
184 settings: Settings,
185 instances: Mutex<Vec<Instance>>,
186 }
187
188 impl TestClient {
189 pub fn new(settings: Settings) -> Self {
190 Self {
191 settings,
192 instances: Mutex::new(Vec::new()),
193 }
194 }
195 }
196
197 impl Display for TestClient {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 write!(f, "TestClient")
200 }
201 }
202
203 #[async_trait::async_trait]
204 impl ServerProviderClient for TestClient {
205 const USERNAME: &'static str = "root";
206
207 async fn list_instances_by_role(
208 &self,
209 _role: InstanceRole,
210 ) -> CloudProviderResult<Vec<Instance>> {
211 let guard = self.instances.lock().unwrap();
212 Ok(guard.clone())
213 }
214 async fn list_instances_by_region_and_ids(
215 &self,
216 ids_by_region: &HashMap<String, Vec<String>>,
217 ) -> CloudProviderResult<Vec<Instance>> {
218 let guard = self.instances.lock().unwrap();
219 let instances_by_ids = guard
220 .iter()
221 .filter(|x| {
222 if let Some(instances) = ids_by_region.get(x.region.as_str()) {
223 instances.contains(&x.id)
224 } else {
225 false
226 }
227 })
228 .cloned()
229 .collect::<Vec<_>>();
230 Ok(instances_by_ids)
231 }
232
233 async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
234 where
235 I: Iterator<Item = &'a Instance> + Send,
236 {
237 let instance_ids: Vec<_> = instances.map(|x| x.id.clone()).collect();
238 let mut guard = self.instances.lock().unwrap();
239 for instance in guard.iter_mut().filter(|x| instance_ids.contains(&x.id)) {
240 instance.status = "running".into();
241 }
242 Ok(())
243 }
244
245 async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
246 where
247 I: Iterator<Item = &'a Instance> + Send,
248 {
249 let instance_ids: Vec<_> = instances.map(|x| x.id.clone()).collect();
250 let mut guard = self.instances.lock().unwrap();
251 for instance in guard.iter_mut().filter(|x| instance_ids.contains(&x.id)) {
252 instance.status = "stopped".into();
253 }
254 Ok(())
255 }
256
257 async fn create_instance<S>(
258 &self,
259 region: S,
260 role: InstanceRole,
261 quantity: usize,
262 use_spot_instances: bool,
263 _id: String,
264 ) -> CloudProviderResult<Vec<Instance>>
265 where
266 S: Into<String> + Serialize + Send,
267 {
268 let mut guard = self.instances.lock().unwrap();
269 let mut instances = Vec::new();
270 let region = region.into();
271 for _ in 0..quantity {
272 let id = guard.len();
273 let instance = Instance {
274 id: id.to_string(),
275 region: region.clone(),
276 main_ip: format!("0.0.0.{id}").parse().unwrap(),
277 private_ip: format!("0.0.0.{id}").parse().unwrap(),
278 tags: Vec::new(),
279 specs: self.settings.node_specs.clone(),
280 status: "running".into(),
281 role: role.clone(),
282 lifecycle: if use_spot_instances {
283 InstanceLifecycle::Spot
284 } else {
285 InstanceLifecycle::OnDemand
286 },
287 };
288 guard.push(instance.clone());
289 instances.push(instance);
290 }
291
292 Ok(instances)
293 }
294
295 async fn delete_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
296 where
297 I: Iterator<Item = &'a Instance> + Send,
298 {
299 let ids_to_delete = instances.map(|x| x.id.clone()).collect::<Vec<_>>();
300 let mut guard = self.instances.lock().unwrap();
301 guard.retain(|x| !ids_to_delete.contains(&x.id));
302 Ok(())
303 }
304
305 async fn register_ssh_public_key(&self, _public_key: String) -> CloudProviderResult<()> {
306 Ok(())
307 }
308
309 async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>> {
310 Ok(Vec::new())
311 }
312 fn instances(&self) -> Vec<Instance> {
313 self.instances.lock().unwrap().clone()
314 }
315 }
316}