1use std::{net::SocketAddr, str::FromStr, sync::Arc};
6
7use axum::{
8 Router,
9 extract::{Query, State},
10 http::StatusCode,
11 response::{IntoResponse as _, Response},
12 routing::{get, post},
13};
14use base64::Engine;
15use humantime::parse_duration;
16use iota_types::{
17 base_types::AuthorityName,
18 crypto::{RandomnessPartialSignature, RandomnessRound, RandomnessSignature},
19 error::IotaError,
20};
21use serde::Deserialize;
22use telemetry_subscribers::{TelemetryError, TracingHandle};
23use tokio::sync::oneshot;
24use tracing::info;
25
26use crate::IotaNode;
27
28const LOGGING_ROUTE: &str = "/logging";
75const TRACING_ROUTE: &str = "/enable-tracing";
76const TRACING_RESET_ROUTE: &str = "/reset-tracing";
77const SET_BUFFER_STAKE_ROUTE: &str = "/set-override-buffer-stake";
78const CLEAR_BUFFER_STAKE_ROUTE: &str = "/clear-override-buffer-stake";
79const FORCE_CLOSE_EPOCH: &str = "/force-close-epoch";
80const CAPABILITIES: &str = "/capabilities";
81const NODE_CONFIG: &str = "/node-config";
82const RANDOMNESS_PARTIAL_SIGS_ROUTE: &str = "/randomness-partial-sigs";
83const RANDOMNESS_INJECT_PARTIAL_SIGS_ROUTE: &str = "/randomness-inject-partial-sigs";
84const RANDOMNESS_INJECT_FULL_SIG_ROUTE: &str = "/randomness-inject-full-sig";
85const FLAMEGRAPH_ROUTE: &str = "/flamegraph";
86
87struct AppState {
88 node: Arc<IotaNode>,
89 tracing_handle: TracingHandle,
90}
91
92pub async fn run_admin_server(
93 node: Arc<IotaNode>,
94 socket_address: SocketAddr,
95 tracing_handle: TracingHandle,
96) {
97 let filter = tracing_handle.get_log().unwrap();
98
99 let app_state = AppState {
100 node,
101 tracing_handle,
102 };
103
104 let app = Router::new()
105 .route(LOGGING_ROUTE, get(get_filter))
106 .route(CAPABILITIES, get(capabilities))
107 .route(NODE_CONFIG, get(node_config))
108 .route(LOGGING_ROUTE, post(set_filter))
109 .route(
110 SET_BUFFER_STAKE_ROUTE,
111 post(set_override_protocol_upgrade_buffer_stake),
112 )
113 .route(
114 CLEAR_BUFFER_STAKE_ROUTE,
115 post(clear_override_protocol_upgrade_buffer_stake),
116 )
117 .route(FORCE_CLOSE_EPOCH, post(force_close_epoch))
118 .route(TRACING_ROUTE, post(enable_tracing))
119 .route(TRACING_RESET_ROUTE, post(reset_tracing))
120 .route(RANDOMNESS_PARTIAL_SIGS_ROUTE, get(randomness_partial_sigs))
121 .route(
122 RANDOMNESS_INJECT_PARTIAL_SIGS_ROUTE,
123 post(randomness_inject_partial_sigs),
124 )
125 .route(
126 RANDOMNESS_INJECT_FULL_SIG_ROUTE,
127 post(randomness_inject_full_sig),
128 )
129 .route(FLAMEGRAPH_ROUTE, get(flamegraph))
130 .with_state(Arc::new(app_state));
131
132 info!(
133 filter =% filter,
134 address =% socket_address,
135 "starting admin server"
136 );
137
138 let listener = tokio::net::TcpListener::bind(&socket_address)
139 .await
140 .unwrap();
141 axum::serve(
142 listener,
143 app.into_make_service_with_connect_info::<SocketAddr>(),
144 )
145 .await
146 .unwrap();
147}
148
149#[derive(Deserialize)]
150struct EnableTracing {
151 filter: Option<String>,
153 duration: Option<String>,
154
155 trace_file: Option<String>,
157
158 sample_rate: Option<f64>,
160}
161
162async fn enable_tracing(
163 State(state): State<Arc<AppState>>,
164 query: Query<EnableTracing>,
165) -> (StatusCode, String) {
166 let Query(EnableTracing {
167 filter,
168 duration,
169 trace_file,
170 sample_rate,
171 }) = query;
172
173 let mut response = Vec::new();
174
175 if let Some(sample_rate) = sample_rate {
176 state.tracing_handle.update_sampling_rate(sample_rate);
177 response.push(format!("sample rate set to {sample_rate:?}"));
178 }
179
180 if let Some(trace_file) = trace_file {
181 if let Err(err) = state.tracing_handle.update_trace_file(&trace_file) {
182 response.push(format!("can't update trace file: {err:?}"));
183 return (StatusCode::BAD_REQUEST, response.join("\n"));
184 } else {
185 response.push(format!("trace file set to {trace_file:?}"));
186 }
187 }
188
189 let Some(filter) = filter else {
190 return (StatusCode::OK, response.join("\n"));
191 };
192
193 let Some(duration) = duration else {
195 response.push("can't update filter: missing duration".into());
196 return (StatusCode::BAD_REQUEST, response.join("\n"));
197 };
198
199 let Ok(duration) = parse_duration(&duration) else {
200 response.push("can't update filter: invalid duration".into());
201 return (StatusCode::BAD_REQUEST, response.join("\n"));
202 };
203
204 match state.tracing_handle.update_trace_filter(&filter, duration) {
205 Ok(()) => {
206 response.push(format!("filter set to {filter:?}"));
207 response.push(format!("filter will be reset after {duration:?}"));
208 (StatusCode::OK, response.join("\n"))
209 }
210 Err(TelemetryError::TracingDisabled) => {
211 response.push("can't update filter: tracing is not enabled. to enable it, run the node with 'TRACE_FILTER' set.".into());
212 (StatusCode::NOT_IMPLEMENTED, response.join("\n"))
213 }
214 Err(err) => {
215 response.push(format!("can't update filter: {err:?}"));
216 (StatusCode::BAD_REQUEST, response.join("\n"))
217 }
218 }
219}
220
221async fn reset_tracing(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
222 match state.tracing_handle.reset_trace() {
223 Ok(()) => (
224 StatusCode::OK,
225 "tracing filter reset to TRACE_FILTER env var".into(),
226 ),
227 Err(TelemetryError::TracingDisabled) => (
228 StatusCode::NOT_IMPLEMENTED,
229 "tracing is not enabled. to enable it, run the node with 'TRACE_FILTER' set.".into(),
230 ),
231 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
232 }
233}
234
235async fn get_filter(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
236 match state.tracing_handle.get_log() {
237 Ok(filter) => (StatusCode::OK, filter),
238 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
239 }
240}
241
242async fn set_filter(
243 State(state): State<Arc<AppState>>,
244 new_filter: String,
245) -> (StatusCode, String) {
246 match state.tracing_handle.update_log(&new_filter) {
247 Ok(()) => {
248 info!(filter =% new_filter, "Log filter updated");
249 (StatusCode::OK, "".into())
250 }
251 Err(err) => (StatusCode::BAD_REQUEST, err.to_string()),
252 }
253}
254
255async fn capabilities(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
256 let epoch_store = state.node.state().load_epoch_store_one_call_per_task();
257
258 let mut output = String::new();
259 let capabilities = epoch_store.get_capabilities_v1();
260 for capability in capabilities.unwrap_or_default() {
261 output.push_str(&format!("{capability:?}\n"));
262 }
263
264 (StatusCode::OK, output)
265}
266
267async fn node_config(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
268 let node_config = &state.node.config;
269
270 (StatusCode::OK, format!("{node_config:#?}\n"))
272}
273
274#[derive(Deserialize)]
275struct Epoch {
276 epoch: u64,
277}
278
279async fn clear_override_protocol_upgrade_buffer_stake(
280 State(state): State<Arc<AppState>>,
281 epoch: Query<Epoch>,
282) -> (StatusCode, String) {
283 let Query(Epoch { epoch }) = epoch;
284
285 match state
286 .node
287 .clear_override_protocol_upgrade_buffer_stake(epoch)
288 {
289 Ok(()) => (
290 StatusCode::OK,
291 "protocol upgrade buffer stake cleared\n".to_string(),
292 ),
293 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
294 }
295}
296
297#[derive(Deserialize)]
298struct SetBufferStake {
299 buffer_bps: u64,
300 epoch: u64,
301}
302
303async fn set_override_protocol_upgrade_buffer_stake(
304 State(state): State<Arc<AppState>>,
305 buffer_state: Query<SetBufferStake>,
306) -> (StatusCode, String) {
307 let Query(SetBufferStake { buffer_bps, epoch }) = buffer_state;
308
309 match state
310 .node
311 .set_override_protocol_upgrade_buffer_stake(epoch, buffer_bps)
312 {
313 Ok(()) => (
314 StatusCode::OK,
315 format!("protocol upgrade buffer stake set to '{buffer_bps}'\n"),
316 ),
317 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
318 }
319}
320
321async fn force_close_epoch(
322 State(state): State<Arc<AppState>>,
323 epoch: Query<Epoch>,
324) -> (StatusCode, String) {
325 let Query(Epoch {
326 epoch: expected_epoch,
327 }) = epoch;
328 let epoch_store = state.node.state().load_epoch_store_one_call_per_task();
329 let actual_epoch = epoch_store.epoch();
330 if actual_epoch != expected_epoch {
331 let err = IotaError::WrongEpoch {
332 expected_epoch,
333 actual_epoch,
334 };
335 return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string());
336 }
337
338 match state.node.close_epoch(&epoch_store).await {
339 Ok(()) => (
340 StatusCode::OK,
341 "close_epoch() called successfully\n".to_string(),
342 ),
343 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
344 }
345}
346
347#[derive(Deserialize)]
348struct Round {
349 round: u64,
350}
351
352async fn randomness_partial_sigs(
353 State(state): State<Arc<AppState>>,
354 round: Query<Round>,
355) -> (StatusCode, String) {
356 let Query(Round { round }) = round;
357
358 let (tx, rx) = oneshot::channel();
359 state
360 .node
361 .randomness_handle()
362 .admin_get_partial_signatures(RandomnessRound(round), tx);
363
364 let sigs = match rx.await {
365 Ok(sigs) => sigs,
366 Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
367 };
368
369 let output = format!(
370 "{}\n",
371 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(sigs)
372 );
373
374 (StatusCode::OK, output)
375}
376
377#[derive(Deserialize)]
378struct PartialSigsToInject {
379 hex_authority_name: String,
380 round: u64,
381 base64_sigs: String,
382}
383
384async fn randomness_inject_partial_sigs(
385 State(state): State<Arc<AppState>>,
386 args: Query<PartialSigsToInject>,
387) -> (StatusCode, String) {
388 let Query(PartialSigsToInject {
389 hex_authority_name,
390 round,
391 base64_sigs,
392 }) = args;
393
394 let authority_name = match AuthorityName::from_str(hex_authority_name.as_str()) {
395 Ok(authority_name) => authority_name,
396 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
397 };
398
399 let sigs: Vec<u8> = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(base64_sigs) {
400 Ok(sigs) => sigs,
401 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
402 };
403
404 let sigs: Vec<RandomnessPartialSignature> = match bcs::from_bytes(&sigs) {
405 Ok(sigs) => sigs,
406 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
407 };
408
409 let (tx_result, rx_result) = oneshot::channel();
410 state
411 .node
412 .randomness_handle()
413 .admin_inject_partial_signatures(authority_name, RandomnessRound(round), sigs, tx_result);
414
415 match rx_result.await {
416 Ok(Ok(())) => (StatusCode::OK, "partial signatures injected\n".to_string()),
417 Ok(Err(e)) => (StatusCode::BAD_REQUEST, e.to_string()),
418 Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
419 }
420}
421
422#[derive(Deserialize)]
423struct FullSigToInject {
424 round: u64,
425 base64_sig: String,
426}
427
428async fn randomness_inject_full_sig(
429 State(state): State<Arc<AppState>>,
430 args: Query<FullSigToInject>,
431) -> (StatusCode, String) {
432 let Query(FullSigToInject { round, base64_sig }) = args;
433
434 let sig: Vec<u8> = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(base64_sig) {
435 Ok(sig) => sig,
436 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
437 };
438
439 let sig: RandomnessSignature = match bcs::from_bytes(&sig) {
440 Ok(sig) => sig,
441 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
442 };
443
444 let (tx_result, rx_result) = oneshot::channel();
445 state.node.randomness_handle().admin_inject_full_signature(
446 RandomnessRound(round),
447 sig,
448 tx_result,
449 );
450
451 match rx_result.await {
452 Ok(Ok(())) => (StatusCode::OK, "full signature injected\n".to_string()),
453 Ok(Err(e)) => (StatusCode::BAD_REQUEST, e.to_string()),
454 Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
455 }
456}
457
458#[derive(Deserialize)]
459struct Flamegraph {
460 #[serde(default)]
462 svg: bool,
463 #[serde(default)]
465 width: usize,
466 #[serde(default)]
468 running: bool,
469 #[serde(default)]
471 completed: bool,
472 #[serde(default)]
474 graph_id: String,
475 #[serde(default)]
477 mem: bool,
478}
479
480async fn flamegraph(State(state): State<Arc<AppState>>, query: Query<Flamegraph>) -> Response {
481 if let Some(sub) = state.tracing_handle.get_flamegraph() {
482 let Query(Flamegraph {
483 svg,
484 width,
485 mut running,
486 mut completed,
487 graph_id,
488 mem,
489 }) = query;
490 if !running && !completed {
491 running = true;
492 completed = true;
493 }
494 if svg {
495 #[cfg(not(all(feature = "flamegraph-alloc", nightly)))]
496 {
497 if mem {
498 return (
499 StatusCode::BAD_REQUEST,
500 "memory flamegraphs are not supported (re-run iota-node with 'flamegraph-alloc' feature enabled and on nightly Rust toolchain)",
501 )
502 .into_response();
503 }
504 }
505
506 let width = if width == 0 { Some(1920) } else { Some(width) };
508 let config = telemetry_subscribers::flamegraph::SvgConfig {
509 width,
510 #[cfg(all(feature = "flamegraph-alloc", nightly))]
511 measure_mem: mem,
512 ..Default::default()
513 };
514 let svg = if !graph_id.is_empty() {
515 sub.get_svg(&graph_id, running, completed, &config)
516 } else {
517 sub.get_combined_svg("iota-node", running, completed, &config)
518 };
519 if let Some(svg) = svg {
520 (
521 [(
522 axum::http::header::CONTENT_TYPE,
523 axum::http::header::HeaderValue::from_static("image/svg+xml"),
524 )],
525 svg.into_string(),
526 )
527 .into_response()
528 } else {
529 (StatusCode::NOT_FOUND, "Flamegraphs not found\n").into_response()
530 }
531 } else {
532 let nested_frames = if !graph_id.is_empty() {
534 sub.get_nested_set(&graph_id, running, completed)
535 } else {
536 sub.get_nested_sets("iota-node", running, completed)
537 };
538 if !nested_frames.is_empty() {
539 axum::Json(nested_frames).into_response()
540 } else {
541 (StatusCode::NOT_FOUND, "Flamegraphs not found\n").into_response()
542 }
543 }
544 } else {
545 (
546 StatusCode::NOT_FOUND,
547 "Flamegraphs are not enabled (re-run iota-node with TRACE_FLAMEGRAPH=1)\n",
548 )
549 .into_response()
550 }
551}