iota_aws_orchestrator/client/
aws.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::HashMap,
7    fmt::{Debug, Display},
8};
9
10use aws_runtime::env_config::file::{EnvConfigFileKind, EnvConfigFiles};
11use aws_sdk_ec2::{
12    config::Region,
13    primitives::Blob,
14    types::{
15        BlockDeviceMapping, EbsBlockDevice, EphemeralNvmeSupport, Filter, ResourceType, Tag,
16        TagSpecification, VolumeType,
17    },
18};
19use aws_smithy_runtime_api::client::{behavior_version::BehaviorVersion, result::SdkError};
20use serde::Serialize;
21
22use super::{Instance, ServerProviderClient};
23use crate::{
24    error::{CloudProviderError, CloudProviderResult},
25    settings::Settings,
26};
27
28// Make a request error from an AWS error message.
29impl<T> From<SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>
30    for CloudProviderError
31where
32    T: Debug + std::error::Error + Send + Sync + 'static,
33{
34    fn from(e: SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>) -> Self {
35        Self::Request(format!("{:?}", e.into_source()))
36    }
37}
38
39/// A AWS client.
40pub struct AwsClient {
41    /// The settings of the testbed.
42    settings: Settings,
43    /// A list of clients, one per AWS region.
44    clients: HashMap<String, aws_sdk_ec2::Client>,
45}
46
47impl Display for AwsClient {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "AWS EC2 client v{}", aws_sdk_ec2::meta::PKG_VERSION)
50    }
51}
52
53impl AwsClient {
54    const OS_IMAGE: &'static str =
55        "Canonical, Ubuntu, 22.04 LTS, amd64 jammy image build on 2023-02-16";
56
57    /// Make a new AWS client.
58    pub async fn new(settings: Settings) -> Self {
59        let profile_files = EnvConfigFiles::builder()
60            .with_file(EnvConfigFileKind::Credentials, &settings.token_file)
61            .with_contents(EnvConfigFileKind::Config, "[default]\noutput=json")
62            .build();
63
64        let mut clients = HashMap::new();
65        for region in settings.regions.clone() {
66            let sdk_config = aws_config::defaults(BehaviorVersion::latest())
67                .region(Region::new(region.clone()))
68                .profile_files(profile_files.clone())
69                .load()
70                .await;
71            let client = aws_sdk_ec2::Client::new(&sdk_config);
72            clients.insert(region, client);
73        }
74
75        Self { settings, clients }
76    }
77
78    /// Parse an AWS response and ignore errors if they mean a request is a
79    /// duplicate.
80    fn check_but_ignore_duplicates<T, E>(
81        response: Result<
82            T,
83            SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>,
84        >,
85    ) -> CloudProviderResult<()>
86    where
87        E: Debug + std::error::Error + Send + Sync + 'static,
88    {
89        if let Err(e) = response {
90            let error_message = format!("{e:?}");
91            if !error_message.to_lowercase().contains("duplicate") {
92                return Err(e.into());
93            }
94        }
95        Ok(())
96    }
97
98    /// Convert an AWS instance into an orchestrator instance (used in the rest
99    /// of the codebase).
100    fn make_instance(
101        &self,
102        region: String,
103        aws_instance: &aws_sdk_ec2::types::Instance,
104    ) -> Instance {
105        Instance {
106            id: aws_instance
107                .instance_id()
108                .expect("AWS instance should have an id")
109                .into(),
110            region,
111            main_ip: aws_instance
112                .public_ip_address()
113                .unwrap_or("0.0.0.0") // Stopped instances do not have an ip address.
114                .parse()
115                .expect("AWS instance should have a valid ip"),
116            tags: vec![self.settings.testbed_id.clone()],
117            specs: format!(
118                "{:?}",
119                aws_instance
120                    .instance_type()
121                    .expect("AWS instance should have a type")
122            ),
123            status: format!(
124                "{:?}",
125                aws_instance
126                    .state()
127                    .expect("AWS instance should have a state")
128                    .name()
129                    .expect("AWS status should have a name")
130            ),
131        }
132    }
133
134    /// Query the image id determining the os of the instances.
135    /// NOTE: The image id changes depending on the region.
136    async fn find_image_id(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<String> {
137        // Query all images that match the description.
138        let request = client.describe_images().filters(
139            Filter::builder()
140                .name("description")
141                .values(Self::OS_IMAGE)
142                .build(),
143        );
144        let response = request.send().await?;
145
146        // Parse the response to select the first returned image id.
147        response
148            .images()
149            .first()
150            .ok_or_else(|| CloudProviderError::Request("Cannot find image id".into()))?
151            .image_id
152            .clone()
153            .ok_or_else(|| {
154                CloudProviderError::UnexpectedResponse(
155                    "Received image description without id".into(),
156                )
157            })
158    }
159
160    /// Create a new security group for the instance (if it doesn't already
161    /// exist).
162    async fn create_security_group(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<()> {
163        // Create a security group (if it doesn't already exist).
164        let request = client
165            .create_security_group()
166            .group_name(&self.settings.testbed_id)
167            .description("Allow all traffic (used for benchmarks).");
168
169        let response = request.send().await;
170        Self::check_but_ignore_duplicates(response)?;
171
172        // Authorize all traffic on the security group.
173        for protocol in ["tcp", "udp", "icmp", "icmpv6"] {
174            let mut request = client
175                .authorize_security_group_ingress()
176                .group_name(&self.settings.testbed_id)
177                .ip_protocol(protocol)
178                .cidr_ip("0.0.0.0/0"); // todo - allowing 0.0.0.0 seem a bit wild?
179            if protocol == "icmp" || protocol == "icmpv6" {
180                request = request.from_port(-1).to_port(-1);
181            } else {
182                request = request.from_port(0).to_port(65535);
183            }
184
185            let response = request.send().await;
186            Self::check_but_ignore_duplicates(response)?;
187        }
188        Ok(())
189    }
190
191    /// Return the command to mount the first (standard) NVMe drive.
192    fn nvme_mount_command(&self) -> Vec<String> {
193        const DRIVE: &str = "nvme1n1";
194        let directory = self.settings.working_dir.display();
195        vec![
196            format!("(sudo mkfs.ext4 -E nodiscard /dev/{DRIVE} || true)"),
197            format!("(sudo mount /dev/{DRIVE} {directory} || true)"),
198            format!("sudo chmod 777 -R {directory}"),
199        ]
200    }
201
202    /// Check whether the instance type specified in the settings supports NVMe
203    /// drives.
204    async fn check_nvme_support(&self) -> CloudProviderResult<bool> {
205        // Get the client for the first region. A given instance type should either have
206        // NVMe support in all regions or in none.
207        let client = match self
208            .settings
209            .regions
210            .first()
211            .and_then(|x| self.clients.get(x))
212        {
213            Some(client) => client,
214            None => return Ok(false),
215        };
216
217        // Request storage details for the instance type specified in the settings.
218        let request = client
219            .describe_instance_types()
220            .instance_types(self.settings.specs.as_str().into());
221
222        // Send the request.
223        let response = request.send().await?;
224
225        // Return true if the response contains references to NVMe drives.
226        if let Some(info) = response.instance_types().first() {
227            if let Some(info) = info.instance_storage_info() {
228                if info.nvme_support() == Some(&EphemeralNvmeSupport::Required) {
229                    return Ok(true);
230                }
231            }
232        }
233        Ok(false)
234    }
235}
236
237#[async_trait::async_trait]
238impl ServerProviderClient for AwsClient {
239    const USERNAME: &'static str = "ubuntu";
240
241    async fn list_instances(&self) -> CloudProviderResult<Vec<Instance>> {
242        let filter = Filter::builder()
243            .name("tag:Name")
244            .values(self.settings.testbed_id.clone())
245            .build();
246
247        let mut instances = Vec::new();
248        for (region, client) in &self.clients {
249            let request = client.describe_instances().filters(filter.clone());
250            for reservation in request.send().await?.reservations() {
251                for instance in reservation.instances() {
252                    instances.push(self.make_instance(region.clone(), instance));
253                }
254            }
255        }
256
257        Ok(instances)
258    }
259
260    async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
261    where
262        I: Iterator<Item = &'a Instance> + Send,
263    {
264        let mut instance_ids = HashMap::new();
265        for instance in instances {
266            instance_ids
267                .entry(&instance.region)
268                .or_insert_with(Vec::new)
269                .push(instance.id.clone());
270        }
271
272        for (region, client) in &self.clients {
273            let ids = instance_ids.remove(&region.to_string());
274            if ids.is_some() {
275                client
276                    .start_instances()
277                    .set_instance_ids(ids)
278                    .send()
279                    .await?;
280            }
281        }
282        Ok(())
283    }
284
285    async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
286    where
287        I: Iterator<Item = &'a Instance> + Send,
288    {
289        let mut instance_ids = HashMap::new();
290        for instance in instances {
291            instance_ids
292                .entry(&instance.region)
293                .or_insert_with(Vec::new)
294                .push(instance.id.clone());
295        }
296
297        for (region, client) in &self.clients {
298            let ids = instance_ids.remove(&region.to_string());
299            if ids.is_some() {
300                client.stop_instances().set_instance_ids(ids).send().await?;
301            }
302        }
303        Ok(())
304    }
305
306    async fn create_instance<S>(&self, region: S) -> CloudProviderResult<Instance>
307    where
308        S: Into<String> + Serialize + Send,
309    {
310        let region = region.into();
311        let testbed_id = &self.settings.testbed_id;
312
313        let client = self
314            .clients
315            .get(&region)
316            .ok_or_else(|| CloudProviderError::Request(format!("Undefined region {region:?}")))?;
317
318        // Create a security group (if needed).
319        self.create_security_group(client).await?;
320
321        // Query the image id.
322        let image_id = self.find_image_id(client).await?;
323
324        // Create a new instance.
325        let tags = TagSpecification::builder()
326            .resource_type(ResourceType::Instance)
327            .tags(Tag::builder().key("Name").value(testbed_id).build())
328            .build();
329
330        let storage = BlockDeviceMapping::builder()
331            .device_name("/dev/sda1")
332            .ebs(
333                EbsBlockDevice::builder()
334                    .delete_on_termination(true)
335                    .volume_size(500)
336                    .volume_type(VolumeType::Gp2)
337                    .build(),
338            )
339            .build();
340
341        let request = client
342            .run_instances()
343            .image_id(image_id)
344            .instance_type(self.settings.specs.as_str().into())
345            .key_name(testbed_id)
346            .min_count(1)
347            .max_count(1)
348            .security_groups(&self.settings.testbed_id)
349            .block_device_mappings(storage)
350            .tag_specifications(tags);
351
352        let response = request.send().await?;
353        let instance = &response
354            .instances()
355            .first()
356            .expect("AWS instances list should contain instances");
357
358        Ok(self.make_instance(region, instance))
359    }
360
361    async fn delete_instance(&self, instance: Instance) -> CloudProviderResult<()> {
362        let client = self.clients.get(&instance.region).ok_or_else(|| {
363            CloudProviderError::Request(format!("Undefined region {:?}", instance.region))
364        })?;
365
366        client
367            .terminate_instances()
368            .set_instance_ids(Some(vec![instance.id.clone()]))
369            .send()
370            .await?;
371
372        Ok(())
373    }
374
375    async fn register_ssh_public_key(&self, public_key: String) -> CloudProviderResult<()> {
376        for client in self.clients.values() {
377            let request = client
378                .import_key_pair()
379                .key_name(&self.settings.testbed_id)
380                .public_key_material(Blob::new::<String>(public_key.clone()));
381
382            let response = request.send().await;
383            Self::check_but_ignore_duplicates(response)?;
384        }
385        Ok(())
386    }
387
388    async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>> {
389        if self.check_nvme_support().await? {
390            Ok(self.nvme_mount_command())
391        } else {
392            Ok(Vec::new())
393        }
394    }
395}