iota_aws_orchestrator/
ssh.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    io::Write,
7    net::SocketAddr,
8    path::{Path, PathBuf},
9    sync::Arc,
10    time::Duration,
11};
12
13use async_trait::async_trait;
14use futures::future::try_join_all;
15use russh::{Channel, client, client::Msg, keys::PrivateKeyWithHashAlg};
16use tokio::{task::JoinHandle, time::sleep};
17
18use crate::{
19    client::Instance,
20    ensure,
21    error::{SshError, SshResult},
22};
23
24#[derive(PartialEq, Eq)]
25/// The status of a ssh command running in the background.
26pub enum CommandStatus {
27    Running,
28    Terminated,
29}
30
31impl CommandStatus {
32    /// Return whether a background command is still running. Returns
33    /// `Terminated` if the command is not running in the background.
34    pub fn status(command_id: &str, text: &str) -> Self {
35        if text.contains(command_id) {
36            Self::Running
37        } else {
38            Self::Terminated
39        }
40    }
41}
42
43/// The command to execute on all specified remote machines.
44#[derive(Clone, Default)]
45pub struct CommandContext {
46    /// Whether to run the command in the background (and return immediately).
47    /// Commands running in the background are identified by a unique id.
48    pub background: Option<String>,
49    /// The path from where to execute the command.
50    pub path: Option<PathBuf>,
51    /// The log file to redirect all stdout and stderr.
52    pub log_file: Option<PathBuf>,
53    /// The number of retries before giving up to execute the command.
54    pub retries: usize,
55}
56
57impl CommandContext {
58    /// Create a new ssh command.
59    pub fn new() -> Self {
60        Self {
61            background: None,
62            path: None,
63            log_file: None,
64            retries: 0,
65        }
66    }
67
68    /// Set id of the command and indicate that it should run in the background.
69    pub fn run_background(mut self, id: String) -> Self {
70        self.background = Some(id);
71        self
72    }
73
74    /// Set the path from where to execute the command.
75    pub fn with_execute_from_path(mut self, path: PathBuf) -> Self {
76        self.path = Some(path);
77        self
78    }
79
80    /// Set the log file where to redirect stdout and stderr.
81    pub fn with_log_file(mut self, path: PathBuf) -> Self {
82        self.log_file = Some(path);
83        self
84    }
85
86    /// Set the number of retries before giving up to execute the command.
87    pub fn with_retries(mut self, retries: usize) -> Self {
88        self.retries = retries;
89        self
90    }
91
92    /// Apply the context to a base command.
93    pub fn apply<S: Into<String>>(&self, base_command: S) -> String {
94        let mut str = base_command.into();
95        if let Some(log_file) = &self.log_file {
96            str = format!("{str} |& tee {}", log_file.as_path().display());
97        }
98        if let Some(id) = &self.background {
99            str = format!("tmux new -d -s \"{id}\" \"{str}\"");
100        }
101        if let Some(exec_path) = &self.path {
102            str = format!("(cd {} && {str})", exec_path.as_path().display());
103        }
104        str
105    }
106}
107
108#[derive(Clone)]
109pub struct SshConnectionManager {
110    /// The ssh username.
111    username: String,
112    /// The ssh primate key to connect to the instances.
113    private_key_file: PathBuf,
114    /// The timeout value of the connection.
115    timeout: Option<Duration>,
116    /// The number of retries before giving up to execute the command.
117    retries: usize,
118}
119
120impl SshConnectionManager {
121    /// Delay before re-attempting an ssh execution.
122    const RETRY_DELAY: Duration = Duration::from_secs(5);
123
124    /// Create a new ssh manager from the instances username and private keys.
125    pub fn new(username: String, private_key_file: PathBuf) -> Self {
126        Self {
127            username,
128            private_key_file,
129            timeout: None,
130            retries: 0,
131        }
132    }
133
134    /// Set a timeout duration for the connections.
135    pub fn with_timeout(mut self, timeout: Duration) -> Self {
136        self.timeout = Some(timeout);
137        self
138    }
139
140    /// Set the maximum number of times to retries to establish a connection and
141    /// execute commands.
142    pub fn with_retries(mut self, retries: usize) -> Self {
143        self.retries = retries;
144        self
145    }
146
147    /// Create a new ssh connection with the provided host.
148    pub async fn connect(&self, address: SocketAddr) -> SshResult<SshConnection> {
149        let mut error = None;
150        for _ in 0..self.retries + 1 {
151            match SshConnection::new(
152                address,
153                &self.username,
154                self.private_key_file.clone(),
155                self.timeout,
156                Some(self.retries),
157            )
158            .await
159            {
160                Ok(x) => return Ok(x),
161                Err(e) => error = Some(e),
162            }
163            sleep(Self::RETRY_DELAY).await;
164        }
165        Err(error.unwrap())
166    }
167
168    /// Execute the specified ssh command on all provided instances.
169    pub async fn execute<I, S>(
170        &self,
171        instances: I,
172        command: S,
173        context: CommandContext,
174    ) -> SshResult<Vec<(String, String)>>
175    where
176        I: IntoIterator<Item = Instance>,
177        S: Into<String> + Clone + Send + 'static,
178    {
179        let targets = instances
180            .into_iter()
181            .map(|instance| (instance, command.clone()));
182        self.execute_per_instance(targets, context).await
183    }
184
185    /// Execute the ssh command associated with each instance.
186    pub async fn execute_per_instance<I, S>(
187        &self,
188        instances: I,
189        context: CommandContext,
190    ) -> SshResult<Vec<(String, String)>>
191    where
192        I: IntoIterator<Item = (Instance, S)>,
193        S: Into<String> + Send + 'static,
194    {
195        let handles = self.run_per_instance(instances, context).await;
196
197        try_join_all(handles)
198            .await
199            .unwrap()
200            .into_iter()
201            .collect::<SshResult<_>>()
202    }
203
204    /// Execute the ssh command associated with each instance and return every
205    /// individual result.
206    pub async fn execute_per_instance_with_results<I, S>(
207        &self,
208        instances: I,
209        context: CommandContext,
210    ) -> Vec<SshResult<(String, String)>>
211    where
212        I: IntoIterator<Item = (Instance, S)>,
213        S: Into<String> + Send + 'static,
214    {
215        let handles = self.run_per_instance(instances, context).await;
216
217        try_join_all(handles)
218            .await
219            .unwrap()
220            .into_iter()
221            .collect::<Vec<_>>()
222    }
223
224    async fn run_per_instance<I, S>(
225        &self,
226        instances: I,
227        context: CommandContext,
228    ) -> Vec<JoinHandle<SshResult<(String, String)>>>
229    where
230        I: IntoIterator<Item = (Instance, S)>,
231        S: Into<String> + Send + 'static,
232    {
233        instances
234            .into_iter()
235            .map(|(instance, command)| {
236                let ssh_manager = self.clone();
237                let context = context.clone();
238
239                tokio::spawn(async move {
240                    let connection = match ssh_manager.connect(instance.ssh_address()).await {
241                        Ok(c) => c,
242                        Err(e) => {
243                            println!("Failed to connect to {}: error {e}", instance.ssh_address());
244                            return Err(e);
245                        }
246                    };
247
248                    let command_str = command.into();
249                    let mut consecutive_errors = 0;
250                    loop {
251                        match connection.execute(context.apply(command_str.clone())).await {
252                            Ok(output) => {
253                                return Ok(output);
254                            }
255                            Err(err) => {
256                                consecutive_errors += 1;
257
258                                if consecutive_errors > context.retries {
259                                    println!(
260                                        "Failed to execute command {command_str} at {}: {err}",
261                                        instance.ssh_address()
262                                    );
263                                    return Err(err);
264                                }
265                            }
266                        }
267
268                        sleep(Self::RETRY_DELAY).await;
269                    }
270                })
271            })
272            .collect::<Vec<_>>()
273    }
274
275    /// Wait until a command running in the background returns or started.
276    pub async fn wait_for_command<I>(
277        &self,
278        instances: I,
279        command_id: &str,
280        status: CommandStatus,
281    ) -> SshResult<()>
282    where
283        I: IntoIterator<Item = Instance> + Clone,
284    {
285        let mut consecutive_errors = 0;
286        loop {
287            sleep(Self::RETRY_DELAY).await;
288
289            match self
290                .execute(
291                    instances.clone(),
292                    "(tmux ls || true)",
293                    CommandContext::default(),
294                )
295                .await
296            {
297                Ok(result) => {
298                    consecutive_errors = 0;
299                    if result
300                        .iter()
301                        .all(|(stdout, _)| CommandStatus::status(command_id, stdout) == status)
302                    {
303                        break;
304                    }
305                }
306                Err(e) => {
307                    consecutive_errors += 1;
308                    if consecutive_errors >= 5 {
309                        return Err(e);
310                    }
311                }
312            }
313        }
314        Ok(())
315    }
316
317    pub async fn wait_for_success<I, S>(&self, instances: I)
318    where
319        I: IntoIterator<Item = (Instance, S)> + Clone,
320        S: Into<String> + Send + 'static + Clone,
321    {
322        match self
323            .execute_per_instance(instances.clone(), CommandContext::default().with_retries(5))
324            .await
325        {
326            Ok(_) => {}
327            Err(e) => {
328                // Handle failure case
329                panic!("Command execution failed on one or more instances: {e}");
330            }
331        }
332    }
333
334    /// Kill a command running in the background of the specified instances.
335    pub async fn kill<I>(&self, instances: I, command_id: &str) -> SshResult<()>
336    where
337        I: IntoIterator<Item = Instance>,
338    {
339        let ssh_command = format!("(tmux kill-session -t {command_id} || true)");
340        let targets = instances.into_iter().map(|x| (x, ssh_command.clone()));
341        self.execute_per_instance(targets, CommandContext::default().with_retries(5))
342            .await?;
343        Ok(())
344    }
345}
346
347struct Session {}
348
349#[async_trait]
350impl client::Handler for Session {
351    type Error = russh::Error;
352
353    #[allow(clippy::manual_async_fn)]
354    fn check_server_key(
355        &mut self,
356        _server_public_key: &russh::keys::PublicKey,
357    ) -> impl std::future::Future<Output = Result<bool, Self::Error>> + Send {
358        async { Ok(true) }
359    }
360}
361
362/// Representation of an ssh connection.
363pub struct SshConnection {
364    /// The ssh session.
365    session: client::Handle<Session>,
366    /// The host address.
367    address: SocketAddr,
368    /// The number of retries before giving up to execute the command.
369    retries: usize,
370}
371
372impl SshConnection {
373    /// Default duration before timing out the ssh connection.
374    const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
375
376    /// Create a new ssh connection with a specific host.
377    pub async fn new<P: AsRef<Path>>(
378        address: SocketAddr,
379        username: &str,
380        private_key_file: P,
381        inactivity_timeout: Option<Duration>,
382        retries: Option<usize>,
383    ) -> SshResult<Self> {
384        let key_bytes = std::fs::read(private_key_file).map_err(|e| SshError::PrivateKeyError {
385            address,
386            error: russh::keys::Error::IO(e),
387        })?;
388        let key = russh::keys::decode_secret_key(&String::from_utf8_lossy(&key_bytes), None)
389            .map_err(|_e| SshError::PrivateKeyError {
390                address,
391                error: russh::keys::Error::CouldNotReadKey,
392            })?;
393
394        let config = client::Config {
395            inactivity_timeout: inactivity_timeout.or(Some(Self::DEFAULT_TIMEOUT)),
396            ..<_>::default()
397        };
398
399        let mut session = client::connect(Arc::new(config), address, Session {})
400            .await
401            .map_err(|error| SshError::ConnectionError { address, error })?;
402
403        let key_with_hash = PrivateKeyWithHashAlg::new(Arc::new(key), None);
404        let _auth_res = session
405            .authenticate_publickey(username, key_with_hash)
406            .await
407            .map_err(|error| SshError::SessionError { address, error })?;
408
409        Ok(Self {
410            session,
411            address,
412            retries: retries.unwrap_or_default(),
413        })
414    }
415
416    /// Make a useful session error from the lower level error message.
417    fn make_session_error(&self, error: russh::Error) -> SshError {
418        SshError::SessionError {
419            address: self.address,
420            error,
421        }
422    }
423
424    /// Execute a ssh command on the remote machine.
425    pub async fn execute(&self, command: String) -> SshResult<(String, String)> {
426        let mut error = None;
427        for _ in 0..self.retries + 1 {
428            let channel = match self.session.channel_open_session().await {
429                Ok(x) => x,
430                Err(e) => {
431                    error = Some(self.make_session_error(e));
432                    continue;
433                }
434            };
435            match self.execute_impl(channel, command.clone()).await {
436                r @ Ok(..) => return r,
437                Err(e) => error = Some(e),
438            }
439        }
440        Err(error.unwrap())
441    }
442
443    /// Execute an ssh command on the remote machine and return both stdout and
444    /// stderr.
445    async fn execute_impl(
446        &self,
447        mut channel: Channel<Msg>,
448        command: String,
449    ) -> SshResult<(String, String)> {
450        channel
451            .exec(true, command.clone())
452            .await
453            .map_err(|e| self.make_session_error(e))?;
454
455        let mut output = Vec::new();
456        let mut exit_code = None;
457
458        while let Some(msg) = channel.wait().await {
459            match msg {
460                russh::ChannelMsg::Data { ref data } => output.write_all(data).unwrap(),
461                // For better error reporting, we capture stderr as well.
462                russh::ChannelMsg::ExtendedData { ref data, .. } => output.write_all(data).unwrap(),
463                russh::ChannelMsg::ExitStatus { exit_status } => exit_code = Some(exit_status),
464                _ => {}
465            }
466        }
467
468        channel
469            .close()
470            .await
471            .map_err(|error| self.make_session_error(error))?;
472
473        let output_str: String = String::from_utf8_lossy(&output).into();
474
475        ensure!(
476            exit_code.is_some() && exit_code.unwrap() == 0,
477            SshError::NonZeroExitCode {
478                address: self.address,
479                code: exit_code.unwrap(),
480                message: output_str,
481                command,
482            }
483        );
484
485        Ok((output_str.clone(), output_str))
486    }
487
488    /// Download a file from the remote machines by doing a silly cat on the
489    /// file. TODO: if the files get too big then we should leverage a
490    /// simple S3 bucket instead.
491    pub async fn download<P: AsRef<Path>>(&self, path: P) -> SshResult<String> {
492        let mut error = None;
493        for _ in 0..self.retries + 1 {
494            match self
495                .execute(format!("cat {}", path.as_ref().to_str().unwrap()))
496                .await
497            {
498                Ok((file_data, _)) => return Ok(file_data),
499                Err(err) => error = Some(err),
500            }
501        }
502        Err(error.unwrap())
503    }
504}