1use std::{net::SocketAddr, str::FromStr, sync::Arc};
6
7use axum::{
8 Router,
9 extract::{Query, State},
10 http::StatusCode,
11 routing::{get, post},
12};
13use base64::Engine;
14use humantime::parse_duration;
15use iota_types::{
16 base_types::AuthorityName,
17 crypto::{RandomnessPartialSignature, RandomnessRound, RandomnessSignature},
18 error::IotaError,
19};
20use serde::Deserialize;
21use telemetry_subscribers::{TelemetryError, TracingHandle};
22use tokio::sync::oneshot;
23use tracing::info;
24
25use crate::IotaNode;
26
27const LOGGING_ROUTE: &str = "/logging";
74const TRACING_ROUTE: &str = "/enable-tracing";
75const TRACING_RESET_ROUTE: &str = "/reset-tracing";
76const SET_BUFFER_STAKE_ROUTE: &str = "/set-override-buffer-stake";
77const CLEAR_BUFFER_STAKE_ROUTE: &str = "/clear-override-buffer-stake";
78const FORCE_CLOSE_EPOCH: &str = "/force-close-epoch";
79const CAPABILITIES: &str = "/capabilities";
80const NODE_CONFIG: &str = "/node-config";
81const RANDOMNESS_PARTIAL_SIGS_ROUTE: &str = "/randomness-partial-sigs";
82const RANDOMNESS_INJECT_PARTIAL_SIGS_ROUTE: &str = "/randomness-inject-partial-sigs";
83const RANDOMNESS_INJECT_FULL_SIG_ROUTE: &str = "/randomness-inject-full-sig";
84
85struct AppState {
86 node: Arc<IotaNode>,
87 tracing_handle: TracingHandle,
88}
89
90pub async fn run_admin_server(
91 node: Arc<IotaNode>,
92 socket_address: SocketAddr,
93 tracing_handle: TracingHandle,
94) {
95 let filter = tracing_handle.get_log().unwrap();
96
97 let app_state = AppState {
98 node,
99 tracing_handle,
100 };
101
102 let app = Router::new()
103 .route(LOGGING_ROUTE, get(get_filter))
104 .route(CAPABILITIES, get(capabilities))
105 .route(NODE_CONFIG, get(node_config))
106 .route(LOGGING_ROUTE, post(set_filter))
107 .route(
108 SET_BUFFER_STAKE_ROUTE,
109 post(set_override_protocol_upgrade_buffer_stake),
110 )
111 .route(
112 CLEAR_BUFFER_STAKE_ROUTE,
113 post(clear_override_protocol_upgrade_buffer_stake),
114 )
115 .route(FORCE_CLOSE_EPOCH, post(force_close_epoch))
116 .route(TRACING_ROUTE, post(enable_tracing))
117 .route(TRACING_RESET_ROUTE, post(reset_tracing))
118 .route(RANDOMNESS_PARTIAL_SIGS_ROUTE, get(randomness_partial_sigs))
119 .route(
120 RANDOMNESS_INJECT_PARTIAL_SIGS_ROUTE,
121 post(randomness_inject_partial_sigs),
122 )
123 .route(
124 RANDOMNESS_INJECT_FULL_SIG_ROUTE,
125 post(randomness_inject_full_sig),
126 )
127 .with_state(Arc::new(app_state));
128
129 info!(
130 filter =% filter,
131 address =% socket_address,
132 "starting admin server"
133 );
134
135 let listener = tokio::net::TcpListener::bind(&socket_address)
136 .await
137 .unwrap();
138 axum::serve(
139 listener,
140 app.into_make_service_with_connect_info::<SocketAddr>(),
141 )
142 .await
143 .unwrap();
144}
145
146#[derive(Deserialize)]
147struct EnableTracing {
148 filter: Option<String>,
150 duration: Option<String>,
151
152 trace_file: Option<String>,
154
155 sample_rate: Option<f64>,
157}
158
159async fn enable_tracing(
160 State(state): State<Arc<AppState>>,
161 query: Query<EnableTracing>,
162) -> (StatusCode, String) {
163 let Query(EnableTracing {
164 filter,
165 duration,
166 trace_file,
167 sample_rate,
168 }) = query;
169
170 let mut response = Vec::new();
171
172 if let Some(sample_rate) = sample_rate {
173 state.tracing_handle.update_sampling_rate(sample_rate);
174 response.push(format!("sample rate set to {:?}", sample_rate));
175 }
176
177 if let Some(trace_file) = trace_file {
178 if let Err(err) = state.tracing_handle.update_trace_file(&trace_file) {
179 response.push(format!("can't update trace file: {:?}", err));
180 return (StatusCode::BAD_REQUEST, response.join("\n"));
181 } else {
182 response.push(format!("trace file set to {:?}", trace_file));
183 }
184 }
185
186 let Some(filter) = filter else {
187 return (StatusCode::OK, response.join("\n"));
188 };
189
190 let Some(duration) = duration else {
192 response.push("can't update filter: missing duration".into());
193 return (StatusCode::BAD_REQUEST, response.join("\n"));
194 };
195
196 let Ok(duration) = parse_duration(&duration) else {
197 response.push("can't update filter: invalid duration".into());
198 return (StatusCode::BAD_REQUEST, response.join("\n"));
199 };
200
201 match state.tracing_handle.update_trace_filter(&filter, duration) {
202 Ok(()) => {
203 response.push(format!("filter set to {:?}", filter));
204 response.push(format!("filter will be reset after {:?}", duration));
205 (StatusCode::OK, response.join("\n"))
206 }
207 Err(TelemetryError::TracingDisabled) => {
208 response.push("can't update filter: tracing is not enabled. to enable it, run the node with 'TRACE_FILTER' set.".into());
209 (StatusCode::NOT_IMPLEMENTED, response.join("\n"))
210 }
211 Err(err) => {
212 response.push(format!("can't update filter: {:?}", err));
213 (StatusCode::BAD_REQUEST, response.join("\n"))
214 }
215 }
216}
217
218async fn reset_tracing(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
219 match state.tracing_handle.reset_trace() {
220 Ok(()) => (
221 StatusCode::OK,
222 "tracing filter reset to TRACE_FILTER env var".into(),
223 ),
224 Err(TelemetryError::TracingDisabled) => (
225 StatusCode::NOT_IMPLEMENTED,
226 "tracing is not enabled. to enable it, run the node with 'TRACE_FILTER' set.".into(),
227 ),
228 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
229 }
230}
231
232async fn get_filter(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
233 match state.tracing_handle.get_log() {
234 Ok(filter) => (StatusCode::OK, filter),
235 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
236 }
237}
238
239async fn set_filter(
240 State(state): State<Arc<AppState>>,
241 new_filter: String,
242) -> (StatusCode, String) {
243 match state.tracing_handle.update_log(&new_filter) {
244 Ok(()) => {
245 info!(filter =% new_filter, "Log filter updated");
246 (StatusCode::OK, "".into())
247 }
248 Err(err) => (StatusCode::BAD_REQUEST, err.to_string()),
249 }
250}
251
252async fn capabilities(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
253 let epoch_store = state.node.state().load_epoch_store_one_call_per_task();
254
255 let mut output = String::new();
256 let capabilities = epoch_store.get_capabilities_v1();
257 for capability in capabilities.unwrap_or_default() {
258 output.push_str(&format!("{:?}\n", capability));
259 }
260
261 (StatusCode::OK, output)
262}
263
264async fn node_config(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
265 let node_config = &state.node.config;
266
267 (StatusCode::OK, format!("{:#?}\n", node_config))
269}
270
271#[derive(Deserialize)]
272struct Epoch {
273 epoch: u64,
274}
275
276async fn clear_override_protocol_upgrade_buffer_stake(
277 State(state): State<Arc<AppState>>,
278 epoch: Query<Epoch>,
279) -> (StatusCode, String) {
280 let Query(Epoch { epoch }) = epoch;
281
282 match state
283 .node
284 .clear_override_protocol_upgrade_buffer_stake(epoch)
285 {
286 Ok(()) => (
287 StatusCode::OK,
288 "protocol upgrade buffer stake cleared\n".to_string(),
289 ),
290 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
291 }
292}
293
294#[derive(Deserialize)]
295struct SetBufferStake {
296 buffer_bps: u64,
297 epoch: u64,
298}
299
300async fn set_override_protocol_upgrade_buffer_stake(
301 State(state): State<Arc<AppState>>,
302 buffer_state: Query<SetBufferStake>,
303) -> (StatusCode, String) {
304 let Query(SetBufferStake { buffer_bps, epoch }) = buffer_state;
305
306 match state
307 .node
308 .set_override_protocol_upgrade_buffer_stake(epoch, buffer_bps)
309 {
310 Ok(()) => (
311 StatusCode::OK,
312 format!("protocol upgrade buffer stake set to '{}'\n", buffer_bps),
313 ),
314 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
315 }
316}
317
318async fn force_close_epoch(
319 State(state): State<Arc<AppState>>,
320 epoch: Query<Epoch>,
321) -> (StatusCode, String) {
322 let Query(Epoch {
323 epoch: expected_epoch,
324 }) = epoch;
325 let epoch_store = state.node.state().load_epoch_store_one_call_per_task();
326 let actual_epoch = epoch_store.epoch();
327 if actual_epoch != expected_epoch {
328 let err = IotaError::WrongEpoch {
329 expected_epoch,
330 actual_epoch,
331 };
332 return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string());
333 }
334
335 match state.node.close_epoch(&epoch_store).await {
336 Ok(()) => (
337 StatusCode::OK,
338 "close_epoch() called successfully\n".to_string(),
339 ),
340 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
341 }
342}
343
344#[derive(Deserialize)]
345struct Round {
346 round: u64,
347}
348
349async fn randomness_partial_sigs(
350 State(state): State<Arc<AppState>>,
351 round: Query<Round>,
352) -> (StatusCode, String) {
353 let Query(Round { round }) = round;
354
355 let (tx, rx) = oneshot::channel();
356 state
357 .node
358 .randomness_handle()
359 .admin_get_partial_signatures(RandomnessRound(round), tx);
360
361 let sigs = match rx.await {
362 Ok(sigs) => sigs,
363 Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
364 };
365
366 let output = format!(
367 "{}\n",
368 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(sigs)
369 );
370
371 (StatusCode::OK, output)
372}
373
374#[derive(Deserialize)]
375struct PartialSigsToInject {
376 hex_authority_name: String,
377 round: u64,
378 base64_sigs: String,
379}
380
381async fn randomness_inject_partial_sigs(
382 State(state): State<Arc<AppState>>,
383 args: Query<PartialSigsToInject>,
384) -> (StatusCode, String) {
385 let Query(PartialSigsToInject {
386 hex_authority_name,
387 round,
388 base64_sigs,
389 }) = args;
390
391 let authority_name = match AuthorityName::from_str(hex_authority_name.as_str()) {
392 Ok(authority_name) => authority_name,
393 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
394 };
395
396 let sigs: Vec<u8> = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(base64_sigs) {
397 Ok(sigs) => sigs,
398 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
399 };
400
401 let sigs: Vec<RandomnessPartialSignature> = match bcs::from_bytes(&sigs) {
402 Ok(sigs) => sigs,
403 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
404 };
405
406 let (tx_result, rx_result) = oneshot::channel();
407 state
408 .node
409 .randomness_handle()
410 .admin_inject_partial_signatures(authority_name, RandomnessRound(round), sigs, tx_result);
411
412 match rx_result.await {
413 Ok(Ok(())) => (StatusCode::OK, "partial signatures injected\n".to_string()),
414 Ok(Err(e)) => (StatusCode::BAD_REQUEST, e.to_string()),
415 Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
416 }
417}
418
419#[derive(Deserialize)]
420struct FullSigToInject {
421 round: u64,
422 base64_sig: String,
423}
424
425async fn randomness_inject_full_sig(
426 State(state): State<Arc<AppState>>,
427 args: Query<FullSigToInject>,
428) -> (StatusCode, String) {
429 let Query(FullSigToInject { round, base64_sig }) = args;
430
431 let sig: Vec<u8> = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(base64_sig) {
432 Ok(sig) => sig,
433 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
434 };
435
436 let sig: RandomnessSignature = match bcs::from_bytes(&sig) {
437 Ok(sig) => sig,
438 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
439 };
440
441 let (tx_result, rx_result) = oneshot::channel();
442 state.node.randomness_handle().admin_inject_full_signature(
443 RandomnessRound(round),
444 sig,
445 tx_result,
446 );
447
448 match rx_result.await {
449 Ok(Ok(())) => (StatusCode::OK, "full signature injected\n".to_string()),
450 Ok(Err(e)) => (StatusCode::BAD_REQUEST, e.to_string()),
451 Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
452 }
453}