1use std::{net::SocketAddr, str::FromStr, sync::Arc};
6
7use axum::{
8 Router,
9 extract::{Query, State},
10 http::StatusCode,
11 response::{IntoResponse as _, Response},
12 routing::{get, post},
13};
14use base64::Engine;
15use humantime::parse_duration;
16use iota_types::{
17 base_types::AuthorityName,
18 crypto::{RandomnessPartialSignature, RandomnessRound, RandomnessSignature},
19 error::IotaError,
20};
21use serde::Deserialize;
22use telemetry_subscribers::{TelemetryError, TracingHandle};
23use tokio::sync::oneshot;
24use tracing::info;
25
26use crate::IotaNode;
27
28const LOGGING_ROUTE: &str = "/logging";
75const TRACING_ROUTE: &str = "/enable-tracing";
76const TRACING_RESET_ROUTE: &str = "/reset-tracing";
77const SET_BUFFER_STAKE_ROUTE: &str = "/set-override-buffer-stake";
78const CLEAR_BUFFER_STAKE_ROUTE: &str = "/clear-override-buffer-stake";
79const FORCE_CLOSE_EPOCH: &str = "/force-close-epoch";
80const CAPABILITIES: &str = "/capabilities";
81const NODE_CONFIG: &str = "/node-config";
82const RANDOMNESS_PARTIAL_SIGS_ROUTE: &str = "/randomness-partial-sigs";
83const RANDOMNESS_INJECT_PARTIAL_SIGS_ROUTE: &str = "/randomness-inject-partial-sigs";
84const RANDOMNESS_INJECT_FULL_SIG_ROUTE: &str = "/randomness-inject-full-sig";
85const FLAMEGRAPH_ROUTE: &str = "/flamegraph";
86
87struct AppState {
88 node: Arc<IotaNode>,
89 tracing_handle: TracingHandle,
90}
91
92pub async fn run_admin_server(
93 node: Arc<IotaNode>,
94 socket_address: SocketAddr,
95 tracing_handle: TracingHandle,
96) {
97 let filter = tracing_handle.get_log().unwrap();
98
99 let app_state = AppState {
100 node,
101 tracing_handle,
102 };
103
104 let app = Router::new()
105 .route(LOGGING_ROUTE, get(get_filter))
106 .route(CAPABILITIES, get(capabilities))
107 .route(NODE_CONFIG, get(node_config))
108 .route(LOGGING_ROUTE, post(set_filter))
109 .route(
110 SET_BUFFER_STAKE_ROUTE,
111 post(set_override_protocol_upgrade_buffer_stake),
112 )
113 .route(
114 CLEAR_BUFFER_STAKE_ROUTE,
115 post(clear_override_protocol_upgrade_buffer_stake),
116 )
117 .route(FORCE_CLOSE_EPOCH, post(force_close_epoch))
118 .route(TRACING_ROUTE, post(enable_tracing))
119 .route(TRACING_RESET_ROUTE, post(reset_tracing))
120 .route(RANDOMNESS_PARTIAL_SIGS_ROUTE, get(randomness_partial_sigs))
121 .route(
122 RANDOMNESS_INJECT_PARTIAL_SIGS_ROUTE,
123 post(randomness_inject_partial_sigs),
124 )
125 .route(
126 RANDOMNESS_INJECT_FULL_SIG_ROUTE,
127 post(randomness_inject_full_sig),
128 )
129 .route(FLAMEGRAPH_ROUTE, get(flamegraph))
130 .with_state(Arc::new(app_state));
131
132 info!(
133 filter =% filter,
134 address =% socket_address,
135 "starting admin server"
136 );
137
138 let listener = tokio::net::TcpListener::bind(&socket_address)
139 .await
140 .unwrap();
141 axum::serve(
142 listener,
143 app.into_make_service_with_connect_info::<SocketAddr>(),
144 )
145 .await
146 .unwrap();
147}
148
149#[derive(Deserialize)]
150struct EnableTracing {
151 filter: Option<String>,
153 duration: Option<String>,
154
155 trace_file: Option<String>,
157
158 sample_rate: Option<f64>,
160}
161
162async fn enable_tracing(
163 State(state): State<Arc<AppState>>,
164 query: Query<EnableTracing>,
165) -> (StatusCode, String) {
166 let Query(EnableTracing {
167 filter,
168 duration,
169 trace_file,
170 sample_rate,
171 }) = query;
172
173 let mut response = Vec::new();
174
175 if let Some(sample_rate) = sample_rate {
176 state.tracing_handle.update_sampling_rate(sample_rate);
177 response.push(format!("sample rate set to {sample_rate:?}"));
178 }
179
180 if let Some(trace_file) = trace_file {
181 if let Err(err) = state.tracing_handle.update_trace_file(&trace_file) {
182 response.push(format!("can't update trace file: {err:?}"));
183 return (StatusCode::BAD_REQUEST, response.join("\n"));
184 } else {
185 response.push(format!("trace file set to {trace_file:?}"));
186 }
187 }
188
189 let Some(filter) = filter else {
190 return (StatusCode::OK, response.join("\n"));
191 };
192
193 let Some(duration) = duration else {
195 response.push("can't update filter: missing duration".into());
196 return (StatusCode::BAD_REQUEST, response.join("\n"));
197 };
198
199 let Ok(duration) = parse_duration(&duration) else {
200 response.push("can't update filter: invalid duration".into());
201 return (StatusCode::BAD_REQUEST, response.join("\n"));
202 };
203
204 match state.tracing_handle.update_trace_filter(&filter, duration) {
205 Ok(()) => {
206 response.push(format!("filter set to {filter:?}"));
207 response.push(format!("filter will be reset after {duration:?}"));
208 (StatusCode::OK, response.join("\n"))
209 }
210 Err(TelemetryError::TracingDisabled) => {
211 response.push("can't update filter: tracing is not enabled. to enable it, run the node with 'TRACE_FILTER' set.".into());
212 (StatusCode::NOT_IMPLEMENTED, response.join("\n"))
213 }
214 Err(err) => {
215 response.push(format!("can't update filter: {err:?}"));
216 (StatusCode::BAD_REQUEST, response.join("\n"))
217 }
218 }
219}
220
221async fn reset_tracing(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
222 match state.tracing_handle.reset_trace() {
223 Ok(()) => (
224 StatusCode::OK,
225 "tracing filter reset to TRACE_FILTER env var".into(),
226 ),
227 Err(TelemetryError::TracingDisabled) => (
228 StatusCode::NOT_IMPLEMENTED,
229 "tracing is not enabled. to enable it, run the node with 'TRACE_FILTER' set.".into(),
230 ),
231 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
232 }
233}
234
235async fn get_filter(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
236 match state.tracing_handle.get_log() {
237 Ok(filter) => (StatusCode::OK, filter),
238 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
239 }
240}
241
242async fn set_filter(
243 State(state): State<Arc<AppState>>,
244 new_filter: String,
245) -> (StatusCode, String) {
246 match state.tracing_handle.update_log(&new_filter) {
247 Ok(()) => {
248 info!(filter =% new_filter, "Log filter updated");
249 (StatusCode::OK, "".into())
250 }
251 Err(err) => (StatusCode::BAD_REQUEST, err.to_string()),
252 }
253}
254
255async fn capabilities(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
256 let epoch_store = state.node.state().load_epoch_store_one_call_per_task();
257
258 let mut output = String::new();
259 let capabilities = epoch_store.get_capabilities_v1();
260 for capability in capabilities.unwrap_or_default() {
261 output.push_str(&format!("{capability:?}\n"));
262 }
263
264 (StatusCode::OK, output)
265}
266
267async fn node_config(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
268 let node_config = &state.node.config;
269
270 (StatusCode::OK, format!("{node_config:#?}\n"))
272}
273
274#[derive(Deserialize)]
275struct Epoch {
276 epoch: u64,
277}
278
279async fn clear_override_protocol_upgrade_buffer_stake(
280 State(state): State<Arc<AppState>>,
281 epoch: Query<Epoch>,
282) -> (StatusCode, String) {
283 let Query(Epoch { epoch }) = epoch;
284
285 match state
286 .node
287 .clear_override_protocol_upgrade_buffer_stake(epoch)
288 {
289 Ok(()) => (
290 StatusCode::OK,
291 "protocol upgrade buffer stake cleared\n".to_string(),
292 ),
293 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
294 }
295}
296
297#[derive(Deserialize)]
298struct SetBufferStake {
299 buffer_bps: u64,
300 epoch: u64,
301}
302
303async fn set_override_protocol_upgrade_buffer_stake(
304 State(state): State<Arc<AppState>>,
305 buffer_state: Query<SetBufferStake>,
306) -> (StatusCode, String) {
307 let Query(SetBufferStake { buffer_bps, epoch }) = buffer_state;
308
309 match state
310 .node
311 .set_override_protocol_upgrade_buffer_stake(epoch, buffer_bps)
312 {
313 Ok(()) => (
314 StatusCode::OK,
315 format!("protocol upgrade buffer stake set to '{buffer_bps}'\n"),
316 ),
317 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
318 }
319}
320
321async fn force_close_epoch(
322 State(state): State<Arc<AppState>>,
323 epoch: Query<Epoch>,
324) -> (StatusCode, String) {
325 let Query(Epoch {
326 epoch: expected_epoch,
327 }) = epoch;
328 let epoch_store = state.node.state().load_epoch_store_one_call_per_task();
329 let actual_epoch = epoch_store.epoch();
330 if actual_epoch != expected_epoch {
331 let err = IotaError::WrongEpoch {
332 expected_epoch,
333 actual_epoch,
334 };
335 return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string());
336 }
337
338 match state.node.close_epoch(&epoch_store).await {
339 Ok(()) => (
340 StatusCode::OK,
341 "close_epoch() called successfully\n".to_string(),
342 ),
343 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
344 }
345}
346
347#[derive(Deserialize)]
348struct Round {
349 round: u64,
350}
351
352async fn randomness_partial_sigs(
353 State(state): State<Arc<AppState>>,
354 round: Query<Round>,
355) -> (StatusCode, String) {
356 let Query(Round { round }) = round;
357
358 let (tx, rx) = oneshot::channel();
359 state
360 .node
361 .randomness_handle()
362 .admin_get_partial_signatures(RandomnessRound::new(round), tx);
363
364 let sigs = match rx.await {
365 Ok(sigs) => sigs,
366 Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
367 };
368
369 let output = format!(
370 "{}\n",
371 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(sigs)
372 );
373
374 (StatusCode::OK, output)
375}
376
377#[derive(Deserialize)]
378struct PartialSigsToInject {
379 hex_authority_name: String,
380 round: u64,
381 base64_sigs: String,
382}
383
384async fn randomness_inject_partial_sigs(
385 State(state): State<Arc<AppState>>,
386 args: Query<PartialSigsToInject>,
387) -> (StatusCode, String) {
388 let Query(PartialSigsToInject {
389 hex_authority_name,
390 round,
391 base64_sigs,
392 }) = args;
393
394 let authority_name = match AuthorityName::from_str(hex_authority_name.as_str()) {
395 Ok(authority_name) => authority_name,
396 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
397 };
398
399 let sigs: Vec<u8> = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(base64_sigs) {
400 Ok(sigs) => sigs,
401 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
402 };
403
404 let sigs: Vec<RandomnessPartialSignature> = match bcs::from_bytes(&sigs) {
405 Ok(sigs) => sigs,
406 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
407 };
408
409 let (tx_result, rx_result) = oneshot::channel();
410 state
411 .node
412 .randomness_handle()
413 .admin_inject_partial_signatures(
414 authority_name,
415 RandomnessRound::new(round),
416 sigs,
417 tx_result,
418 );
419
420 match rx_result.await {
421 Ok(Ok(())) => (StatusCode::OK, "partial signatures injected\n".to_string()),
422 Ok(Err(e)) => (StatusCode::BAD_REQUEST, e.to_string()),
423 Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
424 }
425}
426
427#[derive(Deserialize)]
428struct FullSigToInject {
429 round: u64,
430 base64_sig: String,
431}
432
433async fn randomness_inject_full_sig(
434 State(state): State<Arc<AppState>>,
435 args: Query<FullSigToInject>,
436) -> (StatusCode, String) {
437 let Query(FullSigToInject { round, base64_sig }) = args;
438
439 let sig: Vec<u8> = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(base64_sig) {
440 Ok(sig) => sig,
441 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
442 };
443
444 let sig: RandomnessSignature = match bcs::from_bytes(&sig) {
445 Ok(sig) => sig,
446 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
447 };
448
449 let (tx_result, rx_result) = oneshot::channel();
450 state.node.randomness_handle().admin_inject_full_signature(
451 RandomnessRound::new(round),
452 sig,
453 tx_result,
454 );
455
456 match rx_result.await {
457 Ok(Ok(())) => (StatusCode::OK, "full signature injected\n".to_string()),
458 Ok(Err(e)) => (StatusCode::BAD_REQUEST, e.to_string()),
459 Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
460 }
461}
462
463#[derive(Deserialize)]
464struct Flamegraph {
465 #[serde(default)]
467 svg: bool,
468 #[serde(default)]
470 width: usize,
471 #[serde(default)]
473 running: bool,
474 #[serde(default)]
476 completed: bool,
477 #[serde(default)]
479 graph_id: String,
480 #[serde(default)]
482 mem: bool,
483}
484
485async fn flamegraph(State(state): State<Arc<AppState>>, query: Query<Flamegraph>) -> Response {
486 if let Some(sub) = state.tracing_handle.get_flamegraph() {
487 let Query(Flamegraph {
488 svg,
489 width,
490 mut running,
491 mut completed,
492 graph_id,
493 mem,
494 }) = query;
495 if !running && !completed {
496 running = true;
497 completed = true;
498 }
499 if svg {
500 #[cfg(not(all(feature = "flamegraph-alloc", nightly)))]
501 {
502 if mem {
503 return (
504 StatusCode::BAD_REQUEST,
505 "memory flamegraphs are not supported (re-run iota-node with 'flamegraph-alloc' feature enabled and on nightly Rust toolchain)",
506 )
507 .into_response();
508 }
509 }
510
511 let width = if width == 0 { Some(1920) } else { Some(width) };
513 let config = telemetry_subscribers::flamegraph::SvgConfig {
514 width,
515 #[cfg(all(feature = "flamegraph-alloc", nightly))]
516 measure_mem: mem,
517 ..Default::default()
518 };
519 let svg = if !graph_id.is_empty() {
520 sub.get_svg(&graph_id, running, completed, &config)
521 } else {
522 sub.get_combined_svg("iota-node", running, completed, &config)
523 };
524 if let Some(svg) = svg {
525 (
526 [(
527 axum::http::header::CONTENT_TYPE,
528 axum::http::header::HeaderValue::from_static("image/svg+xml"),
529 )],
530 svg.into_string(),
531 )
532 .into_response()
533 } else {
534 (StatusCode::NOT_FOUND, "Flamegraphs not found\n").into_response()
535 }
536 } else {
537 let nested_frames = if !graph_id.is_empty() {
539 sub.get_nested_set(&graph_id, running, completed)
540 } else {
541 sub.get_nested_sets("iota-node", running, completed)
542 };
543 if !nested_frames.is_empty() {
544 axum::Json(nested_frames).into_response()
545 } else {
546 (StatusCode::NOT_FOUND, "Flamegraphs not found\n").into_response()
547 }
548 }
549 } else {
550 (
551 StatusCode::NOT_FOUND,
552 "Flamegraphs are not enabled (re-run iota-node with TRACE_FLAMEGRAPH=1)\n",
553 )
554 .into_response()
555 }
556}