iota_aws_orchestrator/client/
aws.rs1use std::{
6 collections::HashMap,
7 fmt::{Debug, Display},
8};
9
10use aws_runtime::env_config::file::{EnvConfigFileKind, EnvConfigFiles};
11use aws_sdk_ec2::{
12 config::Region,
13 primitives::Blob,
14 types::{
15 BlockDeviceMapping, EbsBlockDevice, EphemeralNvmeSupport, Filter, ResourceType, Tag,
16 TagSpecification, VolumeType,
17 },
18};
19use aws_smithy_runtime_api::client::{behavior_version::BehaviorVersion, result::SdkError};
20use serde::Serialize;
21
22use super::{Instance, ServerProviderClient};
23use crate::{
24 error::{CloudProviderError, CloudProviderResult},
25 settings::Settings,
26};
27
28impl<T> From<SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>
30 for CloudProviderError
31where
32 T: Debug + std::error::Error + Send + Sync + 'static,
33{
34 fn from(e: SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>) -> Self {
35 Self::Request(format!("{:?}", e.into_source()))
36 }
37}
38
39pub struct AwsClient {
41 settings: Settings,
43 clients: HashMap<String, aws_sdk_ec2::Client>,
45}
46
47impl Display for AwsClient {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(f, "AWS EC2 client v{}", aws_sdk_ec2::meta::PKG_VERSION)
50 }
51}
52
53impl AwsClient {
54 const OS_IMAGE: &'static str =
55 "Canonical, Ubuntu, 22.04 LTS, amd64 jammy image build on 2023-02-16";
56
57 pub async fn new(settings: Settings) -> Self {
59 let profile_files = EnvConfigFiles::builder()
60 .with_file(EnvConfigFileKind::Credentials, &settings.token_file)
61 .with_contents(EnvConfigFileKind::Config, "[default]\noutput=json")
62 .build();
63
64 let mut clients = HashMap::new();
65 for region in settings.regions.clone() {
66 let sdk_config = aws_config::defaults(BehaviorVersion::latest())
67 .region(Region::new(region.clone()))
68 .profile_files(profile_files.clone())
69 .load()
70 .await;
71 let client = aws_sdk_ec2::Client::new(&sdk_config);
72 clients.insert(region, client);
73 }
74
75 Self { settings, clients }
76 }
77
78 fn check_but_ignore_duplicates<T, E>(
81 response: Result<
82 T,
83 SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>,
84 >,
85 ) -> CloudProviderResult<()>
86 where
87 E: Debug + std::error::Error + Send + Sync + 'static,
88 {
89 if let Err(e) = response {
90 let error_message = format!("{e:?}");
91 if !error_message.to_lowercase().contains("duplicate") {
92 return Err(e.into());
93 }
94 }
95 Ok(())
96 }
97
98 fn make_instance(
101 &self,
102 region: String,
103 aws_instance: &aws_sdk_ec2::types::Instance,
104 ) -> Instance {
105 Instance {
106 id: aws_instance
107 .instance_id()
108 .expect("AWS instance should have an id")
109 .into(),
110 region,
111 main_ip: aws_instance
112 .public_ip_address()
113 .unwrap_or("0.0.0.0") .parse()
115 .expect("AWS instance should have a valid ip"),
116 tags: vec![self.settings.testbed_id.clone()],
117 specs: format!(
118 "{:?}",
119 aws_instance
120 .instance_type()
121 .expect("AWS instance should have a type")
122 ),
123 status: format!(
124 "{:?}",
125 aws_instance
126 .state()
127 .expect("AWS instance should have a state")
128 .name()
129 .expect("AWS status should have a name")
130 ),
131 }
132 }
133
134 async fn find_image_id(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<String> {
137 let request = client.describe_images().filters(
139 Filter::builder()
140 .name("description")
141 .values(Self::OS_IMAGE)
142 .build(),
143 );
144 let response = request.send().await?;
145
146 response
148 .images()
149 .first()
150 .ok_or_else(|| CloudProviderError::Request("Cannot find image id".into()))?
151 .image_id
152 .clone()
153 .ok_or_else(|| {
154 CloudProviderError::UnexpectedResponse(
155 "Received image description without id".into(),
156 )
157 })
158 }
159
160 async fn create_security_group(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<()> {
163 let request = client
165 .create_security_group()
166 .group_name(&self.settings.testbed_id)
167 .description("Allow all traffic (used for benchmarks).");
168
169 let response = request.send().await;
170 Self::check_but_ignore_duplicates(response)?;
171
172 for protocol in ["tcp", "udp", "icmp", "icmpv6"] {
174 let mut request = client
175 .authorize_security_group_ingress()
176 .group_name(&self.settings.testbed_id)
177 .ip_protocol(protocol)
178 .cidr_ip("0.0.0.0/0"); if protocol == "icmp" || protocol == "icmpv6" {
180 request = request.from_port(-1).to_port(-1);
181 } else {
182 request = request.from_port(0).to_port(65535);
183 }
184
185 let response = request.send().await;
186 Self::check_but_ignore_duplicates(response)?;
187 }
188 Ok(())
189 }
190
191 fn nvme_mount_command(&self) -> Vec<String> {
193 const DRIVE: &str = "nvme1n1";
194 let directory = self.settings.working_dir.display();
195 vec![
196 format!("(sudo mkfs.ext4 -E nodiscard /dev/{DRIVE} || true)"),
197 format!("(sudo mount /dev/{DRIVE} {directory} || true)"),
198 format!("sudo chmod 777 -R {directory}"),
199 ]
200 }
201
202 async fn check_nvme_support(&self) -> CloudProviderResult<bool> {
205 let client = match self
208 .settings
209 .regions
210 .first()
211 .and_then(|x| self.clients.get(x))
212 {
213 Some(client) => client,
214 None => return Ok(false),
215 };
216
217 let request = client
219 .describe_instance_types()
220 .instance_types(self.settings.specs.as_str().into());
221
222 let response = request.send().await?;
224
225 if let Some(info) = response.instance_types().first() {
227 if let Some(info) = info.instance_storage_info() {
228 if info.nvme_support() == Some(&EphemeralNvmeSupport::Required) {
229 return Ok(true);
230 }
231 }
232 }
233 Ok(false)
234 }
235}
236
237#[async_trait::async_trait]
238impl ServerProviderClient for AwsClient {
239 const USERNAME: &'static str = "ubuntu";
240
241 async fn list_instances(&self) -> CloudProviderResult<Vec<Instance>> {
242 let filter = Filter::builder()
243 .name("tag:Name")
244 .values(self.settings.testbed_id.clone())
245 .build();
246
247 let mut instances = Vec::new();
248 for (region, client) in &self.clients {
249 let request = client.describe_instances().filters(filter.clone());
250 for reservation in request.send().await?.reservations() {
251 for instance in reservation.instances() {
252 instances.push(self.make_instance(region.clone(), instance));
253 }
254 }
255 }
256
257 Ok(instances)
258 }
259
260 async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
261 where
262 I: Iterator<Item = &'a Instance> + Send,
263 {
264 let mut instance_ids = HashMap::new();
265 for instance in instances {
266 instance_ids
267 .entry(&instance.region)
268 .or_insert_with(Vec::new)
269 .push(instance.id.clone());
270 }
271
272 for (region, client) in &self.clients {
273 let ids = instance_ids.remove(®ion.to_string());
274 if ids.is_some() {
275 client
276 .start_instances()
277 .set_instance_ids(ids)
278 .send()
279 .await?;
280 }
281 }
282 Ok(())
283 }
284
285 async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
286 where
287 I: Iterator<Item = &'a Instance> + Send,
288 {
289 let mut instance_ids = HashMap::new();
290 for instance in instances {
291 instance_ids
292 .entry(&instance.region)
293 .or_insert_with(Vec::new)
294 .push(instance.id.clone());
295 }
296
297 for (region, client) in &self.clients {
298 let ids = instance_ids.remove(®ion.to_string());
299 if ids.is_some() {
300 client.stop_instances().set_instance_ids(ids).send().await?;
301 }
302 }
303 Ok(())
304 }
305
306 async fn create_instance<S>(&self, region: S) -> CloudProviderResult<Instance>
307 where
308 S: Into<String> + Serialize + Send,
309 {
310 let region = region.into();
311 let testbed_id = &self.settings.testbed_id;
312
313 let client = self
314 .clients
315 .get(®ion)
316 .ok_or_else(|| CloudProviderError::Request(format!("Undefined region {region:?}")))?;
317
318 self.create_security_group(client).await?;
320
321 let image_id = self.find_image_id(client).await?;
323
324 let tags = TagSpecification::builder()
326 .resource_type(ResourceType::Instance)
327 .tags(Tag::builder().key("Name").value(testbed_id).build())
328 .build();
329
330 let storage = BlockDeviceMapping::builder()
331 .device_name("/dev/sda1")
332 .ebs(
333 EbsBlockDevice::builder()
334 .delete_on_termination(true)
335 .volume_size(500)
336 .volume_type(VolumeType::Gp2)
337 .build(),
338 )
339 .build();
340
341 let request = client
342 .run_instances()
343 .image_id(image_id)
344 .instance_type(self.settings.specs.as_str().into())
345 .key_name(testbed_id)
346 .min_count(1)
347 .max_count(1)
348 .security_groups(&self.settings.testbed_id)
349 .block_device_mappings(storage)
350 .tag_specifications(tags);
351
352 let response = request.send().await?;
353 let instance = &response
354 .instances()
355 .first()
356 .expect("AWS instances list should contain instances");
357
358 Ok(self.make_instance(region, instance))
359 }
360
361 async fn delete_instance(&self, instance: Instance) -> CloudProviderResult<()> {
362 let client = self.clients.get(&instance.region).ok_or_else(|| {
363 CloudProviderError::Request(format!("Undefined region {:?}", instance.region))
364 })?;
365
366 client
367 .terminate_instances()
368 .set_instance_ids(Some(vec![instance.id.clone()]))
369 .send()
370 .await?;
371
372 Ok(())
373 }
374
375 async fn register_ssh_public_key(&self, public_key: String) -> CloudProviderResult<()> {
376 for client in self.clients.values() {
377 let request = client
378 .import_key_pair()
379 .key_name(&self.settings.testbed_id)
380 .public_key_material(Blob::new::<String>(public_key.clone()));
381
382 let response = request.send().await;
383 Self::check_but_ignore_duplicates(response)?;
384 }
385 Ok(())
386 }
387
388 async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>> {
389 if self.check_nvme_support().await? {
390 Ok(self.nvme_mount_command())
391 } else {
392 Ok(Vec::new())
393 }
394 }
395}