iota_aws_orchestrator/client/
aws.rs1use std::{
6 collections::HashMap,
7 fmt::{Debug, Display},
8};
9
10use aws_runtime::env_config::file::{EnvConfigFileKind, EnvConfigFiles};
11use aws_sdk_ec2::{
12 config::Region,
13 primitives::Blob,
14 types::{
15 BlockDeviceMapping, EbsBlockDevice, EphemeralNvmeSupport, Filter,
16 InstanceInterruptionBehavior, InstanceMarketOptionsRequest, MarketType, ResourceType,
17 SpotInstanceType, SpotMarketOptions, Tag, TagSpecification, VolumeType,
18 builders::FilterBuilder,
19 },
20};
21use aws_smithy_runtime_api::client::{behavior_version::BehaviorVersion, result::SdkError};
22use serde::Serialize;
23
24use super::{Instance, InstanceLifecycle, InstanceRole, ServerProviderClient};
25use crate::{
26 display,
27 error::{CloudProviderError, CloudProviderResult},
28 settings::Settings,
29};
30
31impl<T> From<SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>
33 for CloudProviderError
34where
35 T: Debug + std::error::Error + Send + Sync + 'static,
36{
37 fn from(e: SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>) -> Self {
38 Self::Request(format!("{:?}", e.into_source()))
39 }
40}
41
42pub struct AwsClient {
44 settings: Settings,
46 clients: HashMap<String, aws_sdk_ec2::Client>,
48}
49
50impl Display for AwsClient {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 write!(f, "AWS EC2 client v{}", aws_sdk_ec2::meta::PKG_VERSION)
53 }
54}
55
56impl AwsClient {
57 const UBUNTU_NAME_PATTERN: &'static str =
58 "ubuntu/images/hvm-ssd-gp3/ubuntu-noble-24.04-amd64-server-*";
59 const CANONICAL_OWNER_ID: &'static str = "099720109477";
60
61 pub async fn new(settings: Settings) -> Self {
63 let profile_files = EnvConfigFiles::builder()
64 .with_file(EnvConfigFileKind::Credentials, &settings.token_file)
65 .with_contents(EnvConfigFileKind::Config, "[default]\noutput=json")
66 .build();
67
68 let mut clients = HashMap::new();
69 for region in settings.regions.clone() {
70 let sdk_config = aws_config::defaults(BehaviorVersion::latest())
71 .region(Region::new(region.clone()))
72 .profile_files(profile_files.clone())
73 .load()
74 .await;
75 let client = aws_sdk_ec2::Client::new(&sdk_config);
76 clients.insert(region, client);
77 }
78
79 Self { settings, clients }
80 }
81
82 fn check_but_ignore_duplicates<T, E>(
85 response: Result<
86 T,
87 SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>,
88 >,
89 ) -> CloudProviderResult<()>
90 where
91 E: Debug + std::error::Error + Send + Sync + 'static,
92 {
93 if let Err(e) = response {
94 let error_message = format!("{e:?}");
95 if !error_message.to_lowercase().contains("duplicate") {
96 return Err(e.into());
97 }
98 }
99 Ok(())
100 }
101 fn get_tag_value(instance: &aws_sdk_ec2::types::Instance, key: &str) -> Option<String> {
102 instance
103 .tags()
104 .iter()
105 .find(|tag| tag.key().is_some_and(|k| k == key))
106 .and_then(|tag| tag.value().map(|v| v.to_string()))
107 }
108 fn make_instance(
111 &self,
112 region: String,
113 aws_instance: &aws_sdk_ec2::types::Instance,
114 ) -> Instance {
115 let role: InstanceRole = Self::get_tag_value(aws_instance, "Role")
116 .expect("AWS instance should have a role")
117 .as_str()
118 .into();
119 let lifecycle: InstanceLifecycle =
120 if let Some(aws_sdk_ec2::types::InstanceLifecycleType::Spot) =
121 aws_instance.instance_lifecycle
122 {
123 InstanceLifecycle::Spot
124 } else {
125 InstanceLifecycle::OnDemand
126 };
127 Instance {
128 id: aws_instance
129 .instance_id()
130 .expect("AWS instance should have an id")
131 .into(),
132 region,
133 main_ip: aws_instance
134 .public_ip_address()
135 .unwrap_or("0.0.0.0") .parse()
137 .expect("AWS instance should have a valid ip"),
138 private_ip: aws_instance
139 .private_ip_address()
140 .unwrap_or("0.0.0.0") .parse()
142 .expect("AWS instance should have a valid ip"),
143 tags: vec![self.settings.testbed_id.clone()],
144 specs: format!(
145 "{:?}",
146 aws_instance
147 .instance_type()
148 .expect("AWS instance should have a type")
149 ),
150 status: format!(
151 "{:?}",
152 aws_instance
153 .state()
154 .expect("AWS instance should have a state")
155 .name()
156 .expect("AWS status should have a name")
157 ),
158 role,
159 lifecycle,
160 }
161 }
162
163 async fn find_image_id(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<String> {
166 let filters = [
168 FilterBuilder::default()
170 .name("name")
171 .values(Self::UBUNTU_NAME_PATTERN)
172 .build(),
173 FilterBuilder::default()
175 .name("owner-id")
176 .values(Self::CANONICAL_OWNER_ID)
177 .build(),
178 FilterBuilder::default()
180 .name("state")
181 .values("available")
182 .build(),
183 ];
184
185 let request = client.describe_images().set_filters(Some(filters.to_vec()));
187 let response = request.send().await?;
188
189 let mut images = response.images().to_vec();
191 images.sort_by(|a, b| {
192 let a_date = a.creation_date().unwrap_or("");
193 let b_date = b.creation_date().unwrap_or("");
194 b_date.cmp(a_date) });
196
197 let image = images
199 .first()
200 .ok_or_else(|| CloudProviderError::Request("Cannot find Ubuntu 24.04 image".into()))?;
201
202 image
203 .image_id
204 .clone()
205 .ok_or_else(|| CloudProviderError::UnexpectedResponse("Image without ID".into()))
206 }
207
208 async fn create_security_group(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<()> {
211 let request = client
213 .create_security_group()
214 .group_name(&self.settings.testbed_id)
215 .description("Allow all traffic (used for benchmarks).");
216
217 let response = request.send().await;
218 Self::check_but_ignore_duplicates(response)?;
219
220 for protocol in ["tcp", "udp", "icmp", "icmpv6"] {
222 let mut request = client
223 .authorize_security_group_ingress()
224 .group_name(&self.settings.testbed_id)
225 .ip_protocol(protocol)
226 .cidr_ip("0.0.0.0/0"); if protocol == "icmp" || protocol == "icmpv6" {
228 request = request.from_port(-1).to_port(-1);
229 } else {
230 request = request.from_port(0).to_port(65535);
231 }
232
233 let response = request.send().await;
234 Self::check_but_ignore_duplicates(response)?;
235 }
236 Ok(())
237 }
238
239 fn nvme_mount_command(&self) -> Vec<String> {
241 let directory = self.settings.working_dir.display();
242 vec![
243 "export NVME_DRIVE=$(nvme list | awk '/NVMe Instance Storage/ {print $1; exit}')"
244 .to_string(),
245 "(sudo mkfs.ext4 -E nodiscard $NVME_DRIVE || true)".to_string(),
246 format!("(sudo mount $NVME_DRIVE {directory} || true)"),
247 format!("sudo chmod 777 -R {directory}"),
248 ]
249 }
250
251 async fn check_nvme_support(&self) -> CloudProviderResult<bool> {
254 let client = match self
257 .settings
258 .regions
259 .first()
260 .and_then(|x| self.clients.get(x))
261 {
262 Some(client) => client,
263 None => return Ok(false),
264 };
265
266 let request = client
268 .describe_instance_types()
269 .instance_types(self.settings.node_specs.as_str().into());
270
271 let response = request.send().await?;
273
274 if let Some(info) = response.instance_types().first() {
276 if let Some(info) = info.instance_storage_info() {
277 if info.nvme_support() == Some(&EphemeralNvmeSupport::Required) {
278 return Ok(true);
279 }
280 }
281 }
282 Ok(false)
283 }
284 fn spot_options() -> InstanceMarketOptionsRequest {
285 InstanceMarketOptionsRequest::builder()
286 .market_type(MarketType::Spot)
288 .spot_options(
289 SpotMarketOptions::builder()
290 .spot_instance_type(SpotInstanceType::OneTime)
292 .instance_interruption_behavior(InstanceInterruptionBehavior::Terminate)
295 .build(),
299 )
300 .build()
301 }
302}
303
304#[async_trait::async_trait]
305impl ServerProviderClient for AwsClient {
306 const USERNAME: &'static str = "ubuntu";
307
308 async fn list_instances_by_role(
309 &self,
310 role: InstanceRole,
311 ) -> CloudProviderResult<Vec<Instance>> {
312 let filter_name = Filter::builder()
313 .name("tag:Name")
314 .values(self.settings.testbed_id.clone())
315 .build();
316 let filter_role = Filter::builder()
317 .name("tag:Role")
318 .values(role.to_string())
319 .build();
320 let filter_state = Filter::builder()
321 .name("instance-state-name")
322 .values("pending")
323 .values("running")
324 .values("stopping")
325 .values("stopped")
326 .build();
327
328 let mut instances = Vec::new();
329 for (region, client) in &self.clients {
330 let request = client.describe_instances().set_filters(Some(vec![
331 filter_name.clone(),
332 filter_role.clone(),
333 filter_state.clone(),
334 ]));
335 let response = request.send().await?;
336 for reservation in response.reservations() {
337 for instance in reservation.instances() {
338 instances.push(self.make_instance(region.clone(), instance));
339 }
340 }
341 }
342
343 Ok(instances)
344 }
345
346 async fn list_instances_by_region_and_ids(
347 &self,
348 ids_by_region: &HashMap<String, Vec<String>>,
349 ) -> CloudProviderResult<Vec<Instance>> {
350 let mut instances = Vec::new();
351 for (region, client) in &self.clients {
352 let request = client
353 .describe_instances()
354 .set_instance_ids(ids_by_region.get(region).cloned());
355 let response = request.send().await?;
356 for reservation in response.reservations() {
357 for instance in reservation.instances() {
358 instances.push(self.make_instance(region.clone(), instance));
359 }
360 }
361 }
362
363 Ok(instances)
364 }
365
366 async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
367 where
368 I: Iterator<Item = &'a Instance> + Send,
369 {
370 let mut instance_ids = HashMap::new();
371 for instance in instances {
372 instance_ids
373 .entry(&instance.region)
374 .or_insert_with(Vec::new)
375 .push(instance.id.clone());
376 }
377
378 for (region, client) in &self.clients {
379 let ids = instance_ids.remove(®ion.to_string());
380 if ids.is_some() {
381 client
382 .start_instances()
383 .set_instance_ids(ids)
384 .send()
385 .await?;
386 }
387 }
388 Ok(())
389 }
390
391 async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
392 where
393 I: Iterator<Item = &'a Instance> + Send,
394 {
395 let mut instance_ids: HashMap<String, Vec<String>> = HashMap::new();
396 for i in instances {
397 if i.lifecycle == InstanceLifecycle::Spot {
398 return Err(CloudProviderError::FailedToStopSpotInstance(i.id.clone()));
399 }
400 instance_ids
401 .entry(i.region.clone())
402 .or_default()
403 .push(i.id.clone());
404 }
405
406 for (region, ids) in instance_ids {
407 let client = self.clients.get(®ion).ok_or_else(|| {
408 CloudProviderError::Request(format!("Undefined region {:?}", region))
409 })?;
410 client
411 .stop_instances()
412 .set_instance_ids(Some(ids))
413 .send()
414 .await?;
415 }
416 Ok(())
417 }
418
419 async fn create_instance<S>(
420 &self,
421 region: S,
422 role: InstanceRole,
423 quantity: usize,
424 use_spot_instances: bool,
425 id: String,
426 ) -> CloudProviderResult<Vec<Instance>>
427 where
428 S: Into<String> + Serialize + Send,
429 {
430 let region = region.into();
431 let testbed_id = &self.settings.testbed_id;
432
433 let client = self
434 .clients
435 .get(®ion)
436 .ok_or_else(|| CloudProviderError::Request(format!("Undefined region {region:?}")))?;
437
438 self.create_security_group(client).await?;
440
441 let image_id = self.find_image_id(client).await?;
443
444 let tags = TagSpecification::builder()
446 .resource_type(ResourceType::Instance)
447 .tags(Tag::builder().key("Name").value(testbed_id).build())
448 .tags(Tag::builder().key("Role").value(role.to_string()).build())
449 .tags(Tag::builder().key("Id").value(id).build())
450 .build();
451
452 let storage = BlockDeviceMapping::builder()
453 .device_name("/dev/sda1")
454 .ebs(
455 EbsBlockDevice::builder()
456 .delete_on_termination(true)
457 .volume_size(500)
458 .volume_type(VolumeType::Gp2)
459 .build(),
460 )
461 .build();
462 let instance_type = match role {
463 InstanceRole::Node => &self.settings.node_specs,
464 InstanceRole::Metrics => &self.settings.metrics_specs,
465 InstanceRole::Client => &self.settings.client_specs,
466 };
467
468 let mut base_request = client
469 .run_instances()
470 .image_id(image_id)
471 .instance_type(instance_type.as_str().into())
472 .key_name(testbed_id)
473 .security_groups(&self.settings.testbed_id)
474 .tag_specifications(tags);
475
476 if role == InstanceRole::Metrics {
478 base_request = base_request.block_device_mappings(storage);
479 }
480 let mut collected_instances = Vec::new();
481 if use_spot_instances && role == InstanceRole::Node {
482 let start = tokio::time::Instant::now();
483 let total_runtime = tokio::time::Duration::from_secs(300);
485 while start.elapsed() < total_runtime && collected_instances.len() < quantity {
486 display::status(format!(
487 "{}s/{}s: {}",
488 start.elapsed().as_secs(),
489 total_runtime.as_secs(),
490 collected_instances.len()
491 ));
492 let needed = (quantity - collected_instances.len()) as i32;
493 let request = base_request
494 .clone()
495 .min_count(1)
496 .max_count(needed)
497 .instance_market_options(Self::spot_options());
498 let result = request.send().await;
499 let instances = match result {
500 Ok(response) => response
501 .instances()
502 .iter()
503 .map(|i| self.make_instance(region.clone(), i))
504 .collect(),
505 Err(_) => Vec::new(),
506 };
507 collected_instances.extend(instances);
508 }
509 }
510 while collected_instances.len() < quantity {
511 let needed = (quantity - collected_instances.len()) as i32;
513 let request = base_request.clone().min_count(1).max_count(needed);
514 let response = request.send().await?;
515 let on_demand_instances = response
516 .instances()
517 .iter()
518 .map(|instance| self.make_instance(region.clone(), instance))
519 .collect::<Vec<_>>();
520 collected_instances.extend(on_demand_instances);
521 display::status(format!(
522 "collected instances: {}",
523 collected_instances.len()
524 ));
525 }
526 Ok(collected_instances)
527 }
528
529 async fn delete_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
530 where
531 I: Iterator<Item = &'a Instance> + Send,
532 {
533 let map_of_ids_by_region = instances.into_iter().fold(
534 HashMap::new(),
535 |mut acc: HashMap<String, Vec<String>>, i| {
536 acc.entry(i.region.clone()).or_default().push(i.id.clone());
537 acc
538 },
539 );
540 for (region, ids) in map_of_ids_by_region {
541 let client = self.clients.get(®ion).ok_or_else(|| {
542 CloudProviderError::Request(format!("Undefined region {:?}", region))
543 })?;
544 client
545 .terminate_instances()
546 .set_instance_ids(Some(ids))
547 .send()
548 .await?;
549 }
550 Ok(())
551 }
552
553 async fn register_ssh_public_key(&self, public_key: String) -> CloudProviderResult<()> {
554 for client in self.clients.values() {
555 let request = client
556 .import_key_pair()
557 .key_name(&self.settings.testbed_id)
558 .public_key_material(Blob::new::<String>(public_key.clone()));
559
560 let response = request.send().await;
561 Self::check_but_ignore_duplicates(response)?;
562 }
563 Ok(())
564 }
565
566 async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>> {
567 if self.check_nvme_support().await? {
568 Ok(self.nvme_mount_command())
569 } else {
570 Ok(Vec::new())
571 }
572 }
573 #[cfg(test)]
574 fn instances(&self) -> Vec<Instance> {
575 unreachable!()
578 }
579}