1use std::{
6 io::Write,
7 net::SocketAddr,
8 path::{Path, PathBuf},
9 sync::Arc,
10 time::Duration,
11};
12
13use async_trait::async_trait;
14use futures::future::try_join_all;
15use russh::{Channel, client, client::Msg};
16use russh_keys::key;
17use tokio::{task::JoinHandle, time::sleep};
18
19use crate::{
20 client::Instance,
21 ensure,
22 error::{SshError, SshResult},
23};
24
25#[derive(PartialEq, Eq)]
26pub enum CommandStatus {
28 Running,
29 Terminated,
30}
31
32impl CommandStatus {
33 pub fn status(command_id: &str, text: &str) -> Self {
36 if text.contains(command_id) {
37 Self::Running
38 } else {
39 Self::Terminated
40 }
41 }
42}
43
44#[derive(Clone, Default)]
46pub struct CommandContext {
47 pub background: Option<String>,
50 pub path: Option<PathBuf>,
52 pub log_file: Option<PathBuf>,
54 pub retries: usize,
56}
57
58impl CommandContext {
59 pub fn new() -> Self {
61 Self {
62 background: None,
63 path: None,
64 log_file: None,
65 retries: 0,
66 }
67 }
68
69 pub fn run_background(mut self, id: String) -> Self {
71 self.background = Some(id);
72 self
73 }
74
75 pub fn with_execute_from_path(mut self, path: PathBuf) -> Self {
77 self.path = Some(path);
78 self
79 }
80
81 pub fn with_log_file(mut self, path: PathBuf) -> Self {
83 self.log_file = Some(path);
84 self
85 }
86
87 pub fn with_retries(mut self, retries: usize) -> Self {
89 self.retries = retries;
90 self
91 }
92
93 pub fn apply<S: Into<String>>(&self, base_command: S) -> String {
95 let mut str = base_command.into();
96 if let Some(log_file) = &self.log_file {
97 str = format!("{str} |& tee {}", log_file.as_path().display());
98 }
99 if let Some(id) = &self.background {
100 str = format!("tmux new -d -s \"{id}\" \"{str}\"");
101 }
102 if let Some(exec_path) = &self.path {
103 str = format!("(cd {} && {str})", exec_path.as_path().display());
104 }
105 str
106 }
107}
108
109#[derive(Clone)]
110pub struct SshConnectionManager {
111 username: String,
113 private_key_file: PathBuf,
115 timeout: Option<Duration>,
117 retries: usize,
119}
120
121impl SshConnectionManager {
122 const RETRY_DELAY: Duration = Duration::from_secs(5);
124
125 pub fn new(username: String, private_key_file: PathBuf) -> Self {
127 Self {
128 username,
129 private_key_file,
130 timeout: None,
131 retries: 0,
132 }
133 }
134
135 pub fn with_timeout(mut self, timeout: Duration) -> Self {
137 self.timeout = Some(timeout);
138 self
139 }
140
141 pub fn with_retries(mut self, retries: usize) -> Self {
144 self.retries = retries;
145 self
146 }
147
148 pub async fn connect(&self, address: SocketAddr) -> SshResult<SshConnection> {
150 let mut error = None;
151 for _ in 0..self.retries + 1 {
152 match SshConnection::new(
153 address,
154 &self.username,
155 self.private_key_file.clone(),
156 self.timeout,
157 Some(self.retries),
158 )
159 .await
160 {
161 Ok(x) => return Ok(x),
162 Err(e) => error = Some(e),
163 }
164 sleep(Self::RETRY_DELAY).await;
165 }
166 Err(error.unwrap())
167 }
168
169 pub async fn execute<I, S>(
171 &self,
172 instances: I,
173 command: S,
174 context: CommandContext,
175 ) -> SshResult<Vec<(String, String)>>
176 where
177 I: IntoIterator<Item = Instance>,
178 S: Into<String> + Clone + Send + 'static,
179 {
180 let targets = instances
181 .into_iter()
182 .map(|instance| (instance, command.clone()));
183 self.execute_per_instance(targets, context).await
184 }
185
186 pub async fn execute_per_instance<I, S>(
188 &self,
189 instances: I,
190 context: CommandContext,
191 ) -> SshResult<Vec<(String, String)>>
192 where
193 I: IntoIterator<Item = (Instance, S)>,
194 S: Into<String> + Send + 'static,
195 {
196 let handles = self.run_per_instance(instances, context).await;
197
198 try_join_all(handles)
199 .await
200 .unwrap()
201 .into_iter()
202 .collect::<SshResult<_>>()
203 }
204
205 async fn run_per_instance<I, S>(
206 &self,
207 instances: I,
208 context: CommandContext,
209 ) -> Vec<JoinHandle<SshResult<(String, String)>>>
210 where
211 I: IntoIterator<Item = (Instance, S)>,
212 S: Into<String> + Send + 'static,
213 {
214 instances
215 .into_iter()
216 .map(|(instance, command)| {
217 let ssh_manager = self.clone();
218 let context = context.clone();
219
220 tokio::spawn(async move {
221 let connection = ssh_manager.connect(instance.ssh_address()).await?;
222
223 let command_str = command.into();
224 let mut consecutive_errors = 0;
225 loop {
226 match connection.execute(context.apply(command_str.clone())).await {
227 Ok(output) => {
228 return Ok(output);
229 }
230 Err(err) => {
231 consecutive_errors += 1;
232 if consecutive_errors > context.retries {
233 return Err(err);
234 }
235 }
236 }
237
238 sleep(Self::RETRY_DELAY).await;
239 }
240 })
241 })
242 .collect::<Vec<_>>()
243 }
244
245 pub async fn wait_for_command<I>(
247 &self,
248 instances: I,
249 command_id: &str,
250 status: CommandStatus,
251 ) -> SshResult<()>
252 where
253 I: IntoIterator<Item = Instance> + Clone,
254 {
255 let mut consecutive_errors = 0;
256 loop {
257 sleep(Self::RETRY_DELAY).await;
258
259 match self
260 .execute(
261 instances.clone(),
262 "(tmux ls || true)",
263 CommandContext::default(),
264 )
265 .await
266 {
267 Ok(result) => {
268 consecutive_errors = 0;
269 if result
270 .iter()
271 .all(|(stdout, _)| CommandStatus::status(command_id, stdout) == status)
272 {
273 break;
274 }
275 }
276 Err(e) => {
277 consecutive_errors += 1;
278 if consecutive_errors >= 5 {
279 return Err(e);
280 }
281 }
282 }
283 }
284 Ok(())
285 }
286
287 pub async fn wait_for_success<I, S>(&self, instances: I)
288 where
289 I: IntoIterator<Item = (Instance, S)> + Clone,
290 S: Into<String> + Send + 'static + Clone,
291 {
292 match self
293 .execute_per_instance(instances.clone(), CommandContext::default().with_retries(5))
294 .await
295 {
296 Ok(_) => {}
297 Err(e) => {
298 panic!("Command execution failed on one or more instances: {e}");
300 }
301 }
302 }
303
304 pub async fn kill<I>(&self, instances: I, command_id: &str) -> SshResult<()>
306 where
307 I: IntoIterator<Item = Instance>,
308 {
309 let ssh_command = format!("(tmux kill-session -t {command_id} || true)");
310 let targets = instances.into_iter().map(|x| (x, ssh_command.clone()));
311 self.execute_per_instance(targets, CommandContext::default().with_retries(5))
312 .await?;
313 Ok(())
314 }
315}
316
317struct Session {}
318
319#[async_trait]
320impl client::Handler for Session {
321 type Error = russh::Error;
322
323 async fn check_server_key(
324 &mut self,
325 _server_public_key: &key::PublicKey,
326 ) -> Result<bool, Self::Error> {
327 Ok(true)
328 }
329}
330
331pub struct SshConnection {
333 session: client::Handle<Session>,
335 address: SocketAddr,
337 retries: usize,
339}
340
341impl SshConnection {
342 const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
344
345 pub async fn new<P: AsRef<Path>>(
347 address: SocketAddr,
348 username: &str,
349 private_key_file: P,
350 inactivity_timeout: Option<Duration>,
351 retries: Option<usize>,
352 ) -> SshResult<Self> {
353 let key = russh_keys::load_secret_key(private_key_file, None)
354 .map_err(|error| SshError::PrivateKeyError { address, error })?;
355
356 let config = client::Config {
357 inactivity_timeout: inactivity_timeout.or(Some(Self::DEFAULT_TIMEOUT)),
358 ..<_>::default()
359 };
360
361 let mut session = client::connect(Arc::new(config), address, Session {})
362 .await
363 .map_err(|error| SshError::ConnectionError { address, error })?;
364
365 let _auth_res = session
366 .authenticate_publickey(username, Arc::new(key))
367 .await
368 .map_err(|error| SshError::SessionError { address, error })?;
369
370 Ok(Self {
371 session,
372 address,
373 retries: retries.unwrap_or_default(),
374 })
375 }
376
377 fn make_session_error(&self, error: russh::Error) -> SshError {
379 SshError::SessionError {
380 address: self.address,
381 error,
382 }
383 }
384
385 pub async fn execute(&self, command: String) -> SshResult<(String, String)> {
387 let mut error = None;
388 for _ in 0..self.retries + 1 {
389 let channel = match self.session.channel_open_session().await {
390 Ok(x) => x,
391 Err(e) => {
392 error = Some(self.make_session_error(e));
393 continue;
394 }
395 };
396 match self.execute_impl(channel, command.clone()).await {
397 r @ Ok(..) => return r,
398 Err(e) => error = Some(e),
399 }
400 }
401 Err(error.unwrap())
402 }
403
404 async fn execute_impl(
407 &self,
408 mut channel: Channel<Msg>,
409 command: String,
410 ) -> SshResult<(String, String)> {
411 channel
412 .exec(true, command.clone())
413 .await
414 .map_err(|e| self.make_session_error(e))?;
415
416 let mut output = Vec::new();
417 let mut exit_code = None;
418
419 while let Some(msg) = channel.wait().await {
420 match msg {
421 russh::ChannelMsg::Data { ref data } => output.write_all(data).unwrap(),
422 russh::ChannelMsg::ExitStatus { exit_status } => exit_code = Some(exit_status),
423 _ => {}
424 }
425 }
426
427 channel
428 .close()
429 .await
430 .map_err(|error| self.make_session_error(error))?;
431
432 let output_str: String = String::from_utf8_lossy(&output).into();
433
434 ensure!(
435 exit_code.is_some() && exit_code.unwrap() == 0,
436 SshError::NonZeroExitCode {
437 address: self.address,
438 code: exit_code.unwrap(),
439 message: output_str,
440 command,
441 }
442 );
443
444 Ok((output_str.clone(), output_str))
445 }
446
447 pub async fn download<P: AsRef<Path>>(&self, path: P) -> SshResult<String> {
451 let mut error = None;
452 for _ in 0..self.retries + 1 {
453 match self
454 .execute(format!("cat {}", path.as_ref().to_str().unwrap()))
455 .await
456 {
457 Ok((file_data, _)) => return Ok(file_data),
458 Err(err) => error = Some(err),
459 }
460 }
461 Err(error.unwrap())
462 }
463}