1use std::{
6 io::Write,
7 net::SocketAddr,
8 path::{Path, PathBuf},
9 sync::Arc,
10 time::Duration,
11};
12
13use async_trait::async_trait;
14use futures::future::try_join_all;
15use russh::{Channel, client, client::Msg};
16use russh_keys::key;
17use tokio::{task::JoinHandle, time::sleep};
18
19use crate::{
20 client::Instance,
21 ensure,
22 error::{SshError, SshResult},
23};
24
25#[derive(PartialEq, Eq)]
26pub enum CommandStatus {
28 Running,
29 Terminated,
30}
31
32impl CommandStatus {
33 pub fn status(command_id: &str, text: &str) -> Self {
36 if text.contains(command_id) {
37 Self::Running
38 } else {
39 Self::Terminated
40 }
41 }
42}
43
44#[derive(Clone, Default)]
46pub struct CommandContext {
47 pub background: Option<String>,
50 pub path: Option<PathBuf>,
52 pub log_file: Option<PathBuf>,
54}
55
56impl CommandContext {
57 pub fn new() -> Self {
59 Self {
60 background: None,
61 path: None,
62 log_file: None,
63 }
64 }
65
66 pub fn run_background(mut self, id: String) -> Self {
68 self.background = Some(id);
69 self
70 }
71
72 pub fn with_execute_from_path(mut self, path: PathBuf) -> Self {
74 self.path = Some(path);
75 self
76 }
77
78 pub fn with_log_file(mut self, path: PathBuf) -> Self {
80 self.log_file = Some(path);
81 self
82 }
83
84 pub fn apply<S: Into<String>>(&self, base_command: S) -> String {
86 let mut str = base_command.into();
87 if let Some(log_file) = &self.log_file {
88 str = format!("{str} |& tee {}", log_file.as_path().display());
89 }
90 if let Some(id) = &self.background {
91 str = format!("tmux new -d -s \"{id}\" \"{str}\"");
92 }
93 if let Some(exec_path) = &self.path {
94 str = format!("(cd {} && {str})", exec_path.as_path().display());
95 }
96 str
97 }
98}
99
100#[derive(Clone)]
101pub struct SshConnectionManager {
102 username: String,
104 private_key_file: PathBuf,
106 timeout: Option<Duration>,
108 retries: usize,
110}
111
112impl SshConnectionManager {
113 const RETRY_DELAY: Duration = Duration::from_secs(5);
115
116 pub fn new(username: String, private_key_file: PathBuf) -> Self {
118 Self {
119 username,
120 private_key_file,
121 timeout: None,
122 retries: 0,
123 }
124 }
125
126 pub fn with_timeout(mut self, timeout: Duration) -> Self {
128 self.timeout = Some(timeout);
129 self
130 }
131
132 pub fn with_retries(mut self, retries: usize) -> Self {
135 self.retries = retries;
136 self
137 }
138
139 pub async fn connect(&self, address: SocketAddr) -> SshResult<SshConnection> {
141 let mut error = None;
142 for _ in 0..self.retries + 1 {
143 match SshConnection::new(
144 address,
145 &self.username,
146 self.private_key_file.clone(),
147 self.timeout,
148 Some(self.retries),
149 )
150 .await
151 {
152 Ok(x) => return Ok(x),
153 Err(e) => error = Some(e),
154 }
155 sleep(Self::RETRY_DELAY).await;
156 }
157 Err(error.unwrap())
158 }
159
160 pub async fn execute<I, S>(
162 &self,
163 instances: I,
164 command: S,
165 context: CommandContext,
166 ) -> SshResult<Vec<(String, String)>>
167 where
168 I: IntoIterator<Item = Instance>,
169 S: Into<String> + Clone + Send + 'static,
170 {
171 let targets = instances
172 .into_iter()
173 .map(|instance| (instance, command.clone()));
174 self.execute_per_instance(targets, context).await
175 }
176
177 pub async fn execute_per_instance<I, S>(
179 &self,
180 instances: I,
181 context: CommandContext,
182 ) -> SshResult<Vec<(String, String)>>
183 where
184 I: IntoIterator<Item = (Instance, S)>,
185 S: Into<String> + Send + 'static,
186 {
187 let handles = self.run_per_instance(instances, context).await;
188
189 try_join_all(handles)
190 .await
191 .unwrap()
192 .into_iter()
193 .collect::<SshResult<_>>()
194 }
195
196 async fn run_per_instance<I, S>(
197 &self,
198 instances: I,
199 context: CommandContext,
200 ) -> Vec<JoinHandle<SshResult<(String, String)>>>
201 where
202 I: IntoIterator<Item = (Instance, S)>,
203 S: Into<String> + Send + 'static,
204 {
205 instances
206 .into_iter()
207 .map(|(instance, command)| {
208 let ssh_manager = self.clone();
209 let context = context.clone();
210
211 tokio::spawn(async move {
212 let connection = ssh_manager.connect(instance.ssh_address()).await?;
213 connection.execute(context.apply(command)).await
215 })
216 })
217 .collect::<Vec<_>>()
218 }
219
220 pub async fn wait_for_command<I>(
222 &self,
223 instances: I,
224 command_id: &str,
225 status: CommandStatus,
226 ) -> SshResult<()>
227 where
228 I: IntoIterator<Item = Instance> + Clone,
229 {
230 loop {
231 sleep(Self::RETRY_DELAY).await;
232
233 let result = self
234 .execute(
235 instances.clone(),
236 "(tmux ls || true)",
237 CommandContext::default(),
238 )
239 .await?;
240 if result
241 .iter()
242 .all(|(stdout, _)| CommandStatus::status(command_id, stdout) == status)
243 {
244 break;
245 }
246 }
247 Ok(())
248 }
249
250 pub async fn wait_for_success<I, S>(&self, instances: I)
251 where
252 I: IntoIterator<Item = (Instance, S)> + Clone,
253 S: Into<String> + Send + 'static + Clone,
254 {
255 loop {
256 sleep(Self::RETRY_DELAY).await;
257
258 if self
259 .execute_per_instance(instances.clone(), CommandContext::default())
260 .await
261 .is_ok()
262 {
263 break;
264 }
265 }
266 }
267
268 pub async fn kill<I>(&self, instances: I, command_id: &str) -> SshResult<()>
270 where
271 I: IntoIterator<Item = Instance>,
272 {
273 let ssh_command = format!("(tmux kill-session -t {command_id} || true)");
274 let targets = instances.into_iter().map(|x| (x, ssh_command.clone()));
275 self.execute_per_instance(targets, CommandContext::default())
276 .await?;
277 Ok(())
278 }
279}
280
281struct Session {}
282
283#[async_trait]
284impl client::Handler for Session {
285 type Error = russh::Error;
286
287 async fn check_server_key(
288 &mut self,
289 _server_public_key: &key::PublicKey,
290 ) -> Result<bool, Self::Error> {
291 Ok(true)
292 }
293}
294
295pub struct SshConnection {
297 session: client::Handle<Session>,
299 address: SocketAddr,
301 retries: usize,
303}
304
305impl SshConnection {
306 const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
308
309 pub async fn new<P: AsRef<Path>>(
311 address: SocketAddr,
312 username: &str,
313 private_key_file: P,
314 inactivity_timeout: Option<Duration>,
315 retries: Option<usize>,
316 ) -> SshResult<Self> {
317 let key = russh_keys::load_secret_key(private_key_file, None)
318 .map_err(|error| SshError::PrivateKeyError { address, error })?;
319
320 let config = client::Config {
321 inactivity_timeout: inactivity_timeout.or(Some(Self::DEFAULT_TIMEOUT)),
322 ..<_>::default()
323 };
324
325 let mut session = client::connect(Arc::new(config), address, Session {})
326 .await
327 .map_err(|error| SshError::ConnectionError { address, error })?;
328
329 let _auth_res = session
330 .authenticate_publickey(username, Arc::new(key))
331 .await
332 .map_err(|error| SshError::SessionError { address, error })?;
333
334 Ok(Self {
335 session,
336 address,
337 retries: retries.unwrap_or_default(),
338 })
339 }
340
341 fn make_session_error(&self, error: russh::Error) -> SshError {
343 SshError::SessionError {
344 address: self.address,
345 error,
346 }
347 }
348
349 pub async fn execute(&self, command: String) -> SshResult<(String, String)> {
351 let mut error = None;
352 for _ in 0..self.retries + 1 {
353 let channel = match self.session.channel_open_session().await {
354 Ok(x) => x,
355 Err(e) => {
356 error = Some(self.make_session_error(e));
357 continue;
358 }
359 };
360 match self.execute_impl(channel, command.clone()).await {
361 r @ Ok(..) => return r,
362 Err(e) => error = Some(e),
363 }
364 }
365 Err(error.unwrap())
366 }
367
368 async fn execute_impl(
371 &self,
372 mut channel: Channel<Msg>,
373 command: String,
374 ) -> SshResult<(String, String)> {
375 channel
376 .exec(true, command)
377 .await
378 .map_err(|e| self.make_session_error(e))?;
379
380 let mut output = Vec::new();
381 let mut exit_code = None;
382
383 while let Some(msg) = channel.wait().await {
384 match msg {
385 russh::ChannelMsg::Data { ref data } => output.write_all(data).unwrap(),
386 russh::ChannelMsg::ExitStatus { exit_status } => exit_code = Some(exit_status),
387 _ => {}
388 }
389 }
390
391 channel
392 .close()
393 .await
394 .map_err(|error| self.make_session_error(error))?;
395
396 let output_str: String = String::from_utf8_lossy(&output).into();
397
398 ensure!(
399 exit_code.is_some() && exit_code.unwrap() == 0,
400 SshError::NonZeroExitCode {
401 address: self.address,
402 code: exit_code.unwrap(),
403 message: output_str
404 }
405 );
406
407 Ok((output_str.clone(), output_str))
408 }
409
410 pub async fn download<P: AsRef<Path>>(&self, path: P) -> SshResult<String> {
414 let mut error = None;
415 for _ in 0..self.retries + 1 {
416 match self
417 .execute(format!("cat {}", path.as_ref().to_str().unwrap()))
418 .await
419 {
420 Ok((file_data, _)) => return Ok(file_data),
421 Err(err) => error = Some(err),
422 }
423 }
424 Err(error.unwrap())
425 }
426}