iota_aws_orchestrator/
testbed.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{collections::HashMap, time::Duration};
6
7use futures::future::try_join_all;
8use prettytable::{Table, row};
9use tokio::time::{self, Instant};
10
11use super::client::{Instance, InstanceLifecycle, InstanceRole};
12use crate::{
13    client::ServerProviderClient,
14    display,
15    error::{TestbedError, TestbedResult},
16    settings::Settings,
17    ssh::SshConnection,
18};
19
20/// Represents a testbed running on a cloud provider.
21pub struct Testbed<C> {
22    /// The testbed's settings.
23    settings: Settings,
24    /// The client interfacing with the cloud provider.
25    client: C,
26    /// List of Node instances.
27    node_instances: Vec<Instance>,
28    /// List of dedicated Client instances.
29    client_instances: Option<Vec<Instance>>,
30    /// Dedicated Metrics Instance
31    metrics_instance: Option<Instance>,
32}
33
34impl<C: ServerProviderClient> Testbed<C> {
35    /// Create a new testbed instance with the specified settings and client.
36    pub async fn new(settings: Settings, client: C) -> TestbedResult<Self> {
37        let public_key = settings.load_ssh_public_key()?;
38        client.register_ssh_public_key(public_key).await?;
39        let node_instances = client.list_instances_by_role(InstanceRole::Node).await?;
40        let client_instances = client.list_instances_by_role(InstanceRole::Client).await?;
41        let metrics_instance = client.list_instances_by_role(InstanceRole::Metrics).await?;
42
43        Ok(Self {
44            settings,
45            client,
46            node_instances,
47            client_instances: if client_instances.is_empty() {
48                None
49            } else {
50                Some(client_instances)
51            },
52            metrics_instance: metrics_instance.into_iter().next(),
53        })
54    }
55
56    /// Return the username to connect to the instances through ssh.
57    pub fn username(&self) -> &'static str {
58        C::USERNAME
59    }
60
61    /// Return the list of instances of the testbed.
62    pub fn instances(&self) -> Vec<Instance> {
63        let mut instances = self.node_instances.clone();
64        if let Some(instance) = &self.metrics_instance {
65            instances.push(instance.clone());
66        }
67        if let Some(client_instances) = &self.client_instances {
68            instances.extend(client_instances.clone());
69        }
70        instances
71    }
72    /// Return the list of Node instances of the testbed.
73    pub fn node_instances(&self) -> Vec<Instance> {
74        self.node_instances.clone()
75    }
76    /// Return the list of Client instances of the testbed.
77    pub fn client_instances(&self) -> Vec<Instance> {
78        match &self.client_instances {
79            Some(instances) => instances.clone(),
80            None => self.node_instances.clone(),
81        }
82    }
83    /// Return the Metrics Instance of the testbed.
84    pub fn metrics_instance(&self) -> Option<Instance> {
85        self.metrics_instance.clone()
86    }
87
88    /// Return the list of provider-specific instance setup commands.
89    pub async fn setup_commands(&self) -> TestbedResult<Vec<String>> {
90        self.client
91            .instance_setup_commands()
92            .await
93            .map_err(TestbedError::from)
94    }
95
96    /// Print the current status of the testbed.
97    pub fn status(&self) {
98        let instances_by_region = self.instances().into_iter().fold(
99            HashMap::new(),
100            |mut acc: HashMap<String, Vec<Instance>>, i| {
101                acc.entry(i.region.clone()).or_default().push(i);
102                acc
103            },
104        );
105
106        let mut table = Table::new();
107        table.set_format(display::default_table_format());
108
109        for (i, (region, instances)) in instances_by_region.iter().enumerate() {
110            table.add_row(row![bH2->region.to_uppercase()]);
111            let mut j = 0;
112            for instance in instances {
113                if j % 5 == 0 {
114                    table.add_row(row![]);
115                }
116                let private_key_file = self.settings.ssh_private_key_file.display();
117                let username = C::USERNAME;
118                let ip = instance.main_ip;
119                let private_ip = instance.private_ip;
120                let role = instance.role.to_string();
121                let lifecycle = instance.lifecycle.to_string();
122                let connect = format!(
123                    "[{role:<7}] [{lifecycle:<8}] [{private_ip:<15}] ssh -i {private_key_file} {username}@{ip}"
124                );
125                if !instance.is_terminated() {
126                    if instance.is_active() {
127                        table.add_row(row![bFg->format!("{j}"), connect]);
128                    } else {
129                        table.add_row(row![bFr->format!("{j}"), connect]);
130                    }
131                    j += 1;
132                }
133            }
134            if i != instances_by_region.len() - 1 {
135                table.add_row(row![]);
136            }
137        }
138
139        display::newline();
140        display::config("Client", &self.client);
141        let repo = &self.settings.repository;
142        display::config("Repo", format!("{} ({})", repo.url, repo.commit));
143        display::newline();
144        table.printstd();
145        display::newline();
146    }
147
148    /// Populate the testbed by creating the specified amount of instances per
149    /// region. The total number of instances created is thus the specified
150    /// amount x the number of regions.
151    pub async fn deploy(
152        &mut self,
153        quantity: usize,
154        skip_monitoring: bool,
155        dedicated_clients: usize,
156        use_spot_instances: bool,
157        id: String,
158    ) -> TestbedResult<()> {
159        display::action(format!("Deploying instances ({quantity} per region)"));
160
161        let mut instances: Vec<Instance> = vec![];
162
163        if !skip_monitoring {
164            let metrics_region = self
165                .settings
166                .regions
167                .first()
168                .expect("At least one region must be present")
169                .clone();
170            let metrics_instance = self
171                .client
172                .create_instance(metrics_region, InstanceRole::Metrics, 1, false, id.clone())
173                .await?;
174            instances.extend(metrics_instance);
175        }
176
177        let node_instances = {
178            // Multi-region case — call create_instance per region in parallel
179            let tasks = self.settings.regions.iter().map(|region| {
180                self.client.create_instance(
181                    region.clone(),
182                    InstanceRole::Node,
183                    quantity,
184                    use_spot_instances,
185                    id.clone(),
186                )
187            });
188
189            // Run them all concurrently, flatten Vec<Vec<Instance>> → Vec<Instance>
190            try_join_all(tasks)
191                .await?
192                .into_iter()
193                .flatten()
194                .collect::<Vec<_>>()
195        };
196        instances.extend(node_instances);
197
198        let client_instances = match dedicated_clients {
199            0 => vec![],
200            instance_quantity => {
201                // Multi-region case — call create_instance per region in parallel
202                let tasks = self.settings.regions.iter().map(|region| {
203                    self.client.create_instance(
204                        region.clone(),
205                        InstanceRole::Client,
206                        instance_quantity,
207                        false,
208                        id.clone(),
209                    )
210                });
211
212                // Run them all concurrently, flatten Vec<Vec<Instance>> → Vec<Instance>
213                try_join_all(tasks)
214                    .await?
215                    .into_iter()
216                    .flatten()
217                    .collect::<Vec<_>>()
218            }
219        };
220
221        instances.extend(client_instances);
222
223        // Wait until the instances are booted.
224        if cfg!(not(test)) {
225            self.wait_until_reachable(instances.iter()).await?;
226        }
227        let node_instances = self
228            .client
229            .list_instances_by_role(InstanceRole::Node)
230            .await?;
231        let client_instances = self
232            .client
233            .list_instances_by_role(InstanceRole::Client)
234            .await?;
235        let metrics_instance = self
236            .client
237            .list_instances_by_role(InstanceRole::Metrics)
238            .await?;
239        self.node_instances = node_instances;
240        self.client_instances = if client_instances.is_empty() {
241            None
242        } else {
243            Some(client_instances)
244        };
245        self.metrics_instance = metrics_instance.into_iter().next();
246
247        display::done();
248        Ok(())
249    }
250
251    /// Destroy all instances of the testbed.
252    pub async fn destroy(&mut self, keep_monitoring: bool) -> TestbedResult<()> {
253        let instances_to_destroy = self
254            .instances()
255            .into_iter()
256            .filter(|i| !(keep_monitoring && i.role == InstanceRole::Metrics))
257            .collect::<Vec<_>>();
258        let mut number_of_nodes_to_destroy = 0;
259        let mut number_of_clients_to_destroy = 0;
260        let mut number_of_metrics_to_destroy = 0;
261        for instance in instances_to_destroy.iter() {
262            match instance.role {
263                InstanceRole::Node => {
264                    number_of_nodes_to_destroy += 1;
265                }
266                InstanceRole::Client => {
267                    number_of_clients_to_destroy += 1;
268                }
269                InstanceRole::Metrics => {
270                    number_of_metrics_to_destroy += 1;
271                }
272            }
273        }
274        let confirmation_message = format!(
275            "Confirm you want to destroy the following instances:\n\
276            \n\
277            \tMonitoring Instances: {}\n\
278            \tNode Instances: {}\n\
279            \tClient Instances: {}\n",
280            number_of_metrics_to_destroy, number_of_nodes_to_destroy, number_of_clients_to_destroy,
281        );
282        if cfg!(not(test)) && !display::confirm(confirmation_message) {
283            return Ok(());
284        };
285        display::action("Destroying testbed");
286        self.client
287            .delete_instances(instances_to_destroy.iter())
288            .await?;
289
290        display::done();
291        Ok(())
292    }
293
294    /// Start the specified number of instances in each region. Returns an error
295    /// if there are not enough available instances.
296    pub async fn start(
297        &mut self,
298        quantity: usize,
299        dedicated_clients: usize,
300        skip_monitoring: bool,
301    ) -> TestbedResult<()> {
302        display::action("Booting instances");
303
304        // Gather available instances.
305        let mut available = Vec::new();
306        #[cfg(not(test))]
307        let stopped_node_instances_by_region = self
308            .node_instances()
309            .into_iter()
310            .filter(|i| i.is_stopped())
311            .fold(
312                HashMap::new(),
313                |mut acc: HashMap<String, Vec<Instance>>, i| {
314                    acc.entry(i.region.clone()).or_default().push(i);
315                    acc
316                },
317            );
318        #[cfg(test)]
319        let stopped_node_instances_by_region = self
320            .client
321            .instances()
322            .into_iter()
323            .filter(|i| i.role == InstanceRole::Node)
324            .filter(|i| i.is_stopped())
325            .fold(
326                HashMap::new(),
327                |mut acc: HashMap<String, Vec<Instance>>, i| {
328                    acc.entry(i.region.clone()).or_default().push(i);
329                    acc
330                },
331            );
332        for (_, instances) in stopped_node_instances_by_region {
333            if instances.len() < quantity {
334                return Err(TestbedError::InsufficientCapacity(
335                    quantity - instances.len(),
336                ));
337            }
338            available.extend(instances.into_iter().take(quantity));
339        }
340
341        if !skip_monitoring {
342            if let Some(metrics_instance) = &self.metrics_instance {
343                if metrics_instance.is_stopped() {
344                    available.push(metrics_instance.clone());
345                } else {
346                    return Err(TestbedError::MetricsServerMissing());
347                }
348            }
349        }
350        if dedicated_clients > 0 {
351            if let Some(dedicated_client_nodes) = &self.client_instances {
352                let stopped_client_instances_by_region = dedicated_client_nodes
353                    .iter()
354                    .filter(|i| i.is_stopped())
355                    .fold(
356                        HashMap::new(),
357                        |mut acc: HashMap<String, Vec<Instance>>, i| {
358                            acc.entry(i.region.clone()).or_default().push(i.clone());
359                            acc
360                        },
361                    );
362                for (_, instances) in stopped_client_instances_by_region {
363                    if instances.len() < dedicated_clients {
364                        return Err(TestbedError::InsufficientDedicatedClientCapacity(
365                            dedicated_clients - instances.len(),
366                        ));
367                    }
368                    available.extend(instances.into_iter().take(dedicated_clients));
369                }
370            }
371        }
372
373        // Start instances.
374        self.client.start_instances(available.iter()).await?;
375
376        // Wait until the instances are started.
377        if cfg!(not(test)) {
378            self.wait_until_reachable(available.iter()).await?;
379        }
380        let node_instances = self
381            .client
382            .list_instances_by_role(InstanceRole::Node)
383            .await?;
384        let client_instances = self
385            .client
386            .list_instances_by_role(InstanceRole::Client)
387            .await?;
388        let metrics_instance = self
389            .client
390            .list_instances_by_role(InstanceRole::Metrics)
391            .await?;
392        self.node_instances = node_instances;
393        self.client_instances = if client_instances.is_empty() {
394            None
395        } else {
396            Some(client_instances)
397        };
398        self.metrics_instance = metrics_instance.into_iter().next();
399
400        display::done();
401        Ok(())
402    }
403
404    /// Stop all instances of the testbed.
405    pub async fn stop(&mut self, keep_monitoring: bool) -> TestbedResult<()> {
406        display::action("Stopping instances");
407
408        // Stop all instances.
409        self.client
410            .stop_instances(self.instances().iter().filter(|i| {
411                i.is_active()
412                    && !(i.role == InstanceRole::Metrics && keep_monitoring)
413                    && i.lifecycle == InstanceLifecycle::OnDemand
414            }))
415            .await?;
416
417        // Wait until the instances are stopped.
418        loop {
419            let mut instances = self
420                .client
421                .list_instances_by_role(InstanceRole::Node)
422                .await?;
423            let client_instances = self
424                .client
425                .list_instances_by_role(InstanceRole::Client)
426                .await?;
427            instances.extend(client_instances);
428            if !keep_monitoring {
429                let metrics_instance = self
430                    .client
431                    .list_instances_by_role(InstanceRole::Metrics)
432                    .await?;
433                instances.extend(metrics_instance);
434            }
435
436            if instances.iter().all(|x| x.is_inactive()) {
437                break;
438            }
439        }
440
441        display::done();
442        Ok(())
443    }
444
445    /// Wait until all specified instances are ready to accept ssh connections.
446    async fn wait_until_reachable<'a, I>(&self, instances: I) -> TestbedResult<()>
447    where
448        I: Iterator<Item = &'a Instance> + Clone,
449    {
450        let instance_region_and_ids = instances.fold(
451            HashMap::new(),
452            |mut acc: HashMap<String, Vec<String>>, i| {
453                acc.entry(i.region.clone()).or_default().push(i.id.clone());
454                acc
455            },
456        );
457        let mut interval = time::interval(Duration::from_secs(5));
458        interval.tick().await; // The first tick returns immediately.
459
460        let start = Instant::now();
461        loop {
462            let now = interval.tick().await;
463            let elapsed = now.duration_since(start).as_secs_f64().ceil() as u64;
464            display::status(format!("{elapsed}s"));
465            let instances = self
466                .client
467                .list_instances_by_region_and_ids(&instance_region_and_ids)
468                .await?;
469
470            let futures = instances.iter().map(|instance| {
471                let private_key_file = self.settings.ssh_private_key_file.clone();
472                SshConnection::new(
473                    instance.ssh_address(),
474                    C::USERNAME,
475                    private_key_file,
476                    None,
477                    None,
478                )
479            });
480            if try_join_all(futures).await.is_ok() {
481                break;
482            }
483        }
484        Ok(())
485    }
486}
487
488#[cfg(test)]
489mod test {
490    use crate::{
491        client::{InstanceRole, ServerProviderClient, test_client::TestClient},
492        settings::Settings,
493        testbed::Testbed,
494    };
495
496    #[tokio::test]
497    async fn deploy() {
498        let settings = Settings::new_for_test();
499        let client = TestClient::new(settings.clone());
500        let mut testbed = Testbed::new(settings, client).await.unwrap();
501
502        testbed
503            .deploy(5, true, 0, false, "test".to_string())
504            .await
505            .unwrap();
506
507        assert_eq!(
508            testbed.node_instances.len(),
509            5 * testbed.settings.number_of_regions()
510        );
511        for (i, instance) in testbed.node_instances.iter().enumerate() {
512            assert_eq!(i.to_string(), instance.id);
513        }
514    }
515
516    #[tokio::test]
517    async fn destroy() {
518        let settings = Settings::new_for_test();
519        let client = TestClient::new(settings.clone());
520        let mut testbed = Testbed::new(settings, client).await.unwrap();
521
522        testbed.destroy(false).await.unwrap();
523
524        assert_eq!(testbed.node_instances.len(), 0);
525    }
526
527    #[tokio::test]
528    async fn start() {
529        let settings = Settings::new_for_test();
530        let client = TestClient::new(settings.clone());
531        let mut testbed = Testbed::new(settings, client).await.unwrap();
532        testbed
533            .deploy(5, true, 0, false, "test".to_string())
534            .await
535            .unwrap();
536        testbed.stop(false).await.unwrap();
537
538        let result = testbed.start(2, 0, true).await;
539
540        assert!(result.is_ok());
541        for region in &testbed.settings.regions {
542            let active = testbed
543                .client
544                .instances()
545                .iter()
546                .filter(|i| i.role == InstanceRole::Node)
547                .filter(|x| x.is_active() && &x.region == region)
548                .count();
549            assert_eq!(active, 2);
550
551            let inactive = testbed
552                .client
553                .instances()
554                .iter()
555                .filter(|i| i.role == InstanceRole::Node)
556                .filter(|x| x.is_inactive() && &x.region == region)
557                .count();
558            assert_eq!(inactive, 3);
559        }
560    }
561
562    #[tokio::test]
563    async fn stop() {
564        let settings = Settings::new_for_test();
565        let client = TestClient::new(settings.clone());
566        let mut testbed = Testbed::new(settings, client).await.unwrap();
567        testbed
568            .deploy(5, true, 0, false, "test".to_string())
569            .await
570            .unwrap();
571        testbed.start(2, 0, true).await.unwrap();
572
573        testbed.stop(false).await.unwrap();
574
575        assert!(testbed.client.instances().iter().all(|x| x.is_inactive()))
576    }
577}