1use std::{
6 io::Write,
7 net::SocketAddr,
8 path::{Path, PathBuf},
9 sync::Arc,
10 time::Duration,
11};
12
13use async_trait::async_trait;
14use futures::future::try_join_all;
15use russh::{Channel, client, client::Msg, keys::PrivateKeyWithHashAlg};
16use tokio::{task::JoinHandle, time::sleep};
17
18use crate::{
19 client::Instance,
20 ensure,
21 error::{SshError, SshResult},
22};
23
24#[derive(PartialEq, Eq)]
25pub enum CommandStatus {
27 Running,
28 Terminated,
29}
30
31impl CommandStatus {
32 pub fn status(command_id: &str, text: &str) -> Self {
35 if text.contains(command_id) {
36 Self::Running
37 } else {
38 Self::Terminated
39 }
40 }
41}
42
43#[derive(Clone, Default)]
45pub struct CommandContext {
46 pub background: Option<String>,
49 pub path: Option<PathBuf>,
51 pub log_file: Option<PathBuf>,
53 pub retries: usize,
55}
56
57impl CommandContext {
58 pub fn new() -> Self {
60 Self {
61 background: None,
62 path: None,
63 log_file: None,
64 retries: 0,
65 }
66 }
67
68 pub fn run_background(mut self, id: String) -> Self {
70 self.background = Some(id);
71 self
72 }
73
74 pub fn with_execute_from_path(mut self, path: PathBuf) -> Self {
76 self.path = Some(path);
77 self
78 }
79
80 pub fn with_log_file(mut self, path: PathBuf) -> Self {
82 self.log_file = Some(path);
83 self
84 }
85
86 pub fn with_retries(mut self, retries: usize) -> Self {
88 self.retries = retries;
89 self
90 }
91
92 pub fn apply<S: Into<String>>(&self, base_command: S) -> String {
94 let mut str = base_command.into();
95 if let Some(log_file) = &self.log_file {
96 str = format!("{str} |& tee {}", log_file.as_path().display());
97 }
98 if let Some(id) = &self.background {
99 str = format!("tmux new -d -s \"{id}\" \"{str}\"");
100 }
101 if let Some(exec_path) = &self.path {
102 str = format!("(cd {} && {str})", exec_path.as_path().display());
103 }
104 str
105 }
106}
107
108#[derive(Clone)]
109pub struct SshConnectionManager {
110 username: String,
112 private_key_file: PathBuf,
114 timeout: Option<Duration>,
116 retries: usize,
118}
119
120impl SshConnectionManager {
121 const RETRY_DELAY: Duration = Duration::from_secs(5);
123
124 pub fn new(username: String, private_key_file: PathBuf) -> Self {
126 Self {
127 username,
128 private_key_file,
129 timeout: None,
130 retries: 0,
131 }
132 }
133
134 pub fn with_timeout(mut self, timeout: Duration) -> Self {
136 self.timeout = Some(timeout);
137 self
138 }
139
140 pub fn with_retries(mut self, retries: usize) -> Self {
143 self.retries = retries;
144 self
145 }
146
147 pub async fn connect(&self, address: SocketAddr) -> SshResult<SshConnection> {
149 let mut error = None;
150 for _ in 0..self.retries + 1 {
151 match SshConnection::new(
152 address,
153 &self.username,
154 self.private_key_file.clone(),
155 self.timeout,
156 Some(self.retries),
157 )
158 .await
159 {
160 Ok(x) => return Ok(x),
161 Err(e) => error = Some(e),
162 }
163 sleep(Self::RETRY_DELAY).await;
164 }
165 Err(error.unwrap())
166 }
167
168 pub async fn execute<I, S>(
170 &self,
171 instances: I,
172 command: S,
173 context: CommandContext,
174 ) -> SshResult<Vec<(String, String)>>
175 where
176 I: IntoIterator<Item = Instance>,
177 S: Into<String> + Clone + Send + 'static,
178 {
179 let targets = instances
180 .into_iter()
181 .map(|instance| (instance, command.clone()));
182 self.execute_per_instance(targets, context).await
183 }
184
185 pub async fn execute_per_instance<I, S>(
187 &self,
188 instances: I,
189 context: CommandContext,
190 ) -> SshResult<Vec<(String, String)>>
191 where
192 I: IntoIterator<Item = (Instance, S)>,
193 S: Into<String> + Send + 'static,
194 {
195 let handles = self.run_per_instance(instances, context).await;
196
197 try_join_all(handles)
198 .await
199 .unwrap()
200 .into_iter()
201 .collect::<SshResult<_>>()
202 }
203
204 pub async fn execute_per_instance_with_results<I, S>(
207 &self,
208 instances: I,
209 context: CommandContext,
210 ) -> Vec<SshResult<(String, String)>>
211 where
212 I: IntoIterator<Item = (Instance, S)>,
213 S: Into<String> + Send + 'static,
214 {
215 let handles = self.run_per_instance(instances, context).await;
216
217 try_join_all(handles)
218 .await
219 .unwrap()
220 .into_iter()
221 .collect::<Vec<_>>()
222 }
223
224 async fn run_per_instance<I, S>(
225 &self,
226 instances: I,
227 context: CommandContext,
228 ) -> Vec<JoinHandle<SshResult<(String, String)>>>
229 where
230 I: IntoIterator<Item = (Instance, S)>,
231 S: Into<String> + Send + 'static,
232 {
233 instances
234 .into_iter()
235 .map(|(instance, command)| {
236 let ssh_manager = self.clone();
237 let context = context.clone();
238
239 tokio::spawn(async move {
240 let connection = match ssh_manager.connect(instance.ssh_address()).await {
241 Ok(c) => c,
242 Err(e) => {
243 println!("Failed to connect to {}: error {e}", instance.ssh_address());
244 return Err(e);
245 }
246 };
247
248 let command_str = command.into();
249 let mut consecutive_errors = 0;
250 loop {
251 match connection.execute(context.apply(command_str.clone())).await {
252 Ok(output) => {
253 return Ok(output);
254 }
255 Err(err) => {
256 consecutive_errors += 1;
257
258 if consecutive_errors > context.retries {
259 println!(
260 "Failed to execute command {command_str} at {}: {err}",
261 instance.ssh_address()
262 );
263 return Err(err);
264 }
265 }
266 }
267
268 sleep(Self::RETRY_DELAY).await;
269 }
270 })
271 })
272 .collect::<Vec<_>>()
273 }
274
275 pub async fn wait_for_command<I>(
277 &self,
278 instances: I,
279 command_id: &str,
280 status: CommandStatus,
281 ) -> SshResult<()>
282 where
283 I: IntoIterator<Item = Instance> + Clone,
284 {
285 let mut consecutive_errors = 0;
286 loop {
287 sleep(Self::RETRY_DELAY).await;
288
289 match self
290 .execute(
291 instances.clone(),
292 "(tmux ls || true)",
293 CommandContext::default(),
294 )
295 .await
296 {
297 Ok(result) => {
298 consecutive_errors = 0;
299 if result
300 .iter()
301 .all(|(stdout, _)| CommandStatus::status(command_id, stdout) == status)
302 {
303 break;
304 }
305 }
306 Err(e) => {
307 consecutive_errors += 1;
308 if consecutive_errors >= 5 {
309 return Err(e);
310 }
311 }
312 }
313 }
314 Ok(())
315 }
316
317 pub async fn wait_for_success<I, S>(&self, instances: I)
318 where
319 I: IntoIterator<Item = (Instance, S)> + Clone,
320 S: Into<String> + Send + 'static + Clone,
321 {
322 match self
323 .execute_per_instance(instances.clone(), CommandContext::default().with_retries(5))
324 .await
325 {
326 Ok(_) => {}
327 Err(e) => {
328 panic!("Command execution failed on one or more instances: {e}");
330 }
331 }
332 }
333
334 pub async fn kill<I>(&self, instances: I, command_id: &str) -> SshResult<()>
336 where
337 I: IntoIterator<Item = Instance>,
338 {
339 let ssh_command = format!("(tmux kill-session -t {command_id} || true)");
340 let targets = instances.into_iter().map(|x| (x, ssh_command.clone()));
341 self.execute_per_instance(targets, CommandContext::default().with_retries(5))
342 .await?;
343 Ok(())
344 }
345}
346
347struct Session {}
348
349#[async_trait]
350impl client::Handler for Session {
351 type Error = russh::Error;
352
353 #[allow(clippy::manual_async_fn)]
354 fn check_server_key(
355 &mut self,
356 _server_public_key: &russh::keys::PublicKey,
357 ) -> impl std::future::Future<Output = Result<bool, Self::Error>> + Send {
358 async { Ok(true) }
359 }
360}
361
362pub struct SshConnection {
364 session: client::Handle<Session>,
366 address: SocketAddr,
368 retries: usize,
370}
371
372impl SshConnection {
373 const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
375
376 pub async fn new<P: AsRef<Path>>(
378 address: SocketAddr,
379 username: &str,
380 private_key_file: P,
381 inactivity_timeout: Option<Duration>,
382 retries: Option<usize>,
383 ) -> SshResult<Self> {
384 let key_bytes = std::fs::read(private_key_file).map_err(|e| SshError::PrivateKeyError {
385 address,
386 error: russh::keys::Error::IO(e),
387 })?;
388 let key = russh::keys::decode_secret_key(&String::from_utf8_lossy(&key_bytes), None)
389 .map_err(|_e| SshError::PrivateKeyError {
390 address,
391 error: russh::keys::Error::CouldNotReadKey,
392 })?;
393
394 let config = client::Config {
395 inactivity_timeout: inactivity_timeout.or(Some(Self::DEFAULT_TIMEOUT)),
396 ..<_>::default()
397 };
398
399 let mut session = client::connect(Arc::new(config), address, Session {})
400 .await
401 .map_err(|error| SshError::ConnectionError { address, error })?;
402
403 let key_with_hash = PrivateKeyWithHashAlg::new(Arc::new(key), None);
404 let _auth_res = session
405 .authenticate_publickey(username, key_with_hash)
406 .await
407 .map_err(|error| SshError::SessionError { address, error })?;
408
409 Ok(Self {
410 session,
411 address,
412 retries: retries.unwrap_or_default(),
413 })
414 }
415
416 fn make_session_error(&self, error: russh::Error) -> SshError {
418 SshError::SessionError {
419 address: self.address,
420 error,
421 }
422 }
423
424 pub async fn execute(&self, command: String) -> SshResult<(String, String)> {
426 let mut error = None;
427 for _ in 0..self.retries + 1 {
428 let channel = match self.session.channel_open_session().await {
429 Ok(x) => x,
430 Err(e) => {
431 error = Some(self.make_session_error(e));
432 continue;
433 }
434 };
435 match self.execute_impl(channel, command.clone()).await {
436 r @ Ok(..) => return r,
437 Err(e) => error = Some(e),
438 }
439 }
440 Err(error.unwrap())
441 }
442
443 async fn execute_impl(
446 &self,
447 mut channel: Channel<Msg>,
448 command: String,
449 ) -> SshResult<(String, String)> {
450 channel
451 .exec(true, command.clone())
452 .await
453 .map_err(|e| self.make_session_error(e))?;
454
455 let mut output = Vec::new();
456 let mut exit_code = None;
457
458 while let Some(msg) = channel.wait().await {
459 match msg {
460 russh::ChannelMsg::Data { ref data } => output.write_all(data).unwrap(),
461 russh::ChannelMsg::ExtendedData { ref data, .. } => output.write_all(data).unwrap(),
463 russh::ChannelMsg::ExitStatus { exit_status } => exit_code = Some(exit_status),
464 _ => {}
465 }
466 }
467
468 channel
469 .close()
470 .await
471 .map_err(|error| self.make_session_error(error))?;
472
473 let output_str: String = String::from_utf8_lossy(&output).into();
474
475 ensure!(
476 exit_code.is_some() && exit_code.unwrap() == 0,
477 SshError::NonZeroExitCode {
478 address: self.address,
479 code: exit_code.unwrap(),
480 message: output_str,
481 command,
482 }
483 );
484
485 Ok((output_str.clone(), output_str))
486 }
487
488 pub async fn download<P: AsRef<Path>>(&self, path: P) -> SshResult<String> {
492 let mut error = None;
493 for _ in 0..self.retries + 1 {
494 match self
495 .execute(format!("cat {}", path.as_ref().to_str().unwrap()))
496 .await
497 {
498 Ok((file_data, _)) => return Ok(file_data),
499 Err(err) => error = Some(err),
500 }
501 }
502 Err(error.unwrap())
503 }
504}