iota_aws_orchestrator/client/
aws.rs1use std::{
6 collections::HashMap,
7 fmt::{Debug, Display},
8};
9
10use aws_runtime::env_config::file::{EnvConfigFileKind, EnvConfigFiles};
11use aws_sdk_ec2::{
12 config::Region,
13 primitives::Blob,
14 types::{
15 BlockDeviceMapping, EbsBlockDevice, EphemeralNvmeSupport, Filter, ResourceType, Tag,
16 TagSpecification, VolumeType, builders::FilterBuilder,
17 },
18};
19use aws_smithy_runtime_api::client::{behavior_version::BehaviorVersion, result::SdkError};
20use serde::Serialize;
21
22use super::{Instance, InstanceRole, ServerProviderClient};
23use crate::{
24 error::{CloudProviderError, CloudProviderResult},
25 settings::Settings,
26};
27
28impl<T> From<SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>
30 for CloudProviderError
31where
32 T: Debug + std::error::Error + Send + Sync + 'static,
33{
34 fn from(e: SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>) -> Self {
35 Self::Request(format!("{:?}", e.into_source()))
36 }
37}
38
39pub struct AwsClient {
41 settings: Settings,
43 clients: HashMap<String, aws_sdk_ec2::Client>,
45}
46
47impl Display for AwsClient {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(f, "AWS EC2 client v{}", aws_sdk_ec2::meta::PKG_VERSION)
50 }
51}
52
53impl AwsClient {
54 const UBUNTU_NAME_PATTERN: &'static str =
55 "ubuntu/images/hvm-ssd-gp3/ubuntu-noble-24.04-amd64-server-*";
56 const CANONICAL_OWNER_ID: &'static str = "099720109477";
57
58 pub async fn new(settings: Settings) -> Self {
60 let profile_files = EnvConfigFiles::builder()
61 .with_file(EnvConfigFileKind::Credentials, &settings.token_file)
62 .with_contents(EnvConfigFileKind::Config, "[default]\noutput=json")
63 .build();
64
65 let mut clients = HashMap::new();
66 for region in settings.regions.clone() {
67 let sdk_config = aws_config::defaults(BehaviorVersion::latest())
68 .region(Region::new(region.clone()))
69 .profile_files(profile_files.clone())
70 .load()
71 .await;
72 let client = aws_sdk_ec2::Client::new(&sdk_config);
73 clients.insert(region, client);
74 }
75
76 Self { settings, clients }
77 }
78
79 fn check_but_ignore_duplicates<T, E>(
82 response: Result<
83 T,
84 SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>,
85 >,
86 ) -> CloudProviderResult<()>
87 where
88 E: Debug + std::error::Error + Send + Sync + 'static,
89 {
90 if let Err(e) = response {
91 let error_message = format!("{e:?}");
92 if !error_message.to_lowercase().contains("duplicate") {
93 return Err(e.into());
94 }
95 }
96 Ok(())
97 }
98 fn get_tag_value(instance: &aws_sdk_ec2::types::Instance, key: &str) -> Option<String> {
99 instance
100 .tags()
101 .iter()
102 .find(|tag| tag.key().is_some_and(|k| k == key))
103 .and_then(|tag| tag.value().map(|v| v.to_string()))
104 }
105 fn make_instance(
108 &self,
109 region: String,
110 aws_instance: &aws_sdk_ec2::types::Instance,
111 ) -> Instance {
112 let role: InstanceRole = Self::get_tag_value(aws_instance, "Role")
113 .expect("AWS instance should have a role")
114 .as_str()
115 .into();
116 Instance {
117 id: aws_instance
118 .instance_id()
119 .expect("AWS instance should have an id")
120 .into(),
121 region,
122 main_ip: aws_instance
123 .public_ip_address()
124 .unwrap_or("0.0.0.0") .parse()
126 .expect("AWS instance should have a valid ip"),
127 private_ip: aws_instance
128 .private_ip_address()
129 .unwrap_or("0.0.0.0") .parse()
131 .expect("AWS instance should have a valid ip"),
132 tags: vec![self.settings.testbed_id.clone()],
133 specs: format!(
134 "{:?}",
135 aws_instance
136 .instance_type()
137 .expect("AWS instance should have a type")
138 ),
139 status: format!(
140 "{:?}",
141 aws_instance
142 .state()
143 .expect("AWS instance should have a state")
144 .name()
145 .expect("AWS status should have a name")
146 ),
147 role,
148 }
149 }
150
151 async fn find_image_id(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<String> {
154 let filters = [
156 FilterBuilder::default()
158 .name("name")
159 .values(Self::UBUNTU_NAME_PATTERN)
160 .build(),
161 FilterBuilder::default()
163 .name("owner-id")
164 .values(Self::CANONICAL_OWNER_ID)
165 .build(),
166 FilterBuilder::default()
168 .name("state")
169 .values("available")
170 .build(),
171 ];
172
173 let request = client.describe_images().set_filters(Some(filters.to_vec()));
175 let response = request.send().await?;
176
177 let mut images = response.images().to_vec();
179 images.sort_by(|a, b| {
180 let a_date = a.creation_date().unwrap_or("");
181 let b_date = b.creation_date().unwrap_or("");
182 b_date.cmp(a_date) });
184
185 let image = images
187 .first()
188 .ok_or_else(|| CloudProviderError::Request("Cannot find Ubuntu 24.04 image".into()))?;
189
190 image
191 .image_id
192 .clone()
193 .ok_or_else(|| CloudProviderError::UnexpectedResponse("Image without ID".into()))
194 }
195
196 async fn create_security_group(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<()> {
199 let request = client
201 .create_security_group()
202 .group_name(&self.settings.testbed_id)
203 .description("Allow all traffic (used for benchmarks).");
204
205 let response = request.send().await;
206 Self::check_but_ignore_duplicates(response)?;
207
208 for protocol in ["tcp", "udp", "icmp", "icmpv6"] {
210 let mut request = client
211 .authorize_security_group_ingress()
212 .group_name(&self.settings.testbed_id)
213 .ip_protocol(protocol)
214 .cidr_ip("0.0.0.0/0"); if protocol == "icmp" || protocol == "icmpv6" {
216 request = request.from_port(-1).to_port(-1);
217 } else {
218 request = request.from_port(0).to_port(65535);
219 }
220
221 let response = request.send().await;
222 Self::check_but_ignore_duplicates(response)?;
223 }
224 Ok(())
225 }
226
227 fn nvme_mount_command(&self) -> Vec<String> {
229 const DRIVE: &str = "nvme1n1";
230 let directory = self.settings.working_dir.display();
231 vec![
232 format!("(sudo mkfs.ext4 -E nodiscard /dev/{DRIVE} || true)"),
233 format!("(sudo mount /dev/{DRIVE} {directory} || true)"),
234 format!("sudo chmod 777 -R {directory}"),
235 ]
236 }
237
238 async fn check_nvme_support(&self) -> CloudProviderResult<bool> {
241 let client = match self
244 .settings
245 .regions
246 .first()
247 .and_then(|x| self.clients.get(x))
248 {
249 Some(client) => client,
250 None => return Ok(false),
251 };
252
253 let request = client
255 .describe_instance_types()
256 .instance_types(self.settings.node_specs.as_str().into());
257
258 let response = request.send().await?;
260
261 if let Some(info) = response.instance_types().first() {
263 if let Some(info) = info.instance_storage_info() {
264 if info.nvme_support() == Some(&EphemeralNvmeSupport::Required) {
265 return Ok(true);
266 }
267 }
268 }
269 Ok(false)
270 }
271}
272
273#[async_trait::async_trait]
274impl ServerProviderClient for AwsClient {
275 const USERNAME: &'static str = "ubuntu";
276
277 async fn list_instances_by_role(
278 &self,
279 role: InstanceRole,
280 ) -> CloudProviderResult<Vec<Instance>> {
281 let filter_name = Filter::builder()
282 .name("tag:Name")
283 .values(self.settings.testbed_id.clone())
284 .build();
285 let filter_role = Filter::builder()
286 .name("tag:Role")
287 .values(role.to_string())
288 .build();
289 let filter_state = Filter::builder()
290 .name("instance-state-name")
291 .values("pending")
292 .values("running")
293 .values("stopping")
294 .values("stopped")
295 .build();
296
297 let mut instances = Vec::new();
298 for (region, client) in &self.clients {
299 let request = client.describe_instances().set_filters(Some(vec![
300 filter_name.clone(),
301 filter_role.clone(),
302 filter_state.clone(),
303 ]));
304 let response = request.send().await?;
305 for reservation in response.reservations() {
306 for instance in reservation.instances() {
307 instances.push(self.make_instance(region.clone(), instance));
308 }
309 }
310 }
311
312 Ok(instances)
313 }
314
315 async fn list_instances_by_region_and_ids(
316 &self,
317 regions_and_ids: Vec<(String, String)>,
318 ) -> CloudProviderResult<Vec<Instance>> {
319 let mut instances = Vec::new();
320 let ids_by_region: HashMap<String, Vec<String>> =
321 regions_and_ids
322 .into_iter()
323 .fold(HashMap::new(), |mut acc, (k, v)| {
324 acc.entry(k).or_default().push(v);
325 acc
326 });
327 for (region, client) in &self.clients {
328 let request = client
329 .describe_instances()
330 .set_instance_ids(ids_by_region.get(region).cloned());
331 let response = request.send().await?;
332 for reservation in response.reservations() {
333 for instance in reservation.instances() {
334 instances.push(self.make_instance(region.clone(), instance));
335 }
336 }
337 }
338
339 Ok(instances)
340 }
341
342 async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
343 where
344 I: Iterator<Item = &'a Instance> + Send,
345 {
346 let mut instance_ids = HashMap::new();
347 for instance in instances {
348 instance_ids
349 .entry(&instance.region)
350 .or_insert_with(Vec::new)
351 .push(instance.id.clone());
352 }
353
354 for (region, client) in &self.clients {
355 let ids = instance_ids.remove(®ion.to_string());
356 if ids.is_some() {
357 client
358 .start_instances()
359 .set_instance_ids(ids)
360 .send()
361 .await?;
362 }
363 }
364 Ok(())
365 }
366
367 async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
368 where
369 I: Iterator<Item = &'a Instance> + Send,
370 {
371 let mut instance_ids = HashMap::new();
372 for instance in instances {
373 instance_ids
374 .entry(&instance.region)
375 .or_insert_with(Vec::new)
376 .push(instance.id.clone());
377 }
378
379 for (region, client) in &self.clients {
380 let ids = instance_ids.remove(®ion.to_string());
381 if ids.is_some() {
382 client.stop_instances().set_instance_ids(ids).send().await?;
383 }
384 }
385 Ok(())
386 }
387
388 async fn create_instance<S>(
389 &self,
390 region: S,
391 role: InstanceRole,
392 ) -> CloudProviderResult<Instance>
393 where
394 S: Into<String> + Serialize + Send,
395 {
396 let region = region.into();
397 let testbed_id = &self.settings.testbed_id;
398
399 let client = self
400 .clients
401 .get(®ion)
402 .ok_or_else(|| CloudProviderError::Request(format!("Undefined region {region:?}")))?;
403
404 self.create_security_group(client).await?;
406
407 let image_id = self.find_image_id(client).await?;
409
410 let tags = TagSpecification::builder()
412 .resource_type(ResourceType::Instance)
413 .tags(Tag::builder().key("Name").value(testbed_id).build())
414 .tags(Tag::builder().key("Role").value(role.to_string()).build())
415 .build();
416
417 let storage = BlockDeviceMapping::builder()
418 .device_name("/dev/sda1")
419 .ebs(
420 EbsBlockDevice::builder()
421 .delete_on_termination(true)
422 .volume_size(500)
423 .volume_type(VolumeType::Gp2)
424 .build(),
425 )
426 .build();
427 let instance_type = match role {
428 InstanceRole::Node => &self.settings.node_specs,
429 InstanceRole::Metrics => &self.settings.metrics_specs,
430 InstanceRole::Client => &self.settings.client_specs,
431 };
432 let request = client
433 .run_instances()
434 .image_id(image_id)
435 .instance_type(instance_type.as_str().into())
436 .key_name(testbed_id)
437 .min_count(1)
438 .max_count(1)
439 .security_groups(&self.settings.testbed_id)
440 .block_device_mappings(storage)
441 .tag_specifications(tags);
442
443 let response = request.send().await?;
444 let instance = &response
445 .instances()
446 .first()
447 .expect("AWS instances list should contain instances");
448
449 Ok(self.make_instance(region, instance))
450 }
451
452 async fn delete_instance(&self, instance: Instance) -> CloudProviderResult<()> {
453 let client = self.clients.get(&instance.region).ok_or_else(|| {
454 CloudProviderError::Request(format!("Undefined region {:?}", instance.region))
455 })?;
456
457 client
458 .terminate_instances()
459 .set_instance_ids(Some(vec![instance.id.clone()]))
460 .send()
461 .await?;
462
463 Ok(())
464 }
465
466 async fn register_ssh_public_key(&self, public_key: String) -> CloudProviderResult<()> {
467 for client in self.clients.values() {
468 let request = client
469 .import_key_pair()
470 .key_name(&self.settings.testbed_id)
471 .public_key_material(Blob::new::<String>(public_key.clone()));
472
473 let response = request.send().await;
474 Self::check_but_ignore_duplicates(response)?;
475 }
476 Ok(())
477 }
478
479 async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>> {
480 if self.check_nvme_support().await? {
481 Ok(self.nvme_mount_command())
482 } else {
483 Ok(Vec::new())
484 }
485 }
486 #[cfg(test)]
487 fn instances(&self) -> Vec<Instance> {
488 unreachable!()
491 }
492}