iota_aws_orchestrator/
testbed.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{collections::HashMap, time::Duration};
6
7use futures::future::try_join_all;
8use prettytable::{Table, row};
9use tokio::time::{self, Instant};
10
11use super::client::{Instance, InstanceLifecycle, InstanceRole};
12use crate::{
13    client::ServerProviderClient,
14    display,
15    error::{TestbedError, TestbedResult},
16    settings::Settings,
17    ssh::SshConnection,
18};
19
20/// Represents a testbed running on a cloud provider.
21pub struct Testbed<C> {
22    /// The testbed's settings.
23    settings: Settings,
24    /// The client interfacing with the cloud provider.
25    client: C,
26    /// List of Node instances.
27    node_instances: Vec<Instance>,
28    /// List of dedicated Client instances.
29    client_instances: Option<Vec<Instance>>,
30    /// Dedicated Metrics Instance
31    metrics_instance: Option<Instance>,
32}
33
34impl<C: ServerProviderClient> Testbed<C> {
35    /// Create a new testbed instance with the specified settings and client.
36    pub async fn new(settings: Settings, client: C) -> TestbedResult<Self> {
37        let public_key = settings.load_ssh_public_key()?;
38        client.register_ssh_public_key(public_key).await?;
39        let node_instances = client.list_instances_by_role(InstanceRole::Node).await?;
40        let client_instances = client.list_instances_by_role(InstanceRole::Client).await?;
41        let metrics_instance = client.list_instances_by_role(InstanceRole::Metrics).await?;
42
43        Ok(Self {
44            settings,
45            client,
46            node_instances,
47            client_instances: if client_instances.is_empty() {
48                None
49            } else {
50                Some(client_instances)
51            },
52            metrics_instance: metrics_instance.into_iter().next(),
53        })
54    }
55
56    /// Return the username to connect to the instances through ssh.
57    pub fn username(&self) -> &'static str {
58        C::USERNAME
59    }
60
61    /// Return the list of instances of the testbed.
62    pub fn instances(&self) -> Vec<Instance> {
63        let mut instances = self.node_instances.clone();
64        if let Some(instance) = &self.metrics_instance {
65            instances.push(instance.clone());
66        }
67        if let Some(client_instances) = &self.client_instances {
68            instances.extend(client_instances.clone());
69        }
70        instances
71    }
72    /// Return the list of Node instances of the testbed.
73    pub fn node_instances(&self) -> Vec<Instance> {
74        self.node_instances.clone()
75    }
76    /// Return the list of Client instances of the testbed.
77    pub fn client_instances(&self) -> Vec<Instance> {
78        match &self.client_instances {
79            Some(instances) => instances.clone(),
80            None => self.node_instances.clone(),
81        }
82    }
83    /// Return the Metrics Instance of the testbed.
84    pub fn metrics_instance(&self) -> Option<Instance> {
85        self.metrics_instance.clone()
86    }
87
88    /// Return the list of provider-specific instance setup commands.
89    pub async fn setup_commands(&self) -> TestbedResult<Vec<String>> {
90        self.client
91            .instance_setup_commands()
92            .await
93            .map_err(TestbedError::from)
94    }
95
96    /// Print the current status of the testbed.
97    pub fn status(&self) {
98        let instances_by_region = self.instances().into_iter().fold(
99            HashMap::new(),
100            |mut acc: HashMap<String, Vec<Instance>>, i| {
101                acc.entry(i.region.clone()).or_default().push(i);
102                acc
103            },
104        );
105
106        let mut table = Table::new();
107        table.set_format(display::default_table_format());
108
109        for (i, (region, instances)) in instances_by_region.iter().enumerate() {
110            table.add_row(row![bH2->region.to_uppercase()]);
111            let mut j = 0;
112            for instance in instances {
113                if j % 5 == 0 {
114                    table.add_row(row![]);
115                }
116                let private_key_file = self.settings.ssh_private_key_file.display();
117                let username = C::USERNAME;
118                let ip = instance.main_ip;
119                let private_ip = instance.private_ip;
120                let role = instance.role.to_string();
121                let lifecycle = instance.lifecycle.to_string();
122                let connect = format!(
123                    "[{role:<7}] [{lifecycle:<8}] [{private_ip:<15}] ssh -i {private_key_file} {username}@{ip}"
124                );
125                if !instance.is_terminated() {
126                    if instance.is_active() {
127                        table.add_row(row![bFg->format!("{j}"), connect]);
128                    } else {
129                        table.add_row(row![bFr->format!("{j}"), connect]);
130                    }
131                    j += 1;
132                }
133            }
134            if i != instances_by_region.len() - 1 {
135                table.add_row(row![]);
136            }
137        }
138
139        display::newline();
140        display::config("Client", &self.client);
141        let repo = &self.settings.repository;
142        display::config("Repo", format!("{} ({})", repo.url, repo.commit));
143        display::newline();
144        table.printstd();
145        display::newline();
146    }
147
148    /// Populate the testbed by creating the specified amount of instances per
149    /// region. The total number of instances created is thus the specified
150    /// amount x the number of regions.
151    pub async fn deploy(
152        &mut self,
153        quantity: usize,
154        skip_monitoring: bool,
155        dedicated_clients: usize,
156        use_spot_instances: bool,
157        id: String,
158    ) -> TestbedResult<()> {
159        display::action(format!("Deploying instances ({quantity} per region)"));
160
161        let mut instances: Vec<Instance> = vec![];
162
163        if !skip_monitoring {
164            let metrics_region = self
165                .settings
166                .regions
167                .first()
168                .expect("At least one region must be present")
169                .clone();
170            let metrics_instance = self
171                .client
172                .create_instance(metrics_region, InstanceRole::Metrics, 1, false, id.clone())
173                .await?;
174            instances.extend(metrics_instance);
175        }
176
177        let node_instances = {
178            // Multi-region case — call create_instance per region in parallel
179            let tasks = self.settings.regions.iter().map(|region| {
180                self.client.create_instance(
181                    region.clone(),
182                    InstanceRole::Node,
183                    quantity,
184                    use_spot_instances,
185                    id.clone(),
186                )
187            });
188
189            // Run them all concurrently, flatten Vec<Vec<Instance>> → Vec<Instance>
190            try_join_all(tasks)
191                .await?
192                .into_iter()
193                .flatten()
194                .collect::<Vec<_>>()
195        };
196        instances.extend(node_instances);
197
198        let client_instances = match dedicated_clients {
199            0 => vec![],
200            instance_quantity => {
201                // Multi-region case — call create_instance per region in parallel
202                let tasks = self.settings.regions.iter().map(|region| {
203                    self.client.create_instance(
204                        region.clone(),
205                        InstanceRole::Client,
206                        instance_quantity,
207                        false,
208                        id.clone(),
209                    )
210                });
211
212                // Run them all concurrently, flatten Vec<Vec<Instance>> → Vec<Instance>
213                try_join_all(tasks)
214                    .await?
215                    .into_iter()
216                    .flatten()
217                    .collect::<Vec<_>>()
218            }
219        };
220
221        instances.extend(client_instances);
222
223        // Wait until the instances are booted.
224        if cfg!(not(test)) {
225            self.wait_until_reachable(instances.iter()).await?;
226        }
227        let node_instances = self
228            .client
229            .list_instances_by_role(InstanceRole::Node)
230            .await?;
231        let client_instances = self
232            .client
233            .list_instances_by_role(InstanceRole::Client)
234            .await?;
235        let metrics_instance = self
236            .client
237            .list_instances_by_role(InstanceRole::Metrics)
238            .await?;
239        self.node_instances = node_instances;
240        self.client_instances = if client_instances.is_empty() {
241            None
242        } else {
243            Some(client_instances)
244        };
245        self.metrics_instance = metrics_instance.into_iter().next();
246
247        display::action("Deployment completed\n\n");
248        display::done();
249        Ok(())
250    }
251
252    /// Destroy all instances of the testbed.
253    pub async fn destroy(&mut self, keep_monitoring: bool, force: bool) -> TestbedResult<()> {
254        let instances_to_destroy = self
255            .instances()
256            .into_iter()
257            .filter(|i| !(keep_monitoring && i.role == InstanceRole::Metrics))
258            .collect::<Vec<_>>();
259        let mut number_of_nodes_to_destroy = 0;
260        let mut number_of_clients_to_destroy = 0;
261        let mut number_of_metrics_to_destroy = 0;
262        for instance in instances_to_destroy.iter() {
263            match instance.role {
264                InstanceRole::Node => {
265                    number_of_nodes_to_destroy += 1;
266                }
267                InstanceRole::Client => {
268                    number_of_clients_to_destroy += 1;
269                }
270                InstanceRole::Metrics => {
271                    number_of_metrics_to_destroy += 1;
272                }
273            }
274        }
275        let confirmation_message = format!(
276            "Confirm you want to destroy the following instances:\n\
277            \n\
278            \tMonitoring Instances: {}\n\
279            \tNode Instances: {}\n\
280            \tClient Instances: {}\n",
281            number_of_metrics_to_destroy, number_of_nodes_to_destroy, number_of_clients_to_destroy,
282        );
283        if cfg!(not(test)) && !force && !display::confirm(confirmation_message) {
284            return Ok(());
285        };
286        display::action("Destroying testbed");
287        self.client
288            .delete_instances(instances_to_destroy.iter())
289            .await?;
290
291        display::done();
292        Ok(())
293    }
294
295    /// Start the specified number of instances in each region. Returns an error
296    /// if there are not enough available instances.
297    pub async fn start(
298        &mut self,
299        quantity: usize,
300        dedicated_clients: usize,
301        skip_monitoring: bool,
302    ) -> TestbedResult<()> {
303        display::action("Booting instances");
304
305        // Gather available instances.
306        let mut available = Vec::new();
307        #[cfg(not(test))]
308        let stopped_node_instances_by_region = self
309            .node_instances()
310            .into_iter()
311            .filter(|i| i.is_stopped())
312            .fold(
313                HashMap::new(),
314                |mut acc: HashMap<String, Vec<Instance>>, i| {
315                    acc.entry(i.region.clone()).or_default().push(i);
316                    acc
317                },
318            );
319        #[cfg(test)]
320        let stopped_node_instances_by_region = self
321            .client
322            .instances()
323            .into_iter()
324            .filter(|i| i.role == InstanceRole::Node)
325            .filter(|i| i.is_stopped())
326            .fold(
327                HashMap::new(),
328                |mut acc: HashMap<String, Vec<Instance>>, i| {
329                    acc.entry(i.region.clone()).or_default().push(i);
330                    acc
331                },
332            );
333        for (_, instances) in stopped_node_instances_by_region {
334            if instances.len() < quantity {
335                return Err(TestbedError::InsufficientCapacity(
336                    quantity - instances.len(),
337                ));
338            }
339            available.extend(instances.into_iter().take(quantity));
340        }
341
342        if !skip_monitoring {
343            if let Some(metrics_instance) = &self.metrics_instance {
344                if metrics_instance.is_stopped() {
345                    available.push(metrics_instance.clone());
346                } else {
347                    return Err(TestbedError::MetricsServerMissing());
348                }
349            }
350        }
351        if dedicated_clients > 0 {
352            if let Some(dedicated_client_nodes) = &self.client_instances {
353                let stopped_client_instances_by_region = dedicated_client_nodes
354                    .iter()
355                    .filter(|i| i.is_stopped())
356                    .fold(
357                        HashMap::new(),
358                        |mut acc: HashMap<String, Vec<Instance>>, i| {
359                            acc.entry(i.region.clone()).or_default().push(i.clone());
360                            acc
361                        },
362                    );
363                for (_, instances) in stopped_client_instances_by_region {
364                    if instances.len() < dedicated_clients {
365                        return Err(TestbedError::InsufficientDedicatedClientCapacity(
366                            dedicated_clients - instances.len(),
367                        ));
368                    }
369                    available.extend(instances.into_iter().take(dedicated_clients));
370                }
371            }
372        }
373
374        // Start instances.
375        self.client.start_instances(available.iter()).await?;
376
377        // Wait until the instances are started.
378        if cfg!(not(test)) {
379            self.wait_until_reachable(available.iter()).await?;
380        }
381        let node_instances = self
382            .client
383            .list_instances_by_role(InstanceRole::Node)
384            .await?;
385        let client_instances = self
386            .client
387            .list_instances_by_role(InstanceRole::Client)
388            .await?;
389        let metrics_instance = self
390            .client
391            .list_instances_by_role(InstanceRole::Metrics)
392            .await?;
393        self.node_instances = node_instances;
394        self.client_instances = if client_instances.is_empty() {
395            None
396        } else {
397            Some(client_instances)
398        };
399        self.metrics_instance = metrics_instance.into_iter().next();
400
401        display::done();
402        Ok(())
403    }
404
405    /// Stop all instances of the testbed.
406    pub async fn stop(&mut self, keep_monitoring: bool) -> TestbedResult<()> {
407        display::action("Stopping instances");
408
409        // Stop all instances.
410        self.client
411            .stop_instances(self.instances().iter().filter(|i| {
412                i.is_active()
413                    && !(i.role == InstanceRole::Metrics && keep_monitoring)
414                    && i.lifecycle == InstanceLifecycle::OnDemand
415            }))
416            .await?;
417
418        // Wait until the instances are stopped.
419        loop {
420            let mut instances = self
421                .client
422                .list_instances_by_role(InstanceRole::Node)
423                .await?;
424            let client_instances = self
425                .client
426                .list_instances_by_role(InstanceRole::Client)
427                .await?;
428            instances.extend(client_instances);
429            if !keep_monitoring {
430                let metrics_instance = self
431                    .client
432                    .list_instances_by_role(InstanceRole::Metrics)
433                    .await?;
434                instances.extend(metrics_instance);
435            }
436
437            if instances.iter().all(|x| x.is_inactive()) {
438                break;
439            }
440        }
441
442        display::done();
443        Ok(())
444    }
445
446    /// Wait until all specified instances are ready to accept ssh connections.
447    async fn wait_until_reachable<'a, I>(&self, instances: I) -> TestbedResult<()>
448    where
449        I: Iterator<Item = &'a Instance> + Clone,
450    {
451        let instance_region_and_ids = instances.fold(
452            HashMap::new(),
453            |mut acc: HashMap<String, Vec<String>>, i| {
454                acc.entry(i.region.clone()).or_default().push(i.id.clone());
455                acc
456            },
457        );
458        let mut interval = time::interval(Duration::from_secs(5));
459        interval.tick().await; // The first tick returns immediately.
460
461        let start = Instant::now();
462        loop {
463            let now = interval.tick().await;
464            let elapsed = now.duration_since(start).as_secs_f64().ceil() as u64;
465            display::status(format!("{elapsed}s"));
466            let instances = self
467                .client
468                .list_instances_by_region_and_ids(&instance_region_and_ids)
469                .await?;
470
471            let futures = instances.iter().map(|instance| {
472                let private_key_file = self.settings.ssh_private_key_file.clone();
473                SshConnection::new(
474                    instance.ssh_address(),
475                    C::USERNAME,
476                    private_key_file,
477                    None,
478                    None,
479                )
480            });
481            if try_join_all(futures).await.is_ok() {
482                break;
483            }
484        }
485        Ok(())
486    }
487}
488
489#[cfg(test)]
490mod test {
491    use crate::{
492        client::{InstanceRole, ServerProviderClient, test_client::TestClient},
493        settings::Settings,
494        testbed::Testbed,
495    };
496
497    #[tokio::test]
498    async fn deploy() {
499        let settings = Settings::new_for_test();
500        let client = TestClient::new(settings.clone());
501        let mut testbed = Testbed::new(settings, client).await.unwrap();
502
503        testbed
504            .deploy(5, true, 0, false, "test".to_string())
505            .await
506            .unwrap();
507
508        assert_eq!(
509            testbed.node_instances.len(),
510            5 * testbed.settings.number_of_regions()
511        );
512        for (i, instance) in testbed.node_instances.iter().enumerate() {
513            assert_eq!(i.to_string(), instance.id);
514        }
515    }
516
517    #[tokio::test]
518    async fn destroy() {
519        let settings = Settings::new_for_test();
520        let client = TestClient::new(settings.clone());
521        let mut testbed = Testbed::new(settings, client).await.unwrap();
522
523        testbed.destroy(false, true).await.unwrap();
524
525        assert_eq!(testbed.node_instances.len(), 0);
526    }
527
528    #[tokio::test]
529    async fn start() {
530        let settings = Settings::new_for_test();
531        let client = TestClient::new(settings.clone());
532        let mut testbed = Testbed::new(settings, client).await.unwrap();
533        testbed
534            .deploy(5, true, 0, false, "test".to_string())
535            .await
536            .unwrap();
537        testbed.stop(false).await.unwrap();
538
539        let result = testbed.start(2, 0, true).await;
540
541        assert!(result.is_ok());
542        for region in &testbed.settings.regions {
543            let active = testbed
544                .client
545                .instances()
546                .iter()
547                .filter(|i| i.role == InstanceRole::Node)
548                .filter(|x| x.is_active() && &x.region == region)
549                .count();
550            assert_eq!(active, 2);
551
552            let inactive = testbed
553                .client
554                .instances()
555                .iter()
556                .filter(|i| i.role == InstanceRole::Node)
557                .filter(|x| x.is_inactive() && &x.region == region)
558                .count();
559            assert_eq!(inactive, 3);
560        }
561    }
562
563    #[tokio::test]
564    async fn stop() {
565        let settings = Settings::new_for_test();
566        let client = TestClient::new(settings.clone());
567        let mut testbed = Testbed::new(settings, client).await.unwrap();
568        testbed
569            .deploy(5, true, 0, false, "test".to_string())
570            .await
571            .unwrap();
572        testbed.start(2, 0, true).await.unwrap();
573
574        testbed.stop(false).await.unwrap();
575
576        assert!(testbed.client.instances().iter().all(|x| x.is_inactive()))
577    }
578}