iota_node/
admin.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{net::SocketAddr, str::FromStr, sync::Arc};
6
7use axum::{
8    Router,
9    extract::{Query, State},
10    http::StatusCode,
11    response::{IntoResponse as _, Response},
12    routing::{get, post},
13};
14use base64::Engine;
15use humantime::parse_duration;
16use iota_types::{
17    base_types::AuthorityName,
18    crypto::{RandomnessPartialSignature, RandomnessRound, RandomnessSignature},
19    error::IotaError,
20};
21use serde::Deserialize;
22use telemetry_subscribers::{TelemetryError, TracingHandle};
23use tokio::sync::oneshot;
24use tracing::info;
25
26use crate::IotaNode;
27
28// Example commands:
29//
30// Set buffer stake for current epoch 2 to 1500 basis points:
31//
32//   $ curl -X POST 'http://127.0.0.1:1337/set-override-buffer-stake?buffer_bps=1500&epoch=2'
33//
34// Clear buffer stake override for current epoch 2, use
35// ProtocolConfig::buffer_stake_for_protocol_upgrade_bps:
36//
37//   $ curl -X POST 'http://127.0.0.1:1337/clear-override-buffer-stake?epoch=2'
38//
39// Vote to close epoch 2 early
40//
41//   $ curl -X POST 'http://127.0.0.1:1337/force-close-epoch?epoch=2'
42//
43// View current all capabilities from all authorities that have been received by
44// this node:
45//
46//   $ curl 'http://127.0.0.1:1337/capabilities'
47//
48// View the node config (private keys will be masked):
49//
50//   $ curl 'http://127.0.0.1:1337/node-config'
51//
52// Set a time-limited tracing config. After the duration expires, tracing will
53// be disabled automatically.
54//
55//   $ curl -X POST 'http://127.0.0.1:1337/enable-tracing?filter=info&duration=10s'
56//
57// Reset tracing to the TRACE_FILTER env var.
58//
59//   $ curl -X POST 'http://127.0.0.1:1337/reset-tracing'
60//
61// Get the node's randomness partial signatures for round 123.
62//
63//  $ curl 'http://127.0.0.1:1337/randomness-partial-sigs?round=123'
64//
65// Inject a randomness partial signature from another node, bypassing validity
66// checks.
67//
68//  $ curl 'http://127.0.0.1:1337/randomness-inject-partial-sigs?authority_name=hexencodedname&round=123&sigs=base64encodedsigs'
69//
70// Inject a full signature from another node, bypassing validity checks.
71//
72//  $ curl 'http://127.0.0.1:1337/randomness-inject-full-sig?round=123&sigs=base64encodedsig'
73
74const LOGGING_ROUTE: &str = "/logging";
75const TRACING_ROUTE: &str = "/enable-tracing";
76const TRACING_RESET_ROUTE: &str = "/reset-tracing";
77const SET_BUFFER_STAKE_ROUTE: &str = "/set-override-buffer-stake";
78const CLEAR_BUFFER_STAKE_ROUTE: &str = "/clear-override-buffer-stake";
79const FORCE_CLOSE_EPOCH: &str = "/force-close-epoch";
80const CAPABILITIES: &str = "/capabilities";
81const NODE_CONFIG: &str = "/node-config";
82const RANDOMNESS_PARTIAL_SIGS_ROUTE: &str = "/randomness-partial-sigs";
83const RANDOMNESS_INJECT_PARTIAL_SIGS_ROUTE: &str = "/randomness-inject-partial-sigs";
84const RANDOMNESS_INJECT_FULL_SIG_ROUTE: &str = "/randomness-inject-full-sig";
85const FLAMEGRAPH_ROUTE: &str = "/flamegraph";
86
87struct AppState {
88    node: Arc<IotaNode>,
89    tracing_handle: TracingHandle,
90}
91
92pub async fn run_admin_server(
93    node: Arc<IotaNode>,
94    socket_address: SocketAddr,
95    tracing_handle: TracingHandle,
96) {
97    let filter = tracing_handle.get_log().unwrap();
98
99    let app_state = AppState {
100        node,
101        tracing_handle,
102    };
103
104    let app = Router::new()
105        .route(LOGGING_ROUTE, get(get_filter))
106        .route(CAPABILITIES, get(capabilities))
107        .route(NODE_CONFIG, get(node_config))
108        .route(LOGGING_ROUTE, post(set_filter))
109        .route(
110            SET_BUFFER_STAKE_ROUTE,
111            post(set_override_protocol_upgrade_buffer_stake),
112        )
113        .route(
114            CLEAR_BUFFER_STAKE_ROUTE,
115            post(clear_override_protocol_upgrade_buffer_stake),
116        )
117        .route(FORCE_CLOSE_EPOCH, post(force_close_epoch))
118        .route(TRACING_ROUTE, post(enable_tracing))
119        .route(TRACING_RESET_ROUTE, post(reset_tracing))
120        .route(RANDOMNESS_PARTIAL_SIGS_ROUTE, get(randomness_partial_sigs))
121        .route(
122            RANDOMNESS_INJECT_PARTIAL_SIGS_ROUTE,
123            post(randomness_inject_partial_sigs),
124        )
125        .route(
126            RANDOMNESS_INJECT_FULL_SIG_ROUTE,
127            post(randomness_inject_full_sig),
128        )
129        .route(FLAMEGRAPH_ROUTE, get(flamegraph))
130        .with_state(Arc::new(app_state));
131
132    info!(
133        filter =% filter,
134        address =% socket_address,
135        "starting admin server"
136    );
137
138    let listener = tokio::net::TcpListener::bind(&socket_address)
139        .await
140        .unwrap();
141    axum::serve(
142        listener,
143        app.into_make_service_with_connect_info::<SocketAddr>(),
144    )
145    .await
146    .unwrap();
147}
148
149#[derive(Deserialize)]
150struct EnableTracing {
151    // These params change the filter, and reset it after the duration expires.
152    filter: Option<String>,
153    duration: Option<String>,
154
155    // Change the trace output file (if file output was enabled at program start)
156    trace_file: Option<String>,
157
158    // Change the tracing sample rate
159    sample_rate: Option<f64>,
160}
161
162async fn enable_tracing(
163    State(state): State<Arc<AppState>>,
164    query: Query<EnableTracing>,
165) -> (StatusCode, String) {
166    let Query(EnableTracing {
167        filter,
168        duration,
169        trace_file,
170        sample_rate,
171    }) = query;
172
173    let mut response = Vec::new();
174
175    if let Some(sample_rate) = sample_rate {
176        state.tracing_handle.update_sampling_rate(sample_rate);
177        response.push(format!("sample rate set to {sample_rate:?}"));
178    }
179
180    if let Some(trace_file) = trace_file {
181        if let Err(err) = state.tracing_handle.update_trace_file(&trace_file) {
182            response.push(format!("can't update trace file: {err:?}"));
183            return (StatusCode::BAD_REQUEST, response.join("\n"));
184        } else {
185            response.push(format!("trace file set to {trace_file:?}"));
186        }
187    }
188
189    let Some(filter) = filter else {
190        return (StatusCode::OK, response.join("\n"));
191    };
192
193    // Duration is required if filter is set
194    let Some(duration) = duration else {
195        response.push("can't update filter: missing duration".into());
196        return (StatusCode::BAD_REQUEST, response.join("\n"));
197    };
198
199    let Ok(duration) = parse_duration(&duration) else {
200        response.push("can't update filter: invalid duration".into());
201        return (StatusCode::BAD_REQUEST, response.join("\n"));
202    };
203
204    match state.tracing_handle.update_trace_filter(&filter, duration) {
205        Ok(()) => {
206            response.push(format!("filter set to {filter:?}"));
207            response.push(format!("filter will be reset after {duration:?}"));
208            (StatusCode::OK, response.join("\n"))
209        }
210        Err(TelemetryError::TracingDisabled) => {
211            response.push("can't update filter: tracing is not enabled. to enable it, run the node with 'TRACE_FILTER' set.".into());
212            (StatusCode::NOT_IMPLEMENTED, response.join("\n"))
213        }
214        Err(err) => {
215            response.push(format!("can't update filter: {err:?}"));
216            (StatusCode::BAD_REQUEST, response.join("\n"))
217        }
218    }
219}
220
221async fn reset_tracing(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
222    match state.tracing_handle.reset_trace() {
223        Ok(()) => (
224            StatusCode::OK,
225            "tracing filter reset to TRACE_FILTER env var".into(),
226        ),
227        Err(TelemetryError::TracingDisabled) => (
228            StatusCode::NOT_IMPLEMENTED,
229            "tracing is not enabled. to enable it, run the node with 'TRACE_FILTER' set.".into(),
230        ),
231        Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
232    }
233}
234
235async fn get_filter(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
236    match state.tracing_handle.get_log() {
237        Ok(filter) => (StatusCode::OK, filter),
238        Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
239    }
240}
241
242async fn set_filter(
243    State(state): State<Arc<AppState>>,
244    new_filter: String,
245) -> (StatusCode, String) {
246    match state.tracing_handle.update_log(&new_filter) {
247        Ok(()) => {
248            info!(filter =% new_filter, "Log filter updated");
249            (StatusCode::OK, "".into())
250        }
251        Err(err) => (StatusCode::BAD_REQUEST, err.to_string()),
252    }
253}
254
255async fn capabilities(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
256    let epoch_store = state.node.state().load_epoch_store_one_call_per_task();
257
258    let mut output = String::new();
259    let capabilities = epoch_store.get_capabilities_v1();
260    for capability in capabilities.unwrap_or_default() {
261        output.push_str(&format!("{capability:?}\n"));
262    }
263
264    (StatusCode::OK, output)
265}
266
267async fn node_config(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
268    let node_config = &state.node.config;
269
270    // Note private keys will be masked
271    (StatusCode::OK, format!("{node_config:#?}\n"))
272}
273
274#[derive(Deserialize)]
275struct Epoch {
276    epoch: u64,
277}
278
279async fn clear_override_protocol_upgrade_buffer_stake(
280    State(state): State<Arc<AppState>>,
281    epoch: Query<Epoch>,
282) -> (StatusCode, String) {
283    let Query(Epoch { epoch }) = epoch;
284
285    match state
286        .node
287        .clear_override_protocol_upgrade_buffer_stake(epoch)
288    {
289        Ok(()) => (
290            StatusCode::OK,
291            "protocol upgrade buffer stake cleared\n".to_string(),
292        ),
293        Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
294    }
295}
296
297#[derive(Deserialize)]
298struct SetBufferStake {
299    buffer_bps: u64,
300    epoch: u64,
301}
302
303async fn set_override_protocol_upgrade_buffer_stake(
304    State(state): State<Arc<AppState>>,
305    buffer_state: Query<SetBufferStake>,
306) -> (StatusCode, String) {
307    let Query(SetBufferStake { buffer_bps, epoch }) = buffer_state;
308
309    match state
310        .node
311        .set_override_protocol_upgrade_buffer_stake(epoch, buffer_bps)
312    {
313        Ok(()) => (
314            StatusCode::OK,
315            format!("protocol upgrade buffer stake set to '{buffer_bps}'\n"),
316        ),
317        Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
318    }
319}
320
321async fn force_close_epoch(
322    State(state): State<Arc<AppState>>,
323    epoch: Query<Epoch>,
324) -> (StatusCode, String) {
325    let Query(Epoch {
326        epoch: expected_epoch,
327    }) = epoch;
328    let epoch_store = state.node.state().load_epoch_store_one_call_per_task();
329    let actual_epoch = epoch_store.epoch();
330    if actual_epoch != expected_epoch {
331        let err = IotaError::WrongEpoch {
332            expected_epoch,
333            actual_epoch,
334        };
335        return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string());
336    }
337
338    match state.node.close_epoch(&epoch_store).await {
339        Ok(()) => (
340            StatusCode::OK,
341            "close_epoch() called successfully\n".to_string(),
342        ),
343        Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
344    }
345}
346
347#[derive(Deserialize)]
348struct Round {
349    round: u64,
350}
351
352async fn randomness_partial_sigs(
353    State(state): State<Arc<AppState>>,
354    round: Query<Round>,
355) -> (StatusCode, String) {
356    let Query(Round { round }) = round;
357
358    let (tx, rx) = oneshot::channel();
359    state
360        .node
361        .randomness_handle()
362        .admin_get_partial_signatures(RandomnessRound(round), tx);
363
364    let sigs = match rx.await {
365        Ok(sigs) => sigs,
366        Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
367    };
368
369    let output = format!(
370        "{}\n",
371        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(sigs)
372    );
373
374    (StatusCode::OK, output)
375}
376
377#[derive(Deserialize)]
378struct PartialSigsToInject {
379    hex_authority_name: String,
380    round: u64,
381    base64_sigs: String,
382}
383
384async fn randomness_inject_partial_sigs(
385    State(state): State<Arc<AppState>>,
386    args: Query<PartialSigsToInject>,
387) -> (StatusCode, String) {
388    let Query(PartialSigsToInject {
389        hex_authority_name,
390        round,
391        base64_sigs,
392    }) = args;
393
394    let authority_name = match AuthorityName::from_str(hex_authority_name.as_str()) {
395        Ok(authority_name) => authority_name,
396        Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
397    };
398
399    let sigs: Vec<u8> = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(base64_sigs) {
400        Ok(sigs) => sigs,
401        Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
402    };
403
404    let sigs: Vec<RandomnessPartialSignature> = match bcs::from_bytes(&sigs) {
405        Ok(sigs) => sigs,
406        Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
407    };
408
409    let (tx_result, rx_result) = oneshot::channel();
410    state
411        .node
412        .randomness_handle()
413        .admin_inject_partial_signatures(authority_name, RandomnessRound(round), sigs, tx_result);
414
415    match rx_result.await {
416        Ok(Ok(())) => (StatusCode::OK, "partial signatures injected\n".to_string()),
417        Ok(Err(e)) => (StatusCode::BAD_REQUEST, e.to_string()),
418        Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
419    }
420}
421
422#[derive(Deserialize)]
423struct FullSigToInject {
424    round: u64,
425    base64_sig: String,
426}
427
428async fn randomness_inject_full_sig(
429    State(state): State<Arc<AppState>>,
430    args: Query<FullSigToInject>,
431) -> (StatusCode, String) {
432    let Query(FullSigToInject { round, base64_sig }) = args;
433
434    let sig: Vec<u8> = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(base64_sig) {
435        Ok(sig) => sig,
436        Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
437    };
438
439    let sig: RandomnessSignature = match bcs::from_bytes(&sig) {
440        Ok(sig) => sig,
441        Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
442    };
443
444    let (tx_result, rx_result) = oneshot::channel();
445    state.node.randomness_handle().admin_inject_full_signature(
446        RandomnessRound(round),
447        sig,
448        tx_result,
449    );
450
451    match rx_result.await {
452        Ok(Ok(())) => (StatusCode::OK, "full signature injected\n".to_string()),
453        Ok(Err(e)) => (StatusCode::BAD_REQUEST, e.to_string()),
454        Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
455    }
456}
457
458#[derive(Deserialize)]
459struct Flamegraph {
460    /// Toggle SVG response, otherwise return nested set model for Grafana.
461    #[serde(default)]
462    svg: bool,
463    /// SVG width in pixels (when missing or set to 0 will default to 1920).
464    #[serde(default)]
465    width: usize,
466    /// Select still running call graphs.
467    #[serde(default)]
468    running: bool,
469    /// Select already completed call graphs.
470    #[serde(default)]
471    completed: bool,
472    /// Select call graph with the given ID.
473    #[serde(default)]
474    graph_id: String,
475    /// Use memory allocations as span measure rather than duration.
476    #[serde(default)]
477    mem: bool,
478}
479
480async fn flamegraph(State(state): State<Arc<AppState>>, query: Query<Flamegraph>) -> Response {
481    if let Some(sub) = state.tracing_handle.get_flamegraph() {
482        let Query(Flamegraph {
483            svg,
484            width,
485            mut running,
486            mut completed,
487            graph_id,
488            mem,
489        }) = query;
490        if !running && !completed {
491            running = true;
492            completed = true;
493        }
494        if svg {
495            #[cfg(not(all(feature = "flamegraph-alloc", nightly)))]
496            {
497                if mem {
498                    return (
499                        StatusCode::BAD_REQUEST,
500                        "memory flamegraphs are not supported (re-run iota-node with 'flamegraph-alloc' feature enabled and on nightly Rust toolchain)",
501                    )
502                        .into_response();
503                }
504            }
505
506            // draw an svg
507            let width = if width == 0 { Some(1920) } else { Some(width) };
508            let config = telemetry_subscribers::flamegraph::SvgConfig {
509                width,
510                #[cfg(all(feature = "flamegraph-alloc", nightly))]
511                measure_mem: mem,
512                ..Default::default()
513            };
514            let svg = if !graph_id.is_empty() {
515                sub.get_svg(&graph_id, running, completed, &config)
516            } else {
517                sub.get_combined_svg("iota-node", running, completed, &config)
518            };
519            if let Some(svg) = svg {
520                (
521                    [(
522                        axum::http::header::CONTENT_TYPE,
523                        axum::http::header::HeaderValue::from_static("image/svg+xml"),
524                    )],
525                    svg.into_string(),
526                )
527                    .into_response()
528            } else {
529                (StatusCode::NOT_FOUND, "Flamegraphs not found\n").into_response()
530            }
531        } else {
532            // default nested set model for grafana
533            let nested_frames = if !graph_id.is_empty() {
534                sub.get_nested_set(&graph_id, running, completed)
535            } else {
536                sub.get_nested_sets("iota-node", running, completed)
537            };
538            if !nested_frames.is_empty() {
539                axum::Json(nested_frames).into_response()
540            } else {
541                (StatusCode::NOT_FOUND, "Flamegraphs not found\n").into_response()
542            }
543        }
544    } else {
545        (
546            StatusCode::NOT_FOUND,
547            "Flamegraphs are not enabled (re-run iota-node with TRACE_FLAMEGRAPH=1)\n",
548        )
549            .into_response()
550    }
551}