iota_aws_orchestrator/
ssh.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    io::Write,
7    net::SocketAddr,
8    path::{Path, PathBuf},
9    sync::Arc,
10    time::Duration,
11};
12
13use async_trait::async_trait;
14use futures::future::try_join_all;
15use russh::{Channel, client, client::Msg};
16use russh_keys::key;
17use tokio::{task::JoinHandle, time::sleep};
18
19use crate::{
20    client::Instance,
21    ensure,
22    error::{SshError, SshResult},
23};
24
25#[derive(PartialEq, Eq)]
26/// The status of a ssh command running in the background.
27pub enum CommandStatus {
28    Running,
29    Terminated,
30}
31
32impl CommandStatus {
33    /// Return whether a background command is still running. Returns
34    /// `Terminated` if the command is not running in the background.
35    pub fn status(command_id: &str, text: &str) -> Self {
36        if text.contains(command_id) {
37            Self::Running
38        } else {
39            Self::Terminated
40        }
41    }
42}
43
44/// The command to execute on all specified remote machines.
45#[derive(Clone, Default)]
46pub struct CommandContext {
47    /// Whether to run the command in the background (and return immediately).
48    /// Commands running in the background are identified by a unique id.
49    pub background: Option<String>,
50    /// The path from where to execute the command.
51    pub path: Option<PathBuf>,
52    /// The log file to redirect all stdout and stderr.
53    pub log_file: Option<PathBuf>,
54}
55
56impl CommandContext {
57    /// Create a new ssh command.
58    pub fn new() -> Self {
59        Self {
60            background: None,
61            path: None,
62            log_file: None,
63        }
64    }
65
66    /// Set id of the command and indicate that it should run in the background.
67    pub fn run_background(mut self, id: String) -> Self {
68        self.background = Some(id);
69        self
70    }
71
72    /// Set the path from where to execute the command.
73    pub fn with_execute_from_path(mut self, path: PathBuf) -> Self {
74        self.path = Some(path);
75        self
76    }
77
78    /// Set the log file where to redirect stdout and stderr.
79    pub fn with_log_file(mut self, path: PathBuf) -> Self {
80        self.log_file = Some(path);
81        self
82    }
83
84    /// Apply the context to a base command.
85    pub fn apply<S: Into<String>>(&self, base_command: S) -> String {
86        let mut str = base_command.into();
87        if let Some(log_file) = &self.log_file {
88            str = format!("{str} |& tee {}", log_file.as_path().display());
89        }
90        if let Some(id) = &self.background {
91            str = format!("tmux new -d -s \"{id}\" \"{str}\"");
92        }
93        if let Some(exec_path) = &self.path {
94            str = format!("(cd {} && {str})", exec_path.as_path().display());
95        }
96        str
97    }
98}
99
100#[derive(Clone)]
101pub struct SshConnectionManager {
102    /// The ssh username.
103    username: String,
104    /// The ssh primate key to connect to the instances.
105    private_key_file: PathBuf,
106    /// The timeout value of the connection.
107    timeout: Option<Duration>,
108    /// The number of retries before giving up to execute the command.
109    retries: usize,
110}
111
112impl SshConnectionManager {
113    /// Delay before re-attempting an ssh execution.
114    const RETRY_DELAY: Duration = Duration::from_secs(5);
115
116    /// Create a new ssh manager from the instances username and private keys.
117    pub fn new(username: String, private_key_file: PathBuf) -> Self {
118        Self {
119            username,
120            private_key_file,
121            timeout: None,
122            retries: 0,
123        }
124    }
125
126    /// Set a timeout duration for the connections.
127    pub fn with_timeout(mut self, timeout: Duration) -> Self {
128        self.timeout = Some(timeout);
129        self
130    }
131
132    /// Set the maximum number of times to retries to establish a connection and
133    /// execute commands.
134    pub fn with_retries(mut self, retries: usize) -> Self {
135        self.retries = retries;
136        self
137    }
138
139    /// Create a new ssh connection with the provided host.
140    pub async fn connect(&self, address: SocketAddr) -> SshResult<SshConnection> {
141        let mut error = None;
142        for _ in 0..self.retries + 1 {
143            match SshConnection::new(
144                address,
145                &self.username,
146                self.private_key_file.clone(),
147                self.timeout,
148                Some(self.retries),
149            )
150            .await
151            {
152                Ok(x) => return Ok(x),
153                Err(e) => error = Some(e),
154            }
155            sleep(Self::RETRY_DELAY).await;
156        }
157        Err(error.unwrap())
158    }
159
160    /// Execute the specified ssh command on all provided instances.
161    pub async fn execute<I, S>(
162        &self,
163        instances: I,
164        command: S,
165        context: CommandContext,
166    ) -> SshResult<Vec<(String, String)>>
167    where
168        I: IntoIterator<Item = Instance>,
169        S: Into<String> + Clone + Send + 'static,
170    {
171        let targets = instances
172            .into_iter()
173            .map(|instance| (instance, command.clone()));
174        self.execute_per_instance(targets, context).await
175    }
176
177    /// Execute the ssh command associated with each instance.
178    pub async fn execute_per_instance<I, S>(
179        &self,
180        instances: I,
181        context: CommandContext,
182    ) -> SshResult<Vec<(String, String)>>
183    where
184        I: IntoIterator<Item = (Instance, S)>,
185        S: Into<String> + Send + 'static,
186    {
187        let handles = self.run_per_instance(instances, context).await;
188
189        try_join_all(handles)
190            .await
191            .unwrap()
192            .into_iter()
193            .collect::<SshResult<_>>()
194    }
195
196    async fn run_per_instance<I, S>(
197        &self,
198        instances: I,
199        context: CommandContext,
200    ) -> Vec<JoinHandle<SshResult<(String, String)>>>
201    where
202        I: IntoIterator<Item = (Instance, S)>,
203        S: Into<String> + Send + 'static,
204    {
205        instances
206            .into_iter()
207            .map(|(instance, command)| {
208                let ssh_manager = self.clone();
209                let context = context.clone();
210
211                tokio::spawn(async move {
212                    let connection = ssh_manager.connect(instance.ssh_address()).await?;
213                    // SshConnection::execute is a blocking call, needs to go to blocking pool
214                    connection.execute(context.apply(command)).await
215                })
216            })
217            .collect::<Vec<_>>()
218    }
219
220    /// Wait until a command running in the background returns or started.
221    pub async fn wait_for_command<I>(
222        &self,
223        instances: I,
224        command_id: &str,
225        status: CommandStatus,
226    ) -> SshResult<()>
227    where
228        I: IntoIterator<Item = Instance> + Clone,
229    {
230        loop {
231            sleep(Self::RETRY_DELAY).await;
232
233            let result = self
234                .execute(
235                    instances.clone(),
236                    "(tmux ls || true)",
237                    CommandContext::default(),
238                )
239                .await?;
240            if result
241                .iter()
242                .all(|(stdout, _)| CommandStatus::status(command_id, stdout) == status)
243            {
244                break;
245            }
246        }
247        Ok(())
248    }
249
250    pub async fn wait_for_success<I, S>(&self, instances: I)
251    where
252        I: IntoIterator<Item = (Instance, S)> + Clone,
253        S: Into<String> + Send + 'static + Clone,
254    {
255        loop {
256            sleep(Self::RETRY_DELAY).await;
257
258            if self
259                .execute_per_instance(instances.clone(), CommandContext::default())
260                .await
261                .is_ok()
262            {
263                break;
264            }
265        }
266    }
267
268    /// Kill a command running in the background of the specified instances.
269    pub async fn kill<I>(&self, instances: I, command_id: &str) -> SshResult<()>
270    where
271        I: IntoIterator<Item = Instance>,
272    {
273        let ssh_command = format!("(tmux kill-session -t {command_id} || true)");
274        let targets = instances.into_iter().map(|x| (x, ssh_command.clone()));
275        self.execute_per_instance(targets, CommandContext::default())
276            .await?;
277        Ok(())
278    }
279}
280
281struct Session {}
282
283#[async_trait]
284impl client::Handler for Session {
285    type Error = russh::Error;
286
287    async fn check_server_key(
288        &mut self,
289        _server_public_key: &key::PublicKey,
290    ) -> Result<bool, Self::Error> {
291        Ok(true)
292    }
293}
294
295/// Representation of an ssh connection.
296pub struct SshConnection {
297    /// The ssh session.
298    session: client::Handle<Session>,
299    /// The host address.
300    address: SocketAddr,
301    /// The number of retries before giving up to execute the command.
302    retries: usize,
303}
304
305impl SshConnection {
306    /// Default duration before timing out the ssh connection.
307    const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
308
309    /// Create a new ssh connection with a specific host.
310    pub async fn new<P: AsRef<Path>>(
311        address: SocketAddr,
312        username: &str,
313        private_key_file: P,
314        inactivity_timeout: Option<Duration>,
315        retries: Option<usize>,
316    ) -> SshResult<Self> {
317        let key = russh_keys::load_secret_key(private_key_file, None)
318            .map_err(|error| SshError::PrivateKeyError { address, error })?;
319
320        let config = client::Config {
321            inactivity_timeout: inactivity_timeout.or(Some(Self::DEFAULT_TIMEOUT)),
322            ..<_>::default()
323        };
324
325        let mut session = client::connect(Arc::new(config), address, Session {})
326            .await
327            .map_err(|error| SshError::ConnectionError { address, error })?;
328
329        let _auth_res = session
330            .authenticate_publickey(username, Arc::new(key))
331            .await
332            .map_err(|error| SshError::SessionError { address, error })?;
333
334        Ok(Self {
335            session,
336            address,
337            retries: retries.unwrap_or_default(),
338        })
339    }
340
341    /// Make a useful session error from the lower level error message.
342    fn make_session_error(&self, error: russh::Error) -> SshError {
343        SshError::SessionError {
344            address: self.address,
345            error,
346        }
347    }
348
349    /// Execute a ssh command on the remote machine.
350    pub async fn execute(&self, command: String) -> SshResult<(String, String)> {
351        let mut error = None;
352        for _ in 0..self.retries + 1 {
353            let channel = match self.session.channel_open_session().await {
354                Ok(x) => x,
355                Err(e) => {
356                    error = Some(self.make_session_error(e));
357                    continue;
358                }
359            };
360            match self.execute_impl(channel, command.clone()).await {
361                r @ Ok(..) => return r,
362                Err(e) => error = Some(e),
363            }
364        }
365        Err(error.unwrap())
366    }
367
368    /// Execute an ssh command on the remote machine and return both stdout and
369    /// stderr.
370    async fn execute_impl(
371        &self,
372        mut channel: Channel<Msg>,
373        command: String,
374    ) -> SshResult<(String, String)> {
375        channel
376            .exec(true, command)
377            .await
378            .map_err(|e| self.make_session_error(e))?;
379
380        let mut output = Vec::new();
381        let mut exit_code = None;
382
383        while let Some(msg) = channel.wait().await {
384            match msg {
385                russh::ChannelMsg::Data { ref data } => output.write_all(data).unwrap(),
386                russh::ChannelMsg::ExitStatus { exit_status } => exit_code = Some(exit_status),
387                _ => {}
388            }
389        }
390
391        channel
392            .close()
393            .await
394            .map_err(|error| self.make_session_error(error))?;
395
396        let output_str: String = String::from_utf8_lossy(&output).into();
397
398        ensure!(
399            exit_code.is_some() && exit_code.unwrap() == 0,
400            SshError::NonZeroExitCode {
401                address: self.address,
402                code: exit_code.unwrap(),
403                message: output_str
404            }
405        );
406
407        Ok((output_str.clone(), output_str))
408    }
409
410    /// Download a file from the remote machines by doing a silly cat on the
411    /// file. TODO: if the files get too big then we should leverage a
412    /// simple S3 bucket instead.
413    pub async fn download<P: AsRef<Path>>(&self, path: P) -> SshResult<String> {
414        let mut error = None;
415        for _ in 0..self.retries + 1 {
416            match self
417                .execute(format!("cat {}", path.as_ref().to_str().unwrap()))
418                .await
419            {
420                Ok((file_data, _)) => return Ok(file_data),
421                Err(err) => error = Some(err),
422            }
423        }
424        Err(error.unwrap())
425    }
426}