iota_aws_orchestrator/
ssh.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    io::Write,
7    net::SocketAddr,
8    path::{Path, PathBuf},
9    sync::Arc,
10    time::Duration,
11};
12
13use async_trait::async_trait;
14use futures::future::try_join_all;
15use russh::{Channel, client, client::Msg};
16use russh_keys::key;
17use tokio::{task::JoinHandle, time::sleep};
18
19use crate::{
20    client::Instance,
21    ensure,
22    error::{SshError, SshResult},
23};
24
25#[derive(PartialEq, Eq)]
26/// The status of a ssh command running in the background.
27pub enum CommandStatus {
28    Running,
29    Terminated,
30}
31
32impl CommandStatus {
33    /// Return whether a background command is still running. Returns
34    /// `Terminated` if the command is not running in the background.
35    pub fn status(command_id: &str, text: &str) -> Self {
36        if text.contains(command_id) {
37            Self::Running
38        } else {
39            Self::Terminated
40        }
41    }
42}
43
44/// The command to execute on all specified remote machines.
45#[derive(Clone, Default)]
46pub struct CommandContext {
47    /// Whether to run the command in the background (and return immediately).
48    /// Commands running in the background are identified by a unique id.
49    pub background: Option<String>,
50    /// The path from where to execute the command.
51    pub path: Option<PathBuf>,
52    /// The log file to redirect all stdout and stderr.
53    pub log_file: Option<PathBuf>,
54    /// The number of retries before giving up to execute the command.
55    pub retries: usize,
56}
57
58impl CommandContext {
59    /// Create a new ssh command.
60    pub fn new() -> Self {
61        Self {
62            background: None,
63            path: None,
64            log_file: None,
65            retries: 0,
66        }
67    }
68
69    /// Set id of the command and indicate that it should run in the background.
70    pub fn run_background(mut self, id: String) -> Self {
71        self.background = Some(id);
72        self
73    }
74
75    /// Set the path from where to execute the command.
76    pub fn with_execute_from_path(mut self, path: PathBuf) -> Self {
77        self.path = Some(path);
78        self
79    }
80
81    /// Set the log file where to redirect stdout and stderr.
82    pub fn with_log_file(mut self, path: PathBuf) -> Self {
83        self.log_file = Some(path);
84        self
85    }
86
87    /// Set the number of retries before giving up to execute the command.
88    pub fn with_retries(mut self, retries: usize) -> Self {
89        self.retries = retries;
90        self
91    }
92
93    /// Apply the context to a base command.
94    pub fn apply<S: Into<String>>(&self, base_command: S) -> String {
95        let mut str = base_command.into();
96        if let Some(log_file) = &self.log_file {
97            str = format!("{str} |& tee {}", log_file.as_path().display());
98        }
99        if let Some(id) = &self.background {
100            str = format!("tmux new -d -s \"{id}\" \"{str}\"");
101        }
102        if let Some(exec_path) = &self.path {
103            str = format!("(cd {} && {str})", exec_path.as_path().display());
104        }
105        str
106    }
107}
108
109#[derive(Clone)]
110pub struct SshConnectionManager {
111    /// The ssh username.
112    username: String,
113    /// The ssh primate key to connect to the instances.
114    private_key_file: PathBuf,
115    /// The timeout value of the connection.
116    timeout: Option<Duration>,
117    /// The number of retries before giving up to execute the command.
118    retries: usize,
119}
120
121impl SshConnectionManager {
122    /// Delay before re-attempting an ssh execution.
123    const RETRY_DELAY: Duration = Duration::from_secs(5);
124
125    /// Create a new ssh manager from the instances username and private keys.
126    pub fn new(username: String, private_key_file: PathBuf) -> Self {
127        Self {
128            username,
129            private_key_file,
130            timeout: None,
131            retries: 0,
132        }
133    }
134
135    /// Set a timeout duration for the connections.
136    pub fn with_timeout(mut self, timeout: Duration) -> Self {
137        self.timeout = Some(timeout);
138        self
139    }
140
141    /// Set the maximum number of times to retries to establish a connection and
142    /// execute commands.
143    pub fn with_retries(mut self, retries: usize) -> Self {
144        self.retries = retries;
145        self
146    }
147
148    /// Create a new ssh connection with the provided host.
149    pub async fn connect(&self, address: SocketAddr) -> SshResult<SshConnection> {
150        let mut error = None;
151        for _ in 0..self.retries + 1 {
152            match SshConnection::new(
153                address,
154                &self.username,
155                self.private_key_file.clone(),
156                self.timeout,
157                Some(self.retries),
158            )
159            .await
160            {
161                Ok(x) => return Ok(x),
162                Err(e) => error = Some(e),
163            }
164            sleep(Self::RETRY_DELAY).await;
165        }
166        Err(error.unwrap())
167    }
168
169    /// Execute the specified ssh command on all provided instances.
170    pub async fn execute<I, S>(
171        &self,
172        instances: I,
173        command: S,
174        context: CommandContext,
175    ) -> SshResult<Vec<(String, String)>>
176    where
177        I: IntoIterator<Item = Instance>,
178        S: Into<String> + Clone + Send + 'static,
179    {
180        let targets = instances
181            .into_iter()
182            .map(|instance| (instance, command.clone()));
183        self.execute_per_instance(targets, context).await
184    }
185
186    /// Execute the ssh command associated with each instance.
187    pub async fn execute_per_instance<I, S>(
188        &self,
189        instances: I,
190        context: CommandContext,
191    ) -> SshResult<Vec<(String, String)>>
192    where
193        I: IntoIterator<Item = (Instance, S)>,
194        S: Into<String> + Send + 'static,
195    {
196        let handles = self.run_per_instance(instances, context).await;
197
198        try_join_all(handles)
199            .await
200            .unwrap()
201            .into_iter()
202            .collect::<SshResult<_>>()
203    }
204
205    async fn run_per_instance<I, S>(
206        &self,
207        instances: I,
208        context: CommandContext,
209    ) -> Vec<JoinHandle<SshResult<(String, String)>>>
210    where
211        I: IntoIterator<Item = (Instance, S)>,
212        S: Into<String> + Send + 'static,
213    {
214        instances
215            .into_iter()
216            .map(|(instance, command)| {
217                let ssh_manager = self.clone();
218                let context = context.clone();
219
220                tokio::spawn(async move {
221                    let connection = ssh_manager.connect(instance.ssh_address()).await?;
222
223                    let command_str = command.into();
224                    let mut consecutive_errors = 0;
225                    loop {
226                        match connection.execute(context.apply(command_str.clone())).await {
227                            Ok(output) => {
228                                return Ok(output);
229                            }
230                            Err(err) => {
231                                consecutive_errors += 1;
232                                if consecutive_errors > context.retries {
233                                    return Err(err);
234                                }
235                            }
236                        }
237
238                        sleep(Self::RETRY_DELAY).await;
239                    }
240                })
241            })
242            .collect::<Vec<_>>()
243    }
244
245    /// Wait until a command running in the background returns or started.
246    pub async fn wait_for_command<I>(
247        &self,
248        instances: I,
249        command_id: &str,
250        status: CommandStatus,
251    ) -> SshResult<()>
252    where
253        I: IntoIterator<Item = Instance> + Clone,
254    {
255        let mut consecutive_errors = 0;
256        loop {
257            sleep(Self::RETRY_DELAY).await;
258
259            match self
260                .execute(
261                    instances.clone(),
262                    "(tmux ls || true)",
263                    CommandContext::default(),
264                )
265                .await
266            {
267                Ok(result) => {
268                    consecutive_errors = 0;
269                    if result
270                        .iter()
271                        .all(|(stdout, _)| CommandStatus::status(command_id, stdout) == status)
272                    {
273                        break;
274                    }
275                }
276                Err(e) => {
277                    consecutive_errors += 1;
278                    if consecutive_errors >= 5 {
279                        return Err(e);
280                    }
281                }
282            }
283        }
284        Ok(())
285    }
286
287    pub async fn wait_for_success<I, S>(&self, instances: I)
288    where
289        I: IntoIterator<Item = (Instance, S)> + Clone,
290        S: Into<String> + Send + 'static + Clone,
291    {
292        match self
293            .execute_per_instance(instances.clone(), CommandContext::default().with_retries(5))
294            .await
295        {
296            Ok(_) => {}
297            Err(e) => {
298                // Handle failure case
299                panic!("Command execution failed on one or more instances: {e}");
300            }
301        }
302    }
303
304    /// Kill a command running in the background of the specified instances.
305    pub async fn kill<I>(&self, instances: I, command_id: &str) -> SshResult<()>
306    where
307        I: IntoIterator<Item = Instance>,
308    {
309        let ssh_command = format!("(tmux kill-session -t {command_id} || true)");
310        let targets = instances.into_iter().map(|x| (x, ssh_command.clone()));
311        self.execute_per_instance(targets, CommandContext::default().with_retries(5))
312            .await?;
313        Ok(())
314    }
315}
316
317struct Session {}
318
319#[async_trait]
320impl client::Handler for Session {
321    type Error = russh::Error;
322
323    async fn check_server_key(
324        &mut self,
325        _server_public_key: &key::PublicKey,
326    ) -> Result<bool, Self::Error> {
327        Ok(true)
328    }
329}
330
331/// Representation of an ssh connection.
332pub struct SshConnection {
333    /// The ssh session.
334    session: client::Handle<Session>,
335    /// The host address.
336    address: SocketAddr,
337    /// The number of retries before giving up to execute the command.
338    retries: usize,
339}
340
341impl SshConnection {
342    /// Default duration before timing out the ssh connection.
343    const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
344
345    /// Create a new ssh connection with a specific host.
346    pub async fn new<P: AsRef<Path>>(
347        address: SocketAddr,
348        username: &str,
349        private_key_file: P,
350        inactivity_timeout: Option<Duration>,
351        retries: Option<usize>,
352    ) -> SshResult<Self> {
353        let key = russh_keys::load_secret_key(private_key_file, None)
354            .map_err(|error| SshError::PrivateKeyError { address, error })?;
355
356        let config = client::Config {
357            inactivity_timeout: inactivity_timeout.or(Some(Self::DEFAULT_TIMEOUT)),
358            ..<_>::default()
359        };
360
361        let mut session = client::connect(Arc::new(config), address, Session {})
362            .await
363            .map_err(|error| SshError::ConnectionError { address, error })?;
364
365        let _auth_res = session
366            .authenticate_publickey(username, Arc::new(key))
367            .await
368            .map_err(|error| SshError::SessionError { address, error })?;
369
370        Ok(Self {
371            session,
372            address,
373            retries: retries.unwrap_or_default(),
374        })
375    }
376
377    /// Make a useful session error from the lower level error message.
378    fn make_session_error(&self, error: russh::Error) -> SshError {
379        SshError::SessionError {
380            address: self.address,
381            error,
382        }
383    }
384
385    /// Execute a ssh command on the remote machine.
386    pub async fn execute(&self, command: String) -> SshResult<(String, String)> {
387        let mut error = None;
388        for _ in 0..self.retries + 1 {
389            let channel = match self.session.channel_open_session().await {
390                Ok(x) => x,
391                Err(e) => {
392                    error = Some(self.make_session_error(e));
393                    continue;
394                }
395            };
396            match self.execute_impl(channel, command.clone()).await {
397                r @ Ok(..) => return r,
398                Err(e) => error = Some(e),
399            }
400        }
401        Err(error.unwrap())
402    }
403
404    /// Execute an ssh command on the remote machine and return both stdout and
405    /// stderr.
406    async fn execute_impl(
407        &self,
408        mut channel: Channel<Msg>,
409        command: String,
410    ) -> SshResult<(String, String)> {
411        channel
412            .exec(true, command.clone())
413            .await
414            .map_err(|e| self.make_session_error(e))?;
415
416        let mut output = Vec::new();
417        let mut exit_code = None;
418
419        while let Some(msg) = channel.wait().await {
420            match msg {
421                russh::ChannelMsg::Data { ref data } => output.write_all(data).unwrap(),
422                russh::ChannelMsg::ExitStatus { exit_status } => exit_code = Some(exit_status),
423                _ => {}
424            }
425        }
426
427        channel
428            .close()
429            .await
430            .map_err(|error| self.make_session_error(error))?;
431
432        let output_str: String = String::from_utf8_lossy(&output).into();
433
434        ensure!(
435            exit_code.is_some() && exit_code.unwrap() == 0,
436            SshError::NonZeroExitCode {
437                address: self.address,
438                code: exit_code.unwrap(),
439                message: output_str,
440                command,
441            }
442        );
443
444        Ok((output_str.clone(), output_str))
445    }
446
447    /// Download a file from the remote machines by doing a silly cat on the
448    /// file. TODO: if the files get too big then we should leverage a
449    /// simple S3 bucket instead.
450    pub async fn download<P: AsRef<Path>>(&self, path: P) -> SshResult<String> {
451        let mut error = None;
452        for _ in 0..self.retries + 1 {
453            match self
454                .execute(format!("cat {}", path.as_ref().to_str().unwrap()))
455                .await
456            {
457                Ok((file_data, _)) => return Ok(file_data),
458                Err(err) => error = Some(err),
459            }
460        }
461        Err(error.unwrap())
462    }
463}