iota_aws_orchestrator/client/
aws.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::HashMap,
7    fmt::{Debug, Display},
8};
9
10use aws_runtime::env_config::file::{EnvConfigFileKind, EnvConfigFiles};
11use aws_sdk_ec2::{
12    config::Region,
13    primitives::Blob,
14    types::{
15        BlockDeviceMapping, EbsBlockDevice, EphemeralNvmeSupport, Filter, ResourceType, Tag,
16        TagSpecification, VolumeType, builders::FilterBuilder,
17    },
18};
19use aws_smithy_runtime_api::client::{behavior_version::BehaviorVersion, result::SdkError};
20use serde::Serialize;
21
22use super::{Instance, InstanceRole, ServerProviderClient};
23use crate::{
24    error::{CloudProviderError, CloudProviderResult},
25    settings::Settings,
26};
27
28// Make a request error from an AWS error message.
29impl<T> From<SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>
30    for CloudProviderError
31where
32    T: Debug + std::error::Error + Send + Sync + 'static,
33{
34    fn from(e: SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>) -> Self {
35        Self::Request(format!("{:?}", e.into_source()))
36    }
37}
38
39/// A AWS client.
40pub struct AwsClient {
41    /// The settings of the testbed.
42    settings: Settings,
43    /// A list of clients, one per AWS region.
44    clients: HashMap<String, aws_sdk_ec2::Client>,
45}
46
47impl Display for AwsClient {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "AWS EC2 client v{}", aws_sdk_ec2::meta::PKG_VERSION)
50    }
51}
52
53impl AwsClient {
54    const UBUNTU_NAME_PATTERN: &'static str =
55        "ubuntu/images/hvm-ssd-gp3/ubuntu-noble-24.04-amd64-server-*";
56    const CANONICAL_OWNER_ID: &'static str = "099720109477";
57
58    /// Make a new AWS client.
59    pub async fn new(settings: Settings) -> Self {
60        let profile_files = EnvConfigFiles::builder()
61            .with_file(EnvConfigFileKind::Credentials, &settings.token_file)
62            .with_contents(EnvConfigFileKind::Config, "[default]\noutput=json")
63            .build();
64
65        let mut clients = HashMap::new();
66        for region in settings.regions.clone() {
67            let sdk_config = aws_config::defaults(BehaviorVersion::latest())
68                .region(Region::new(region.clone()))
69                .profile_files(profile_files.clone())
70                .load()
71                .await;
72            let client = aws_sdk_ec2::Client::new(&sdk_config);
73            clients.insert(region, client);
74        }
75
76        Self { settings, clients }
77    }
78
79    /// Parse an AWS response and ignore errors if they mean a request is a
80    /// duplicate.
81    fn check_but_ignore_duplicates<T, E>(
82        response: Result<
83            T,
84            SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>,
85        >,
86    ) -> CloudProviderResult<()>
87    where
88        E: Debug + std::error::Error + Send + Sync + 'static,
89    {
90        if let Err(e) = response {
91            let error_message = format!("{e:?}");
92            if !error_message.to_lowercase().contains("duplicate") {
93                return Err(e.into());
94            }
95        }
96        Ok(())
97    }
98    fn get_tag_value(instance: &aws_sdk_ec2::types::Instance, key: &str) -> Option<String> {
99        instance
100            .tags()
101            .iter()
102            .find(|tag| tag.key().is_some_and(|k| k == key))
103            .and_then(|tag| tag.value().map(|v| v.to_string()))
104    }
105    /// Convert an AWS instance into an orchestrator instance (used in the rest
106    /// of the codebase).
107    fn make_instance(
108        &self,
109        region: String,
110        aws_instance: &aws_sdk_ec2::types::Instance,
111    ) -> Instance {
112        let role: InstanceRole = Self::get_tag_value(aws_instance, "Role")
113            .expect("AWS instance should have a role")
114            .as_str()
115            .into();
116        Instance {
117            id: aws_instance
118                .instance_id()
119                .expect("AWS instance should have an id")
120                .into(),
121            region,
122            main_ip: aws_instance
123                .public_ip_address()
124                .unwrap_or("0.0.0.0") // Stopped instances do not have an ip address.
125                .parse()
126                .expect("AWS instance should have a valid ip"),
127            private_ip: aws_instance
128                .private_ip_address()
129                .unwrap_or("0.0.0.0") // Stopped instances do not have an ip address.
130                .parse()
131                .expect("AWS instance should have a valid ip"),
132            tags: vec![self.settings.testbed_id.clone()],
133            specs: format!(
134                "{:?}",
135                aws_instance
136                    .instance_type()
137                    .expect("AWS instance should have a type")
138            ),
139            status: format!(
140                "{:?}",
141                aws_instance
142                    .state()
143                    .expect("AWS instance should have a state")
144                    .name()
145                    .expect("AWS status should have a name")
146            ),
147            role,
148        }
149    }
150
151    /// Query the image id determining the os of the instances.
152    /// NOTE: The image id changes depending on the region.
153    async fn find_image_id(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<String> {
154        // Use a more general filter that doesn't depend on specific build dates
155        let filters = [
156            // Filter for Ubuntu 24.04 LTS
157            FilterBuilder::default()
158                .name("name")
159                .values(Self::UBUNTU_NAME_PATTERN)
160                .build(),
161            // Only look at images from Canonical
162            FilterBuilder::default()
163                .name("owner-id")
164                .values(Self::CANONICAL_OWNER_ID)
165                .build(),
166            // Only want available images
167            FilterBuilder::default()
168                .name("state")
169                .values("available")
170                .build(),
171        ];
172
173        // Query images with these filters
174        let request = client.describe_images().set_filters(Some(filters.to_vec()));
175        let response = request.send().await?;
176
177        // Sort images by creation date (newest first)
178        let mut images = response.images().to_vec();
179        images.sort_by(|a, b| {
180            let a_date = a.creation_date().unwrap_or("");
181            let b_date = b.creation_date().unwrap_or("");
182            b_date.cmp(a_date) // Reverse order to get newest first
183        });
184
185        // Select the newest image
186        let image = images
187            .first()
188            .ok_or_else(|| CloudProviderError::Request("Cannot find Ubuntu 24.04 image".into()))?;
189
190        image
191            .image_id
192            .clone()
193            .ok_or_else(|| CloudProviderError::UnexpectedResponse("Image without ID".into()))
194    }
195
196    /// Create a new security group for the instance (if it doesn't already
197    /// exist).
198    async fn create_security_group(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<()> {
199        // Create a security group (if it doesn't already exist).
200        let request = client
201            .create_security_group()
202            .group_name(&self.settings.testbed_id)
203            .description("Allow all traffic (used for benchmarks).");
204
205        let response = request.send().await;
206        Self::check_but_ignore_duplicates(response)?;
207
208        // Authorize all traffic on the security group.
209        for protocol in ["tcp", "udp", "icmp", "icmpv6"] {
210            let mut request = client
211                .authorize_security_group_ingress()
212                .group_name(&self.settings.testbed_id)
213                .ip_protocol(protocol)
214                .cidr_ip("0.0.0.0/0"); // todo - allowing 0.0.0.0 seem a bit wild?
215            if protocol == "icmp" || protocol == "icmpv6" {
216                request = request.from_port(-1).to_port(-1);
217            } else {
218                request = request.from_port(0).to_port(65535);
219            }
220
221            let response = request.send().await;
222            Self::check_but_ignore_duplicates(response)?;
223        }
224        Ok(())
225    }
226
227    /// Return the command to mount the first (standard) NVMe drive.
228    fn nvme_mount_command(&self) -> Vec<String> {
229        const DRIVE: &str = "nvme1n1";
230        let directory = self.settings.working_dir.display();
231        vec![
232            format!("(sudo mkfs.ext4 -E nodiscard /dev/{DRIVE} || true)"),
233            format!("(sudo mount /dev/{DRIVE} {directory} || true)"),
234            format!("sudo chmod 777 -R {directory}"),
235        ]
236    }
237
238    /// Check whether the instance type specified in the settings supports NVMe
239    /// drives.
240    async fn check_nvme_support(&self) -> CloudProviderResult<bool> {
241        // Get the client for the first region. A given instance type should either have
242        // NVMe support in all regions or in none.
243        let client = match self
244            .settings
245            .regions
246            .first()
247            .and_then(|x| self.clients.get(x))
248        {
249            Some(client) => client,
250            None => return Ok(false),
251        };
252
253        // Request storage details for the instance type specified in the settings.
254        let request = client
255            .describe_instance_types()
256            .instance_types(self.settings.node_specs.as_str().into());
257
258        // Send the request.
259        let response = request.send().await?;
260
261        // Return true if the response contains references to NVMe drives.
262        if let Some(info) = response.instance_types().first() {
263            if let Some(info) = info.instance_storage_info() {
264                if info.nvme_support() == Some(&EphemeralNvmeSupport::Required) {
265                    return Ok(true);
266                }
267            }
268        }
269        Ok(false)
270    }
271}
272
273#[async_trait::async_trait]
274impl ServerProviderClient for AwsClient {
275    const USERNAME: &'static str = "ubuntu";
276
277    async fn list_instances_by_role(
278        &self,
279        role: InstanceRole,
280    ) -> CloudProviderResult<Vec<Instance>> {
281        let filter_name = Filter::builder()
282            .name("tag:Name")
283            .values(self.settings.testbed_id.clone())
284            .build();
285        let filter_role = Filter::builder()
286            .name("tag:Role")
287            .values(role.to_string())
288            .build();
289        let filter_state = Filter::builder()
290            .name("instance-state-name")
291            .values("pending")
292            .values("running")
293            .values("stopping")
294            .values("stopped")
295            .build();
296
297        let mut instances = Vec::new();
298        for (region, client) in &self.clients {
299            let request = client.describe_instances().set_filters(Some(vec![
300                filter_name.clone(),
301                filter_role.clone(),
302                filter_state.clone(),
303            ]));
304            let response = request.send().await?;
305            for reservation in response.reservations() {
306                for instance in reservation.instances() {
307                    instances.push(self.make_instance(region.clone(), instance));
308                }
309            }
310        }
311
312        Ok(instances)
313    }
314
315    async fn list_instances_by_region_and_ids(
316        &self,
317        regions_and_ids: Vec<(String, String)>,
318    ) -> CloudProviderResult<Vec<Instance>> {
319        let mut instances = Vec::new();
320        let ids_by_region: HashMap<String, Vec<String>> =
321            regions_and_ids
322                .into_iter()
323                .fold(HashMap::new(), |mut acc, (k, v)| {
324                    acc.entry(k).or_default().push(v);
325                    acc
326                });
327        for (region, client) in &self.clients {
328            let request = client
329                .describe_instances()
330                .set_instance_ids(ids_by_region.get(region).cloned());
331            let response = request.send().await?;
332            for reservation in response.reservations() {
333                for instance in reservation.instances() {
334                    instances.push(self.make_instance(region.clone(), instance));
335                }
336            }
337        }
338
339        Ok(instances)
340    }
341
342    async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
343    where
344        I: Iterator<Item = &'a Instance> + Send,
345    {
346        let mut instance_ids = HashMap::new();
347        for instance in instances {
348            instance_ids
349                .entry(&instance.region)
350                .or_insert_with(Vec::new)
351                .push(instance.id.clone());
352        }
353
354        for (region, client) in &self.clients {
355            let ids = instance_ids.remove(&region.to_string());
356            if ids.is_some() {
357                client
358                    .start_instances()
359                    .set_instance_ids(ids)
360                    .send()
361                    .await?;
362            }
363        }
364        Ok(())
365    }
366
367    async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
368    where
369        I: Iterator<Item = &'a Instance> + Send,
370    {
371        let mut instance_ids = HashMap::new();
372        for instance in instances {
373            instance_ids
374                .entry(&instance.region)
375                .or_insert_with(Vec::new)
376                .push(instance.id.clone());
377        }
378
379        for (region, client) in &self.clients {
380            let ids = instance_ids.remove(&region.to_string());
381            if ids.is_some() {
382                client.stop_instances().set_instance_ids(ids).send().await?;
383            }
384        }
385        Ok(())
386    }
387
388    async fn create_instance<S>(
389        &self,
390        region: S,
391        role: InstanceRole,
392    ) -> CloudProviderResult<Instance>
393    where
394        S: Into<String> + Serialize + Send,
395    {
396        let region = region.into();
397        let testbed_id = &self.settings.testbed_id;
398
399        let client = self
400            .clients
401            .get(&region)
402            .ok_or_else(|| CloudProviderError::Request(format!("Undefined region {region:?}")))?;
403
404        // Create a security group (if needed).
405        self.create_security_group(client).await?;
406
407        // Query the image id.
408        let image_id = self.find_image_id(client).await?;
409
410        // Create a new instance.
411        let tags = TagSpecification::builder()
412            .resource_type(ResourceType::Instance)
413            .tags(Tag::builder().key("Name").value(testbed_id).build())
414            .tags(Tag::builder().key("Role").value(role.to_string()).build())
415            .build();
416
417        let storage = BlockDeviceMapping::builder()
418            .device_name("/dev/sda1")
419            .ebs(
420                EbsBlockDevice::builder()
421                    .delete_on_termination(true)
422                    .volume_size(500)
423                    .volume_type(VolumeType::Gp2)
424                    .build(),
425            )
426            .build();
427        let instance_type = match role {
428            InstanceRole::Node => &self.settings.node_specs,
429            InstanceRole::Metrics => &self.settings.metrics_specs,
430            InstanceRole::Client => &self.settings.client_specs,
431        };
432        let request = client
433            .run_instances()
434            .image_id(image_id)
435            .instance_type(instance_type.as_str().into())
436            .key_name(testbed_id)
437            .min_count(1)
438            .max_count(1)
439            .security_groups(&self.settings.testbed_id)
440            .block_device_mappings(storage)
441            .tag_specifications(tags);
442
443        let response = request.send().await?;
444        let instance = &response
445            .instances()
446            .first()
447            .expect("AWS instances list should contain instances");
448
449        Ok(self.make_instance(region, instance))
450    }
451
452    async fn delete_instance(&self, instance: Instance) -> CloudProviderResult<()> {
453        let client = self.clients.get(&instance.region).ok_or_else(|| {
454            CloudProviderError::Request(format!("Undefined region {:?}", instance.region))
455        })?;
456
457        client
458            .terminate_instances()
459            .set_instance_ids(Some(vec![instance.id.clone()]))
460            .send()
461            .await?;
462
463        Ok(())
464    }
465
466    async fn register_ssh_public_key(&self, public_key: String) -> CloudProviderResult<()> {
467        for client in self.clients.values() {
468            let request = client
469                .import_key_pair()
470                .key_name(&self.settings.testbed_id)
471                .public_key_material(Blob::new::<String>(public_key.clone()));
472
473            let response = request.send().await;
474            Self::check_but_ignore_duplicates(response)?;
475        }
476        Ok(())
477    }
478
479    async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>> {
480        if self.check_nvme_support().await? {
481            Ok(self.nvme_mount_command())
482        } else {
483            Ok(Vec::new())
484        }
485    }
486    #[cfg(test)]
487    fn instances(&self) -> Vec<Instance> {
488        // Only used under testing by the TestClient, unreachable cause no test
489        // should use AwsClient.
490        unreachable!()
491    }
492}