iota_aws_orchestrator/
settings.rs1use std::{
6 env,
7 fmt::Display,
8 fs::{self},
9 path::{Path, PathBuf},
10};
11
12use reqwest::Url;
13use serde::{Deserialize, Deserializer, de::Error};
14
15use crate::{
16 client::Instance,
17 error::{SettingsError, SettingsResult},
18};
19
20#[derive(Deserialize, Clone)]
22pub struct Repository {
23 #[serde(deserialize_with = "parse_url")]
25 pub url: Url,
26 pub commit: String,
28}
29
30fn parse_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
31where
32 D: Deserializer<'de>,
33{
34 let s: &str = Deserialize::deserialize(deserializer)?;
35 let url = Url::parse(s).map_err(D::Error::custom)?;
36
37 match url.path_segments().map(|x| x.count() >= 2) {
38 None | Some(false) => Err(D::Error::custom(SettingsError::MalformedRepositoryUrl(url))),
39 _ => Ok(url),
40 }
41}
42
43#[derive(Deserialize, Clone)]
45pub enum CloudProvider {
46 #[serde(alias = "aws")]
47 Aws,
48}
49
50#[derive(Deserialize, Clone)]
52pub struct Settings {
53 pub testbed_id: String,
57 pub cloud_provider: CloudProvider,
59 pub token_file: PathBuf,
61 pub ssh_private_key_file: PathBuf,
63 pub ssh_public_key_file: Option<PathBuf>,
67 pub regions: Vec<String>,
69 pub specs: String,
73 pub repository: Repository,
75 #[serde(default = "default_working_dir")]
78 pub working_dir: PathBuf,
79 #[serde(default = "default_results_dir")]
82 pub results_dir: PathBuf,
83 #[serde(default = "default_logs_dir")]
86 pub logs_dir: PathBuf,
87}
88
89fn default_working_dir() -> PathBuf {
90 ["~/", "working_dir"].iter().collect()
91}
92
93fn default_results_dir() -> PathBuf {
94 ["./", "results"].iter().collect()
95}
96
97fn default_logs_dir() -> PathBuf {
98 ["./", "logs"].iter().collect()
99}
100
101impl Settings {
102 pub fn load<P>(path: P) -> SettingsResult<Self>
104 where
105 P: AsRef<Path> + Display + Clone,
106 {
107 let reader = || -> Result<Self, std::io::Error> {
108 let data = fs::read(path.clone())?;
109 let data = resolve_env(std::str::from_utf8(&data).unwrap());
110 let settings: Settings = serde_json::from_slice(data.as_bytes())?;
111
112 fs::create_dir_all(&settings.results_dir)?;
113 fs::create_dir_all(&settings.logs_dir)?;
114
115 Ok(settings)
116 };
117
118 reader().map_err(|e| SettingsError::InvalidSettings {
119 file: path.to_string(),
120 message: e.to_string(),
121 })
122 }
123
124 pub fn repository_name(&self) -> String {
126 self.repository
127 .url
128 .path_segments()
129 .expect("Url should already be checked when loading settings")
130 .collect::<Vec<_>>()[1]
131 .split('.')
132 .next()
133 .unwrap()
134 .to_string()
135 }
136
137 pub fn load_token(&self) -> SettingsResult<String> {
139 match fs::read_to_string(&self.token_file) {
140 Ok(token) => Ok(token.trim_end_matches('\n').to_string()),
141 Err(e) => Err(SettingsError::InvalidTokenFile {
142 file: self.token_file.display().to_string(),
143 message: e.to_string(),
144 }),
145 }
146 }
147
148 pub fn load_ssh_public_key(&self) -> SettingsResult<String> {
150 let ssh_public_key_file = self.ssh_public_key_file.clone().unwrap_or_else(|| {
151 let mut private = self.ssh_private_key_file.clone();
152 private.set_extension("pub");
153 private
154 });
155 match fs::read_to_string(&ssh_public_key_file) {
156 Ok(token) => Ok(token.trim_end_matches('\n').to_string()),
157 Err(e) => Err(SettingsError::InvalidSshPublicKeyFile {
158 file: ssh_public_key_file.display().to_string(),
159 message: e.to_string(),
160 }),
161 }
162 }
163
164 pub fn filter_instances(&self, instance: &Instance) -> bool {
167 self.regions.contains(&instance.region)
168 && instance.specs.to_lowercase().replace('.', "")
169 == self.specs.to_lowercase().replace('.', "")
170 }
171
172 #[cfg(test)]
174 pub fn number_of_regions(&self) -> usize {
175 self.regions.len()
176 }
177
178 #[cfg(test)]
180 pub fn new_for_test() -> Self {
181 let mut path = tempfile::tempdir().unwrap().into_path();
183 path.push("test_public_key.pub");
184 let public_key = "This is a fake public key for tests";
185 fs::write(&path, public_key).unwrap();
186
187 Self {
189 testbed_id: "testbed".into(),
190 cloud_provider: CloudProvider::Aws,
191 token_file: "/path/to/token/file".into(),
192 ssh_private_key_file: "/path/to/private/key/file".into(),
193 ssh_public_key_file: Some(path),
194 regions: vec!["London".into(), "New York".into()],
195 specs: "small".into(),
196 repository: Repository {
197 url: Url::parse("https://example.net/author/repo").unwrap(),
198 commit: "main".into(),
199 },
200 working_dir: "/path/to/working_dir".into(),
201 results_dir: "results".into(),
202 logs_dir: "logs".into(),
203 }
204 }
205}
206
207fn resolve_env(s: &str) -> String {
209 let mut s = s.to_string();
210 for (name, value) in env::vars() {
211 s = s.replace(&format!("${{{}}}", name), &value);
212 }
213 if s.contains("${") {
214 eprintln!("settings.json:\n{}\n", s);
215 panic!("Unresolved env variables in the settings.json");
216 }
217 s
218}
219
220#[cfg(test)]
221mod test {
222 use reqwest::Url;
223
224 use crate::settings::Settings;
225
226 #[test]
227 fn repository_name() {
228 let mut settings = Settings::new_for_test();
229 settings.repository.url = Url::parse("https://example.com/author/name").unwrap();
230 assert_eq!(settings.repository_name(), "name");
231 }
232}