iota_aws_orchestrator/client/
aws.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::HashMap,
7    fmt::{Debug, Display},
8};
9
10use aws_runtime::env_config::file::{EnvConfigFileKind, EnvConfigFiles};
11use aws_sdk_ec2::{
12    config::Region,
13    primitives::Blob,
14    types::{
15        BlockDeviceMapping, EbsBlockDevice, EphemeralNvmeSupport, Filter,
16        InstanceInterruptionBehavior, InstanceMarketOptionsRequest, MarketType, ResourceType,
17        SpotInstanceType, SpotMarketOptions, Tag, TagSpecification, VolumeType,
18        builders::FilterBuilder,
19    },
20};
21use aws_smithy_runtime_api::client::{behavior_version::BehaviorVersion, result::SdkError};
22use serde::Serialize;
23
24use super::{Instance, InstanceLifecycle, InstanceRole, ServerProviderClient};
25use crate::{
26    display,
27    error::{CloudProviderError, CloudProviderResult},
28    settings::Settings,
29};
30
31// Make a request error from an AWS error message.
32impl<T> From<SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>
33    for CloudProviderError
34where
35    T: Debug + std::error::Error + Send + Sync + 'static,
36{
37    fn from(e: SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>) -> Self {
38        Self::Request(format!("{:?}", e.into_source()))
39    }
40}
41
42/// A AWS client.
43pub struct AwsClient {
44    /// The settings of the testbed.
45    settings: Settings,
46    /// A list of clients, one per AWS region.
47    clients: HashMap<String, aws_sdk_ec2::Client>,
48}
49
50impl Display for AwsClient {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        write!(f, "AWS EC2 client v{}", aws_sdk_ec2::meta::PKG_VERSION)
53    }
54}
55
56impl AwsClient {
57    const UBUNTU_NAME_PATTERN: &'static str =
58        "ubuntu/images/hvm-ssd-gp3/ubuntu-noble-24.04-amd64-server-*";
59    const CANONICAL_OWNER_ID: &'static str = "099720109477";
60
61    /// Make a new AWS client.
62    pub async fn new(settings: Settings) -> Self {
63        let profile_files = EnvConfigFiles::builder()
64            .with_file(EnvConfigFileKind::Credentials, &settings.token_file)
65            .with_contents(EnvConfigFileKind::Config, "[default]\noutput=json")
66            .build();
67
68        let mut clients = HashMap::new();
69        for region in settings.regions.clone() {
70            let sdk_config = aws_config::defaults(BehaviorVersion::latest())
71                .region(Region::new(region.clone()))
72                .profile_files(profile_files.clone())
73                .load()
74                .await;
75            let client = aws_sdk_ec2::Client::new(&sdk_config);
76            clients.insert(region, client);
77        }
78
79        Self { settings, clients }
80    }
81
82    /// Parse an AWS response and ignore errors if they mean a request is a
83    /// duplicate.
84    fn check_but_ignore_duplicates<T, E>(
85        response: Result<
86            T,
87            SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>,
88        >,
89    ) -> CloudProviderResult<()>
90    where
91        E: Debug + std::error::Error + Send + Sync + 'static,
92    {
93        if let Err(e) = response {
94            let error_message = format!("{e:?}");
95            if !error_message.to_lowercase().contains("duplicate") {
96                return Err(e.into());
97            }
98        }
99        Ok(())
100    }
101    fn get_tag_value(instance: &aws_sdk_ec2::types::Instance, key: &str) -> Option<String> {
102        instance
103            .tags()
104            .iter()
105            .find(|tag| tag.key().is_some_and(|k| k == key))
106            .and_then(|tag| tag.value().map(|v| v.to_string()))
107    }
108    /// Convert an AWS instance into an orchestrator instance (used in the rest
109    /// of the codebase).
110    fn make_instance(
111        &self,
112        region: String,
113        aws_instance: &aws_sdk_ec2::types::Instance,
114    ) -> Instance {
115        let role: InstanceRole = Self::get_tag_value(aws_instance, "Role")
116            .expect("AWS instance should have a role")
117            .as_str()
118            .into();
119        let lifecycle: InstanceLifecycle =
120            if let Some(aws_sdk_ec2::types::InstanceLifecycleType::Spot) =
121                aws_instance.instance_lifecycle
122            {
123                InstanceLifecycle::Spot
124            } else {
125                InstanceLifecycle::OnDemand
126            };
127        Instance {
128            id: aws_instance
129                .instance_id()
130                .expect("AWS instance should have an id")
131                .into(),
132            region,
133            main_ip: aws_instance
134                .public_ip_address()
135                .unwrap_or("0.0.0.0") // Stopped instances do not have an ip address.
136                .parse()
137                .expect("AWS instance should have a valid ip"),
138            private_ip: aws_instance
139                .private_ip_address()
140                .unwrap_or("0.0.0.0") // Stopped instances do not have an ip address.
141                .parse()
142                .expect("AWS instance should have a valid ip"),
143            tags: vec![self.settings.testbed_id.clone()],
144            specs: format!(
145                "{:?}",
146                aws_instance
147                    .instance_type()
148                    .expect("AWS instance should have a type")
149            ),
150            status: format!(
151                "{:?}",
152                aws_instance
153                    .state()
154                    .expect("AWS instance should have a state")
155                    .name()
156                    .expect("AWS status should have a name")
157            ),
158            role,
159            lifecycle,
160        }
161    }
162
163    /// Query the image id determining the os of the instances.
164    /// NOTE: The image id changes depending on the region.
165    async fn find_image_id(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<String> {
166        // Use a more general filter that doesn't depend on specific build dates
167        let filters = [
168            // Filter for Ubuntu 24.04 LTS
169            FilterBuilder::default()
170                .name("name")
171                .values(Self::UBUNTU_NAME_PATTERN)
172                .build(),
173            // Only look at images from Canonical
174            FilterBuilder::default()
175                .name("owner-id")
176                .values(Self::CANONICAL_OWNER_ID)
177                .build(),
178            // Only want available images
179            FilterBuilder::default()
180                .name("state")
181                .values("available")
182                .build(),
183        ];
184
185        // Query images with these filters
186        let request = client.describe_images().set_filters(Some(filters.to_vec()));
187        let response = request.send().await?;
188
189        // Sort images by creation date (newest first)
190        let mut images = response.images().to_vec();
191        images.sort_by(|a, b| {
192            let a_date = a.creation_date().unwrap_or("");
193            let b_date = b.creation_date().unwrap_or("");
194            b_date.cmp(a_date) // Reverse order to get newest first
195        });
196
197        // Select the newest image
198        let image = images
199            .first()
200            .ok_or_else(|| CloudProviderError::Request("Cannot find Ubuntu 24.04 image".into()))?;
201
202        image
203            .image_id
204            .clone()
205            .ok_or_else(|| CloudProviderError::UnexpectedResponse("Image without ID".into()))
206    }
207
208    /// Create a new security group for the instance (if it doesn't already
209    /// exist).
210    async fn create_security_group(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<()> {
211        // Create a security group (if it doesn't already exist).
212        let request = client
213            .create_security_group()
214            .group_name(&self.settings.testbed_id)
215            .description("Allow all traffic (used for benchmarks).");
216
217        let response = request.send().await;
218        Self::check_but_ignore_duplicates(response)?;
219
220        // Authorize all traffic on the security group.
221        for protocol in ["tcp", "udp", "icmp", "icmpv6"] {
222            let mut request = client
223                .authorize_security_group_ingress()
224                .group_name(&self.settings.testbed_id)
225                .ip_protocol(protocol)
226                .cidr_ip("0.0.0.0/0"); // todo - allowing 0.0.0.0 seem a bit wild?
227            if protocol == "icmp" || protocol == "icmpv6" {
228                request = request.from_port(-1).to_port(-1);
229            } else {
230                request = request.from_port(0).to_port(65535);
231            }
232
233            let response = request.send().await;
234            Self::check_but_ignore_duplicates(response)?;
235        }
236        Ok(())
237    }
238
239    /// Return the command to mount the first (standard) NVMe drive.
240    fn nvme_mount_command(&self) -> Vec<String> {
241        let directory = self.settings.working_dir.display();
242        vec![
243            "export NVME_DRIVE=$(nvme list | awk '/NVMe Instance Storage/ {print $1; exit}')"
244                .to_string(),
245            "(sudo mkfs.ext4 -E nodiscard $NVME_DRIVE || true)".to_string(),
246            format!("(sudo mount $NVME_DRIVE {directory} || true)"),
247            format!("sudo chmod 777 -R {directory}"),
248        ]
249    }
250
251    /// Check whether the instance type specified in the settings supports NVMe
252    /// drives.
253    async fn check_nvme_support(&self) -> CloudProviderResult<bool> {
254        // Get the client for the first region. A given instance type should either have
255        // NVMe support in all regions or in none.
256        let client = match self
257            .settings
258            .regions
259            .first()
260            .and_then(|x| self.clients.get(x))
261        {
262            Some(client) => client,
263            None => return Ok(false),
264        };
265
266        // Request storage details for the instance type specified in the settings.
267        let request = client
268            .describe_instance_types()
269            .instance_types(self.settings.node_specs.as_str().into());
270
271        // Send the request.
272        let response = request.send().await?;
273
274        // Return true if the response contains references to NVMe drives.
275        if let Some(info) = response.instance_types().first() {
276            if let Some(info) = info.instance_storage_info() {
277                if info.nvme_support() == Some(&EphemeralNvmeSupport::Required) {
278                    return Ok(true);
279                }
280            }
281        }
282        Ok(false)
283    }
284    fn spot_options() -> InstanceMarketOptionsRequest {
285        InstanceMarketOptionsRequest::builder()
286            // SPOT vs CAPACITY_BLOCK
287            .market_type(MarketType::Spot)
288            .spot_options(
289                SpotMarketOptions::builder()
290                    // One-off Spot request that ends when the instance ends.
291                    .spot_instance_type(SpotInstanceType::OneTime)
292                    // What to do when AWS reclaims capacity.
293                    // For ephemeral test runs, terminate is usually fine.
294                    .instance_interruption_behavior(InstanceInterruptionBehavior::Terminate)
295                    // Usually DON'T set max_price: if omitted, you just pay current Spot price,
296                    // capped by On-Demand in most regions. Setting it can increase
297                    // interruptions.:contentReference[oaicite:4]{index=4}
298                    .build(),
299            )
300            .build()
301    }
302}
303
304#[async_trait::async_trait]
305impl ServerProviderClient for AwsClient {
306    const USERNAME: &'static str = "ubuntu";
307
308    async fn list_instances_by_role(
309        &self,
310        role: InstanceRole,
311    ) -> CloudProviderResult<Vec<Instance>> {
312        let filter_name = Filter::builder()
313            .name("tag:Name")
314            .values(self.settings.testbed_id.clone())
315            .build();
316        let filter_role = Filter::builder()
317            .name("tag:Role")
318            .values(role.to_string())
319            .build();
320        let filter_state = Filter::builder()
321            .name("instance-state-name")
322            .values("pending")
323            .values("running")
324            .values("stopping")
325            .values("stopped")
326            .build();
327
328        let mut instances = Vec::new();
329        for (region, client) in &self.clients {
330            let request = client.describe_instances().set_filters(Some(vec![
331                filter_name.clone(),
332                filter_role.clone(),
333                filter_state.clone(),
334            ]));
335            let response = request.send().await?;
336            for reservation in response.reservations() {
337                for instance in reservation.instances() {
338                    instances.push(self.make_instance(region.clone(), instance));
339                }
340            }
341        }
342
343        Ok(instances)
344    }
345
346    async fn list_instances_by_region_and_ids(
347        &self,
348        ids_by_region: &HashMap<String, Vec<String>>,
349    ) -> CloudProviderResult<Vec<Instance>> {
350        let mut instances = Vec::new();
351        for (region, client) in &self.clients {
352            let request = client
353                .describe_instances()
354                .set_instance_ids(ids_by_region.get(region).cloned());
355            let response = request.send().await?;
356            for reservation in response.reservations() {
357                for instance in reservation.instances() {
358                    instances.push(self.make_instance(region.clone(), instance));
359                }
360            }
361        }
362
363        Ok(instances)
364    }
365
366    async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
367    where
368        I: Iterator<Item = &'a Instance> + Send,
369    {
370        let mut instance_ids = HashMap::new();
371        for instance in instances {
372            instance_ids
373                .entry(&instance.region)
374                .or_insert_with(Vec::new)
375                .push(instance.id.clone());
376        }
377
378        for (region, client) in &self.clients {
379            let ids = instance_ids.remove(&region.to_string());
380            if ids.is_some() {
381                client
382                    .start_instances()
383                    .set_instance_ids(ids)
384                    .send()
385                    .await?;
386            }
387        }
388        Ok(())
389    }
390
391    async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
392    where
393        I: Iterator<Item = &'a Instance> + Send,
394    {
395        let mut instance_ids: HashMap<String, Vec<String>> = HashMap::new();
396        for i in instances {
397            if i.lifecycle == InstanceLifecycle::Spot {
398                return Err(CloudProviderError::FailedToStopSpotInstance(i.id.clone()));
399            }
400            instance_ids
401                .entry(i.region.clone())
402                .or_default()
403                .push(i.id.clone());
404        }
405
406        for (region, ids) in instance_ids {
407            let client = self.clients.get(&region).ok_or_else(|| {
408                CloudProviderError::Request(format!("Undefined region {:?}", region))
409            })?;
410            client
411                .stop_instances()
412                .set_instance_ids(Some(ids))
413                .send()
414                .await?;
415        }
416        Ok(())
417    }
418
419    async fn create_instance<S>(
420        &self,
421        region: S,
422        role: InstanceRole,
423        quantity: usize,
424        use_spot_instances: bool,
425        id: String,
426    ) -> CloudProviderResult<Vec<Instance>>
427    where
428        S: Into<String> + Serialize + Send,
429    {
430        let region = region.into();
431        let testbed_id = &self.settings.testbed_id;
432
433        let client = self
434            .clients
435            .get(&region)
436            .ok_or_else(|| CloudProviderError::Request(format!("Undefined region {region:?}")))?;
437
438        // Create a security group (if needed).
439        self.create_security_group(client).await?;
440
441        // Query the image id.
442        let image_id = self.find_image_id(client).await?;
443
444        // Create a new instance.
445        let tags = TagSpecification::builder()
446            .resource_type(ResourceType::Instance)
447            .tags(Tag::builder().key("Name").value(testbed_id).build())
448            .tags(Tag::builder().key("Role").value(role.to_string()).build())
449            .tags(Tag::builder().key("Id").value(id).build())
450            .build();
451
452        let storage = BlockDeviceMapping::builder()
453            .device_name("/dev/sda1")
454            .ebs(
455                EbsBlockDevice::builder()
456                    .delete_on_termination(true)
457                    .volume_size(500)
458                    .volume_type(VolumeType::Gp2)
459                    .build(),
460            )
461            .build();
462        let instance_type = match role {
463            InstanceRole::Node => &self.settings.node_specs,
464            InstanceRole::Metrics => &self.settings.metrics_specs,
465            InstanceRole::Client => &self.settings.client_specs,
466        };
467
468        let mut base_request = client
469            .run_instances()
470            .image_id(image_id)
471            .instance_type(instance_type.as_str().into())
472            .key_name(testbed_id)
473            .security_groups(&self.settings.testbed_id)
474            .tag_specifications(tags);
475
476        // Only the monitoring device should be EBS backed.
477        if role == InstanceRole::Metrics {
478            base_request = base_request.block_device_mappings(storage);
479        }
480        let mut collected_instances = Vec::new();
481        if use_spot_instances && role == InstanceRole::Node {
482            let start = tokio::time::Instant::now();
483            // 5min try for spot instances
484            let total_runtime = tokio::time::Duration::from_secs(300);
485            while start.elapsed() < total_runtime && collected_instances.len() < quantity {
486                display::status(format!(
487                    "{}s/{}s: {}",
488                    start.elapsed().as_secs(),
489                    total_runtime.as_secs(),
490                    collected_instances.len()
491                ));
492                let needed = (quantity - collected_instances.len()) as i32;
493                let request = base_request
494                    .clone()
495                    .min_count(1)
496                    .max_count(needed)
497                    .instance_market_options(Self::spot_options());
498                let result = request.send().await;
499                let instances = match result {
500                    Ok(response) => response
501                        .instances()
502                        .iter()
503                        .map(|i| self.make_instance(region.clone(), i))
504                        .collect(),
505                    Err(_) => Vec::new(),
506                };
507                collected_instances.extend(instances);
508            }
509        }
510        while collected_instances.len() < quantity {
511            // some instances need to be OnDemand
512            let needed = (quantity - collected_instances.len()) as i32;
513            let request = base_request.clone().min_count(1).max_count(needed);
514            let response = request.send().await?;
515            let on_demand_instances = response
516                .instances()
517                .iter()
518                .map(|instance| self.make_instance(region.clone(), instance))
519                .collect::<Vec<_>>();
520            collected_instances.extend(on_demand_instances);
521            display::status(format!(
522                "collected instances: {}",
523                collected_instances.len()
524            ));
525        }
526        Ok(collected_instances)
527    }
528
529    async fn delete_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
530    where
531        I: Iterator<Item = &'a Instance> + Send,
532    {
533        let map_of_ids_by_region = instances.into_iter().fold(
534            HashMap::new(),
535            |mut acc: HashMap<String, Vec<String>>, i| {
536                acc.entry(i.region.clone()).or_default().push(i.id.clone());
537                acc
538            },
539        );
540        for (region, ids) in map_of_ids_by_region {
541            let client = self.clients.get(&region).ok_or_else(|| {
542                CloudProviderError::Request(format!("Undefined region {:?}", region))
543            })?;
544            client
545                .terminate_instances()
546                .set_instance_ids(Some(ids))
547                .send()
548                .await?;
549        }
550        Ok(())
551    }
552
553    async fn register_ssh_public_key(&self, public_key: String) -> CloudProviderResult<()> {
554        for client in self.clients.values() {
555            let request = client
556                .import_key_pair()
557                .key_name(&self.settings.testbed_id)
558                .public_key_material(Blob::new::<String>(public_key.clone()));
559
560            let response = request.send().await;
561            Self::check_but_ignore_duplicates(response)?;
562        }
563        Ok(())
564    }
565
566    async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>> {
567        if self.check_nvme_support().await? {
568            Ok(self.nvme_mount_command())
569        } else {
570            Ok(Vec::new())
571        }
572    }
573    #[cfg(test)]
574    fn instances(&self) -> Vec<Instance> {
575        // Only used under testing by the TestClient, unreachable cause no test
576        // should use AwsClient.
577        unreachable!()
578    }
579}