iota_aws_orchestrator/
settings.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    env,
7    fmt::Display,
8    fs::{self},
9    path::{Path, PathBuf},
10};
11
12use reqwest::Url;
13use serde::{Deserialize, Deserializer, de::Error};
14
15use crate::{
16    client::Instance,
17    error::{SettingsError, SettingsResult},
18};
19
20/// The git repository holding the codebase.
21#[derive(Deserialize, Clone)]
22pub struct Repository {
23    /// The url of the repository.
24    #[serde(deserialize_with = "parse_url")]
25    pub url: Url,
26    /// The commit (or branch name) to deploy.
27    pub commit: String,
28}
29
30fn parse_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
31where
32    D: Deserializer<'de>,
33{
34    let s: &str = Deserialize::deserialize(deserializer)?;
35    let url = Url::parse(s).map_err(D::Error::custom)?;
36
37    match url.path_segments().map(|x| x.count() >= 2) {
38        None | Some(false) => Err(D::Error::custom(SettingsError::MalformedRepositoryUrl(url))),
39        _ => Ok(url),
40    }
41}
42
43/// The list of supported cloud providers.
44#[derive(Deserialize, Clone)]
45pub enum CloudProvider {
46    #[serde(alias = "aws")]
47    Aws,
48}
49
50/// The testbed settings. Those are topically specified in a file.
51#[derive(Deserialize, Clone)]
52pub struct Settings {
53    /// The testbed unique id. This allows multiple users to run concurrent
54    /// testbeds on the same cloud provider's account without interference
55    /// with each others.
56    pub testbed_id: String,
57    /// The cloud provider hosting the testbed.
58    pub cloud_provider: CloudProvider,
59    /// The path to the secret token for authentication with the cloud provider.
60    pub token_file: PathBuf,
61    /// The ssh private key to access the instances.
62    pub ssh_private_key_file: PathBuf,
63    /// The corresponding ssh public key registered on the instances. If not
64    /// specified. the public key defaults the same path as the private key
65    /// with an added extension 'pub'.
66    pub ssh_public_key_file: Option<PathBuf>,
67    /// The list of cloud provider regions to deploy the testbed.
68    pub regions: Vec<String>,
69    /// The specs of the instances to deploy. Those are dependent on the cloud
70    /// provider, e.g., specifying 't3.medium' creates instances with 2 vCPU
71    /// and 4GBo of ram on AWS.
72    pub specs: String,
73    /// The details of the git reposit to deploy.
74    pub repository: Repository,
75    /// The working directory on the remote instance (containing all
76    /// configuration files).
77    #[serde(default = "default_working_dir")]
78    pub working_dir: PathBuf,
79    /// The directory (on the local machine) where to save benchmarks
80    /// measurements.
81    #[serde(default = "default_results_dir")]
82    pub results_dir: PathBuf,
83    /// The directory (on the local machine) where to download logs files from
84    /// the instances.
85    #[serde(default = "default_logs_dir")]
86    pub logs_dir: PathBuf,
87}
88
89fn default_working_dir() -> PathBuf {
90    ["~/", "working_dir"].iter().collect()
91}
92
93fn default_results_dir() -> PathBuf {
94    ["./", "results"].iter().collect()
95}
96
97fn default_logs_dir() -> PathBuf {
98    ["./", "logs"].iter().collect()
99}
100
101impl Settings {
102    /// Load the settings from a json file.
103    pub fn load<P>(path: P) -> SettingsResult<Self>
104    where
105        P: AsRef<Path> + Display + Clone,
106    {
107        let reader = || -> Result<Self, std::io::Error> {
108            let data = fs::read(path.clone())?;
109            let data = resolve_env(std::str::from_utf8(&data).unwrap());
110            let settings: Settings = serde_json::from_slice(data.as_bytes())?;
111
112            fs::create_dir_all(&settings.results_dir)?;
113            fs::create_dir_all(&settings.logs_dir)?;
114
115            Ok(settings)
116        };
117
118        reader().map_err(|e| SettingsError::InvalidSettings {
119            file: path.to_string(),
120            message: e.to_string(),
121        })
122    }
123
124    /// Get the name of the repository (from its url).
125    pub fn repository_name(&self) -> String {
126        self.repository
127            .url
128            .path_segments()
129            .expect("Url should already be checked when loading settings")
130            .collect::<Vec<_>>()[1]
131            .split('.')
132            .next()
133            .unwrap()
134            .to_string()
135    }
136
137    /// Load the secret token to authenticate with the cloud provider.
138    pub fn load_token(&self) -> SettingsResult<String> {
139        match fs::read_to_string(&self.token_file) {
140            Ok(token) => Ok(token.trim_end_matches('\n').to_string()),
141            Err(e) => Err(SettingsError::InvalidTokenFile {
142                file: self.token_file.display().to_string(),
143                message: e.to_string(),
144            }),
145        }
146    }
147
148    /// Load the ssh public key from file.
149    pub fn load_ssh_public_key(&self) -> SettingsResult<String> {
150        let ssh_public_key_file = self.ssh_public_key_file.clone().unwrap_or_else(|| {
151            let mut private = self.ssh_private_key_file.clone();
152            private.set_extension("pub");
153            private
154        });
155        match fs::read_to_string(&ssh_public_key_file) {
156            Ok(token) => Ok(token.trim_end_matches('\n').to_string()),
157            Err(e) => Err(SettingsError::InvalidSshPublicKeyFile {
158                file: ssh_public_key_file.display().to_string(),
159                message: e.to_string(),
160            }),
161        }
162    }
163
164    /// Check whether the input instance matches the criteria described in the
165    /// settings.
166    pub fn filter_instances(&self, instance: &Instance) -> bool {
167        self.regions.contains(&instance.region)
168            && instance.specs.to_lowercase().replace('.', "")
169                == self.specs.to_lowercase().replace('.', "")
170    }
171
172    /// The number of regions specified in the settings.
173    #[cfg(test)]
174    pub fn number_of_regions(&self) -> usize {
175        self.regions.len()
176    }
177
178    /// Test settings for unit tests.
179    #[cfg(test)]
180    pub fn new_for_test() -> Self {
181        // Create a temporary public key file.
182        let mut path = tempfile::tempdir().unwrap().into_path();
183        path.push("test_public_key.pub");
184        let public_key = "This is a fake public key for tests";
185        fs::write(&path, public_key).unwrap();
186
187        // Return set settings.
188        Self {
189            testbed_id: "testbed".into(),
190            cloud_provider: CloudProvider::Aws,
191            token_file: "/path/to/token/file".into(),
192            ssh_private_key_file: "/path/to/private/key/file".into(),
193            ssh_public_key_file: Some(path),
194            regions: vec!["London".into(), "New York".into()],
195            specs: "small".into(),
196            repository: Repository {
197                url: Url::parse("https://example.net/author/repo").unwrap(),
198                commit: "main".into(),
199            },
200            working_dir: "/path/to/working_dir".into(),
201            results_dir: "results".into(),
202            logs_dir: "logs".into(),
203        }
204    }
205}
206
207// Resolves ${ENV} into it's value for each env variable.
208fn resolve_env(s: &str) -> String {
209    let mut s = s.to_string();
210    for (name, value) in env::vars() {
211        s = s.replace(&format!("${{{}}}", name), &value);
212    }
213    if s.contains("${") {
214        eprintln!("settings.json:\n{}\n", s);
215        panic!("Unresolved env variables in the settings.json");
216    }
217    s
218}
219
220#[cfg(test)]
221mod test {
222    use reqwest::Url;
223
224    use crate::settings::Settings;
225
226    #[test]
227    fn repository_name() {
228        let mut settings = Settings::new_for_test();
229        settings.repository.url = Url::parse("https://example.com/author/name").unwrap();
230        assert_eq!(settings.repository_name(), "name");
231    }
232}