iota_aws_orchestrator/
testbed.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::time::Duration;
6
7use futures::future::try_join_all;
8use prettytable::{Table, row};
9use tokio::time::{self, Instant};
10
11use super::client::Instance;
12use crate::{
13    client::ServerProviderClient,
14    display,
15    error::{TestbedError, TestbedResult},
16    settings::Settings,
17    ssh::SshConnection,
18};
19
20/// Represents a testbed running on a cloud provider.
21pub struct Testbed<C> {
22    /// The testbed's settings.
23    settings: Settings,
24    /// The client interfacing with the cloud provider.
25    client: C,
26    /// The state of the testbed (reflecting accurately the state of the
27    /// machines).
28    instances: Vec<Instance>,
29}
30
31impl<C: ServerProviderClient> Testbed<C> {
32    /// Create a new testbed instance with the specified settings and client.
33    pub async fn new(settings: Settings, client: C) -> TestbedResult<Self> {
34        let public_key = settings.load_ssh_public_key()?;
35        client.register_ssh_public_key(public_key).await?;
36        let instances = client.list_instances().await?;
37
38        Ok(Self {
39            settings,
40            client,
41            instances,
42        })
43    }
44
45    /// Return the username to connect to the instances through ssh.
46    pub fn username(&self) -> &'static str {
47        C::USERNAME
48    }
49
50    /// Return the list of instances of the testbed.
51    pub fn instances(&self) -> Vec<Instance> {
52        self.instances
53            .iter()
54            .filter(|x| self.settings.filter_instances(x))
55            .cloned()
56            .collect()
57    }
58
59    /// Return the list of provider-specific instance setup commands.
60    pub async fn setup_commands(&self) -> TestbedResult<Vec<String>> {
61        self.client
62            .instance_setup_commands()
63            .await
64            .map_err(TestbedError::from)
65    }
66
67    /// Print the current status of the testbed.
68    pub fn status(&self) {
69        let filtered = self
70            .instances
71            .iter()
72            .filter(|instance| self.settings.filter_instances(instance));
73        let sorted: Vec<(_, Vec<_>)> = self
74            .settings
75            .regions
76            .iter()
77            .map(|region| {
78                (
79                    region,
80                    filtered
81                        .clone()
82                        .filter(|instance| &instance.region == region)
83                        .collect(),
84                )
85            })
86            .collect();
87
88        let mut table = Table::new();
89        table.set_format(display::default_table_format());
90
91        let active = filtered.filter(|x| x.is_active()).count();
92        table.set_titles(row![bH2->format!("Instances ({active})")]);
93        for (i, (region, instances)) in sorted.iter().enumerate() {
94            table.add_row(row![bH2->region.to_uppercase()]);
95            let mut j = 0;
96            for instance in instances {
97                if j % 5 == 0 {
98                    table.add_row(row![]);
99                }
100                let private_key_file = self.settings.ssh_private_key_file.display();
101                let username = C::USERNAME;
102                let ip = instance.main_ip;
103                let connect = format!("ssh -i {private_key_file} {username}@{ip}");
104                if !instance.is_terminated() {
105                    if instance.is_active() {
106                        table.add_row(row![bFg->format!("{j}"), connect]);
107                    } else {
108                        table.add_row(row![bFr->format!("{j}"), connect]);
109                    }
110                    j += 1;
111                }
112            }
113            if i != sorted.len() - 1 {
114                table.add_row(row![]);
115            }
116        }
117
118        display::newline();
119        display::config("Client", &self.client);
120        let repo = &self.settings.repository;
121        display::config("Repo", format!("{} ({})", repo.url, repo.commit));
122        display::newline();
123        table.printstd();
124        display::newline();
125    }
126
127    /// Populate the testbed by creating the specified amount of instances per
128    /// region. The total number of instances created is thus the specified
129    /// amount x the number of regions.
130    pub async fn deploy(&mut self, quantity: usize, region: Option<String>) -> TestbedResult<()> {
131        display::action(format!("Deploying instances ({quantity} per region)"));
132
133        let instances = match region {
134            Some(x) => {
135                try_join_all((0..quantity).map(|_| self.client.create_instance(x.clone()))).await?
136            }
137            None => {
138                try_join_all(self.settings.regions.iter().flat_map(|region| {
139                    (0..quantity).map(|_| self.client.create_instance(region.clone()))
140                }))
141                .await?
142            }
143        };
144
145        // Wait until the instances are booted.
146        if cfg!(not(test)) {
147            self.wait_until_reachable(instances.iter()).await?;
148        }
149        self.instances = self.client.list_instances().await?;
150
151        display::done();
152        Ok(())
153    }
154
155    /// Destroy all instances of the testbed.
156    pub async fn destroy(&mut self) -> TestbedResult<()> {
157        display::action("Destroying testbed");
158
159        try_join_all(
160            self.instances
161                .drain(..)
162                .map(|instance| self.client.delete_instance(instance)),
163        )
164        .await?;
165
166        display::done();
167        Ok(())
168    }
169
170    /// Start the specified number of instances in each region. Returns an error
171    /// if there are not enough available instances.
172    pub async fn start(&mut self, quantity: usize) -> TestbedResult<()> {
173        display::action("Booting instances");
174
175        // Gather available instances.
176        let mut available = Vec::new();
177        for region in &self.settings.regions {
178            available.extend(
179                self.instances
180                    .iter()
181                    .filter(|x| {
182                        x.is_inactive() && &x.region == region && self.settings.filter_instances(x)
183                    })
184                    .take(quantity)
185                    .cloned()
186                    .collect::<Vec<_>>(),
187            );
188        }
189
190        // Start instances.
191        self.client.start_instances(available.iter()).await?;
192
193        // Wait until the instances are started.
194        if cfg!(not(test)) {
195            self.wait_until_reachable(available.iter()).await?;
196        }
197        self.instances = self.client.list_instances().await?;
198
199        display::done();
200        Ok(())
201    }
202
203    /// Stop all instances of the testbed.
204    pub async fn stop(&mut self) -> TestbedResult<()> {
205        display::action("Stopping instances");
206
207        // Stop all instances.
208        self.client
209            .stop_instances(self.instances.iter().filter(|i| i.is_active()))
210            .await?;
211
212        // Wait until the instances are stopped.
213        loop {
214            let instances = self.client.list_instances().await?;
215            if instances.iter().all(|x| x.is_inactive()) {
216                self.instances = instances;
217                break;
218            }
219        }
220
221        display::done();
222        Ok(())
223    }
224
225    /// Wait until all specified instances are ready to accept ssh connections.
226    async fn wait_until_reachable<'a, I>(&self, instances: I) -> TestbedResult<()>
227    where
228        I: Iterator<Item = &'a Instance> + Clone,
229    {
230        let instances_ids: Vec<_> = instances.map(|x| x.id.clone()).collect();
231
232        let mut interval = time::interval(Duration::from_secs(5));
233        interval.tick().await; // The first tick returns immediately.
234
235        let start = Instant::now();
236        loop {
237            let now = interval.tick().await;
238            let elapsed = now.duration_since(start).as_secs_f64().ceil() as u64;
239            display::status(format!("{elapsed}s"));
240
241            let instances = self.client.list_instances().await?;
242            let futures = instances
243                .iter()
244                .filter(|x| instances_ids.contains(&x.id))
245                .map(|instance| {
246                    let private_key_file = self.settings.ssh_private_key_file.clone();
247                    SshConnection::new(
248                        instance.ssh_address(),
249                        C::USERNAME,
250                        private_key_file,
251                        None,
252                        None,
253                    )
254                });
255            if try_join_all(futures).await.is_ok() {
256                break;
257            }
258        }
259        Ok(())
260    }
261}
262
263#[cfg(test)]
264mod test {
265    use crate::{client::test_client::TestClient, settings::Settings, testbed::Testbed};
266
267    #[tokio::test]
268    async fn deploy() {
269        let settings = Settings::new_for_test();
270        let client = TestClient::new(settings.clone());
271        let mut testbed = Testbed::new(settings, client).await.unwrap();
272
273        testbed.deploy(5, None).await.unwrap();
274
275        assert_eq!(
276            testbed.instances.len(),
277            5 * testbed.settings.number_of_regions()
278        );
279        for (i, instance) in testbed.instances.iter().enumerate() {
280            assert_eq!(i.to_string(), instance.id);
281        }
282    }
283
284    #[tokio::test]
285    async fn destroy() {
286        let settings = Settings::new_for_test();
287        let client = TestClient::new(settings.clone());
288        let mut testbed = Testbed::new(settings, client).await.unwrap();
289
290        testbed.destroy().await.unwrap();
291
292        assert_eq!(testbed.instances.len(), 0);
293    }
294
295    #[tokio::test]
296    async fn start() {
297        let settings = Settings::new_for_test();
298        let client = TestClient::new(settings.clone());
299        let mut testbed = Testbed::new(settings, client).await.unwrap();
300        testbed.deploy(5, None).await.unwrap();
301        testbed.stop().await.unwrap();
302
303        let result = testbed.start(2).await;
304
305        assert!(result.is_ok());
306        for region in &testbed.settings.regions {
307            let active = testbed
308                .instances
309                .iter()
310                .filter(|x| x.is_active() && &x.region == region)
311                .count();
312            assert_eq!(active, 2);
313
314            let inactive = testbed
315                .instances
316                .iter()
317                .filter(|x| x.is_inactive() && &x.region == region)
318                .count();
319            assert_eq!(inactive, 3);
320        }
321    }
322
323    #[tokio::test]
324    async fn stop() {
325        let settings = Settings::new_for_test();
326        let client = TestClient::new(settings.clone());
327        let mut testbed = Testbed::new(settings, client).await.unwrap();
328        testbed.deploy(5, None).await.unwrap();
329        testbed.start(2).await.unwrap();
330
331        testbed.stop().await.unwrap();
332
333        assert!(testbed.instances.iter().all(|x| x.is_inactive()))
334    }
335}