iota_aws_orchestrator/client/
aws.rs1use std::{
6 collections::HashMap,
7 fmt::{Debug, Display},
8};
9
10use aws_runtime::env_config::file::{EnvConfigFileKind, EnvConfigFiles};
11use aws_sdk_ec2::{
12 config::Region,
13 primitives::Blob,
14 types::{
15 BlockDeviceMapping, EbsBlockDevice, EphemeralNvmeSupport, Filter,
16 InstanceInterruptionBehavior, InstanceMarketOptionsRequest, MarketType, ResourceType,
17 SpotInstanceType, SpotMarketOptions, Tag, TagSpecification, VolumeType,
18 builders::FilterBuilder,
19 },
20};
21use aws_smithy_runtime_api::client::{behavior_version::BehaviorVersion, result::SdkError};
22use serde::Serialize;
23
24use super::{Instance, InstanceLifecycle, InstanceRole, ServerProviderClient};
25use crate::{
26 display,
27 error::{CloudProviderError, CloudProviderResult},
28 settings::Settings,
29};
30
31impl<T> From<SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>
33 for CloudProviderError
34where
35 T: Debug + std::error::Error + Send + Sync + 'static,
36{
37 fn from(e: SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>) -> Self {
38 Self::Request(format!("{:?}", e.into_source()))
39 }
40}
41
42pub struct AwsClient {
44 settings: Settings,
46 clients: HashMap<String, aws_sdk_ec2::Client>,
48}
49
50impl Display for AwsClient {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 write!(f, "AWS EC2 client v{}", aws_sdk_ec2::meta::PKG_VERSION)
53 }
54}
55
56impl AwsClient {
57 const UBUNTU_NAME_PATTERN: &'static str =
58 "ubuntu/images/hvm-ssd-gp3/ubuntu-noble-24.04-amd64-server-*";
59 const CANONICAL_OWNER_ID: &'static str = "099720109477";
60
61 pub async fn new(settings: Settings) -> Self {
63 let profile_files = EnvConfigFiles::builder()
64 .with_file(EnvConfigFileKind::Credentials, &settings.token_file)
65 .with_contents(EnvConfigFileKind::Config, "[default]\noutput=json")
66 .build();
67
68 let mut clients = HashMap::new();
69 for region in settings.regions.clone() {
70 let sdk_config = aws_config::defaults(BehaviorVersion::latest())
71 .region(Region::new(region.clone()))
72 .profile_files(profile_files.clone())
73 .load()
74 .await;
75 let client = aws_sdk_ec2::Client::new(&sdk_config);
76 clients.insert(region, client);
77 }
78
79 Self { settings, clients }
80 }
81
82 fn check_but_ignore_duplicates<T, E>(
85 response: Result<
86 T,
87 SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>,
88 >,
89 ) -> CloudProviderResult<()>
90 where
91 E: Debug + std::error::Error + Send + Sync + 'static,
92 {
93 if let Err(e) = response {
94 let error_message = format!("{e:?}");
95 if !error_message.to_lowercase().contains("duplicate") {
96 return Err(e.into());
97 }
98 }
99 Ok(())
100 }
101 fn get_tag_value(instance: &aws_sdk_ec2::types::Instance, key: &str) -> Option<String> {
102 instance
103 .tags()
104 .iter()
105 .find(|tag| tag.key().is_some_and(|k| k == key))
106 .and_then(|tag| tag.value().map(|v| v.to_string()))
107 }
108 fn make_instance(
111 &self,
112 region: String,
113 aws_instance: &aws_sdk_ec2::types::Instance,
114 ) -> Instance {
115 let role: InstanceRole = Self::get_tag_value(aws_instance, "Role")
116 .expect("AWS instance should have a role")
117 .as_str()
118 .into();
119 let lifecycle: InstanceLifecycle =
120 if let Some(aws_sdk_ec2::types::InstanceLifecycleType::Spot) =
121 aws_instance.instance_lifecycle
122 {
123 InstanceLifecycle::Spot
124 } else {
125 InstanceLifecycle::OnDemand
126 };
127 Instance {
128 id: aws_instance
129 .instance_id()
130 .expect("AWS instance should have an id")
131 .into(),
132 region,
133 main_ip: aws_instance
134 .public_ip_address()
135 .unwrap_or("0.0.0.0") .parse()
137 .expect("AWS instance should have a valid ip"),
138 private_ip: aws_instance
139 .private_ip_address()
140 .unwrap_or("0.0.0.0") .parse()
142 .expect("AWS instance should have a valid ip"),
143 tags: vec![self.settings.testbed_id.clone()],
144 specs: format!(
145 "{:?}",
146 aws_instance
147 .instance_type()
148 .expect("AWS instance should have a type")
149 ),
150 status: format!(
151 "{:?}",
152 aws_instance
153 .state()
154 .expect("AWS instance should have a state")
155 .name()
156 .expect("AWS status should have a name")
157 ),
158 role,
159 lifecycle,
160 }
161 }
162
163 async fn find_image_id(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<String> {
166 let filters = [
168 FilterBuilder::default()
170 .name("name")
171 .values(Self::UBUNTU_NAME_PATTERN)
172 .build(),
173 FilterBuilder::default()
175 .name("owner-id")
176 .values(Self::CANONICAL_OWNER_ID)
177 .build(),
178 FilterBuilder::default()
180 .name("state")
181 .values("available")
182 .build(),
183 ];
184
185 let request = client.describe_images().set_filters(Some(filters.to_vec()));
187 let response = request.send().await?;
188
189 let mut images = response.images().to_vec();
191 images.sort_by(|a, b| {
192 let a_date = a.creation_date().unwrap_or("");
193 let b_date = b.creation_date().unwrap_or("");
194 b_date.cmp(a_date) });
196
197 let image = images
199 .first()
200 .ok_or_else(|| CloudProviderError::Request("Cannot find Ubuntu 24.04 image".into()))?;
201
202 image
203 .image_id
204 .clone()
205 .ok_or_else(|| CloudProviderError::UnexpectedResponse("Image without ID".into()))
206 }
207
208 async fn create_security_group(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<()> {
211 let request = client
213 .create_security_group()
214 .group_name(&self.settings.testbed_id)
215 .description("Allow all traffic (used for benchmarks).");
216
217 let response = request.send().await;
218 Self::check_but_ignore_duplicates(response)?;
219
220 for protocol in ["tcp", "udp", "icmp", "icmpv6"] {
222 let mut request = client
223 .authorize_security_group_ingress()
224 .group_name(&self.settings.testbed_id)
225 .ip_protocol(protocol)
226 .cidr_ip("0.0.0.0/0"); if protocol == "icmp" || protocol == "icmpv6" {
228 request = request.from_port(-1).to_port(-1);
229 } else {
230 request = request.from_port(0).to_port(65535);
231 }
232
233 let response = request.send().await;
234 Self::check_but_ignore_duplicates(response)?;
235 }
236 Ok(())
237 }
238
239 fn nvme_mount_command(&self) -> Vec<String> {
241 let directory = self.settings.working_dir.display();
242 vec![
243 "export NVME_DRIVE=$(nvme list | awk '/NVMe Instance Storage/ {print $1; exit}')"
244 .to_string(),
245 "(sudo mkfs.ext4 -E nodiscard $NVME_DRIVE || true)".to_string(),
246 format!("(sudo mount $NVME_DRIVE {directory} || true)"),
247 format!("sudo chmod 777 -R {directory}"),
248 ]
249 }
250
251 async fn check_nvme_support(&self) -> CloudProviderResult<bool> {
254 let client = match self
257 .settings
258 .regions
259 .first()
260 .and_then(|x| self.clients.get(x))
261 {
262 Some(client) => client,
263 None => return Ok(false),
264 };
265
266 let request = client
268 .describe_instance_types()
269 .instance_types(self.settings.node_specs.as_str().into());
270
271 let response = request.send().await?;
273
274 if let Some(info) = response.instance_types().first() {
276 if let Some(info) = info.instance_storage_info() {
277 if info.nvme_support() == Some(&EphemeralNvmeSupport::Required) {
278 return Ok(true);
279 }
280 }
281 }
282 Ok(false)
283 }
284 fn spot_options() -> InstanceMarketOptionsRequest {
285 InstanceMarketOptionsRequest::builder()
286 .market_type(MarketType::Spot)
288 .spot_options(
289 SpotMarketOptions::builder()
290 .spot_instance_type(SpotInstanceType::OneTime)
292 .instance_interruption_behavior(InstanceInterruptionBehavior::Terminate)
295 .build(),
299 )
300 .build()
301 }
302}
303
304#[async_trait::async_trait]
305impl ServerProviderClient for AwsClient {
306 const USERNAME: &'static str = "ubuntu";
307
308 async fn list_instances_by_role(
309 &self,
310 role: InstanceRole,
311 ) -> CloudProviderResult<Vec<Instance>> {
312 let filter_name = Filter::builder()
313 .name("tag:Name")
314 .values(self.settings.testbed_id.clone())
315 .build();
316 let filter_role = Filter::builder()
317 .name("tag:Role")
318 .values(role.to_string())
319 .build();
320 let filter_state = Filter::builder()
321 .name("instance-state-name")
322 .values("pending")
323 .values("running")
324 .values("stopping")
325 .values("stopped")
326 .build();
327
328 let mut instances = Vec::new();
329 for (region, client) in &self.clients {
330 let request = client.describe_instances().set_filters(Some(vec![
331 filter_name.clone(),
332 filter_role.clone(),
333 filter_state.clone(),
334 ]));
335 let response = request.send().await?;
336 for reservation in response.reservations() {
337 for instance in reservation.instances() {
338 instances.push(self.make_instance(region.clone(), instance));
339 }
340 }
341 }
342 instances.sort_by_key(|i| i.main_ip);
343
344 Ok(instances)
345 }
346
347 async fn list_instances_by_region_and_ids(
348 &self,
349 ids_by_region: &HashMap<String, Vec<String>>,
350 ) -> CloudProviderResult<Vec<Instance>> {
351 let mut instances = Vec::new();
352 for (region, client) in &self.clients {
353 let request = client
354 .describe_instances()
355 .set_instance_ids(ids_by_region.get(region).cloned());
356 let response = request.send().await?;
357 for reservation in response.reservations() {
358 for instance in reservation.instances() {
359 instances.push(self.make_instance(region.clone(), instance));
360 }
361 }
362 }
363
364 Ok(instances)
365 }
366
367 async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
368 where
369 I: Iterator<Item = &'a Instance> + Send,
370 {
371 let mut instance_ids = HashMap::new();
372 for instance in instances {
373 instance_ids
374 .entry(&instance.region)
375 .or_insert_with(Vec::new)
376 .push(instance.id.clone());
377 }
378
379 for (region, client) in &self.clients {
380 let ids = instance_ids.remove(®ion.to_string());
381 if ids.is_some() {
382 client
383 .start_instances()
384 .set_instance_ids(ids)
385 .send()
386 .await?;
387 }
388 }
389 Ok(())
390 }
391
392 async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
393 where
394 I: Iterator<Item = &'a Instance> + Send,
395 {
396 let mut instance_ids: HashMap<String, Vec<String>> = HashMap::new();
397 for i in instances {
398 if i.lifecycle == InstanceLifecycle::Spot {
399 return Err(CloudProviderError::FailedToStopSpotInstance(i.id.clone()));
400 }
401 instance_ids
402 .entry(i.region.clone())
403 .or_default()
404 .push(i.id.clone());
405 }
406
407 for (region, ids) in instance_ids {
408 let client = self.clients.get(®ion).ok_or_else(|| {
409 CloudProviderError::Request(format!("Undefined region {:?}", region))
410 })?;
411 client
412 .stop_instances()
413 .set_instance_ids(Some(ids))
414 .send()
415 .await?;
416 }
417 Ok(())
418 }
419
420 async fn create_instance<S>(
421 &self,
422 region: S,
423 role: InstanceRole,
424 quantity: usize,
425 use_spot_instances: bool,
426 id: String,
427 ) -> CloudProviderResult<Vec<Instance>>
428 where
429 S: Into<String> + Serialize + Send,
430 {
431 let region = region.into();
432 let testbed_id = &self.settings.testbed_id;
433
434 let client = self
435 .clients
436 .get(®ion)
437 .ok_or_else(|| CloudProviderError::Request(format!("Undefined region {region:?}")))?;
438
439 self.create_security_group(client).await?;
441
442 let image_id = self.find_image_id(client).await?;
444
445 let tags = TagSpecification::builder()
447 .resource_type(ResourceType::Instance)
448 .tags(Tag::builder().key("Name").value(testbed_id).build())
449 .tags(Tag::builder().key("Role").value(role.to_string()).build())
450 .tags(Tag::builder().key("Id").value(id).build())
451 .build();
452
453 let storage = BlockDeviceMapping::builder()
454 .device_name("/dev/sda1")
455 .ebs(
456 EbsBlockDevice::builder()
457 .delete_on_termination(true)
458 .volume_size(500)
459 .volume_type(VolumeType::Gp2)
460 .build(),
461 )
462 .build();
463 let instance_type = match role {
464 InstanceRole::Node => &self.settings.node_specs,
465 InstanceRole::Metrics => &self.settings.metrics_specs,
466 InstanceRole::Client => &self.settings.client_specs,
467 };
468
469 let mut base_request = client
470 .run_instances()
471 .image_id(image_id)
472 .instance_type(instance_type.as_str().into())
473 .key_name(testbed_id)
474 .security_groups(&self.settings.testbed_id)
475 .tag_specifications(tags);
476
477 if role == InstanceRole::Metrics {
479 base_request = base_request.block_device_mappings(storage);
480 }
481 let mut collected_instances = Vec::new();
482 if use_spot_instances && role == InstanceRole::Node {
483 let start = tokio::time::Instant::now();
484 let total_runtime = tokio::time::Duration::from_secs(300);
486 while start.elapsed() < total_runtime && collected_instances.len() < quantity {
487 display::status(format!(
488 "{}s/{}s: {}",
489 start.elapsed().as_secs(),
490 total_runtime.as_secs(),
491 collected_instances.len()
492 ));
493 let needed = (quantity - collected_instances.len()) as i32;
494 let request = base_request
495 .clone()
496 .min_count(1)
497 .max_count(needed)
498 .instance_market_options(Self::spot_options());
499 let result = request.send().await;
500 let instances = match result {
501 Ok(response) => response
502 .instances()
503 .iter()
504 .map(|i| self.make_instance(region.clone(), i))
505 .collect(),
506 Err(_) => Vec::new(),
507 };
508 collected_instances.extend(instances);
509 }
510 }
511 while collected_instances.len() < quantity {
512 let needed = (quantity - collected_instances.len()) as i32;
514 let request = base_request.clone().min_count(1).max_count(needed);
515 let response = request.send().await?;
516 let on_demand_instances = response
517 .instances()
518 .iter()
519 .map(|instance| self.make_instance(region.clone(), instance))
520 .collect::<Vec<_>>();
521 collected_instances.extend(on_demand_instances);
522 display::status(format!(
523 "collected instances: {}",
524 collected_instances.len()
525 ));
526 }
527 Ok(collected_instances)
528 }
529
530 async fn delete_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
531 where
532 I: Iterator<Item = &'a Instance> + Send,
533 {
534 let map_of_ids_by_region = instances.into_iter().fold(
535 HashMap::new(),
536 |mut acc: HashMap<String, Vec<String>>, i| {
537 acc.entry(i.region.clone()).or_default().push(i.id.clone());
538 acc
539 },
540 );
541 for (region, ids) in map_of_ids_by_region {
542 let client = self.clients.get(®ion).ok_or_else(|| {
543 CloudProviderError::Request(format!("Undefined region {:?}", region))
544 })?;
545 client
546 .terminate_instances()
547 .set_instance_ids(Some(ids))
548 .send()
549 .await?;
550 }
551 Ok(())
552 }
553
554 async fn register_ssh_public_key(&self, public_key: String) -> CloudProviderResult<()> {
555 for client in self.clients.values() {
556 let request = client
557 .import_key_pair()
558 .key_name(&self.settings.testbed_id)
559 .public_key_material(Blob::new::<String>(public_key.clone()));
560
561 let response = request.send().await;
562 Self::check_but_ignore_duplicates(response)?;
563 }
564 Ok(())
565 }
566
567 async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>> {
568 if self.check_nvme_support().await? {
569 Ok(self.nvme_mount_command())
570 } else {
571 Ok(Vec::new())
572 }
573 }
574 #[cfg(test)]
575 fn instances(&self) -> Vec<Instance> {
576 unreachable!()
579 }
580}