iota_open_rpc/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5extern crate core;
6
7use std::collections::{BTreeMap, HashMap, btree_map::Entry::Occupied};
8
9use schemars::{
10    JsonSchema,
11    gen::{SchemaGenerator, SchemaSettings},
12    schema::SchemaObject,
13};
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use versions::Versioning;
17
18/// OPEN-RPC documentation following the OpenRPC specification <https://spec.open-rpc.org>
19/// The implementation is partial, only required fields and subset of optional
20/// fields in the specification are implemented catered to IOTA's need.
21#[derive(Serialize, Deserialize, Clone)]
22pub struct Project {
23    openrpc: String,
24    info: Info,
25    methods: Vec<Method>,
26    components: Components,
27    // Method routing for backward compatibility, not part of the open rpc spec.
28    #[serde(skip)]
29    pub method_routing: HashMap<String, MethodRouting>,
30}
31
32impl Project {
33    pub fn new(
34        version: &str,
35        title: &str,
36        description: &str,
37        contact_name: &str,
38        url: &str,
39        email: &str,
40        license: &str,
41        license_url: &str,
42    ) -> Self {
43        let openrpc = "1.2.6".to_string();
44        Self {
45            openrpc,
46            info: Info {
47                title: title.to_string(),
48                description: Some(description.to_string()),
49                contact: Some(Contact {
50                    name: contact_name.to_string(),
51                    url: Some(url.to_string()),
52                    email: Some(email.to_string()),
53                }),
54                license: Some(License {
55                    name: license.to_string(),
56                    url: Some(license_url.to_string()),
57                }),
58                version: version.to_owned(),
59                ..Default::default()
60            },
61            methods: vec![],
62            components: Components {
63                content_descriptors: Default::default(),
64                schemas: Default::default(),
65            },
66            method_routing: Default::default(),
67        }
68    }
69
70    pub fn add_module(&mut self, module: Module) {
71        self.methods.extend(module.methods);
72
73        self.methods.sort_by(|m, n| m.name.cmp(&n.name));
74
75        self.components.schemas.extend(module.components.schemas);
76        self.components
77            .content_descriptors
78            .extend(module.components.content_descriptors);
79        self.method_routing.extend(module.method_routing);
80    }
81
82    pub fn add_examples(&mut self, mut example_provider: BTreeMap<String, Vec<ExamplePairing>>) {
83        for method in &mut self.methods {
84            if let Occupied(entry) = example_provider.entry(method.name.clone()) {
85                let examples = entry.remove();
86                let param_names = method
87                    .params
88                    .iter()
89                    .map(|p| p.name.clone())
90                    .collect::<Vec<_>>();
91
92                // Make sure example's parameters are correct.
93                for example in examples.iter() {
94                    let example_param_names = example
95                        .params
96                        .iter()
97                        .map(|param| param.name.clone())
98                        .collect::<Vec<_>>();
99                    assert_eq!(
100                        param_names, example_param_names,
101                        "Provided example parameters doesn't match the function parameters."
102                    );
103                }
104
105                method.examples = examples
106            } else {
107                println!("No example found for method: {}", method.name);
108            }
109        }
110    }
111}
112
113pub struct Module {
114    methods: Vec<Method>,
115    components: Components,
116    method_routing: BTreeMap<String, MethodRouting>,
117}
118
119pub struct RpcModuleDocBuilder {
120    schema_generator: SchemaGenerator,
121    methods: BTreeMap<String, Method>,
122    method_routing: BTreeMap<String, MethodRouting>,
123    content_descriptors: BTreeMap<String, ContentDescriptor>,
124}
125
126#[derive(Serialize, Deserialize, Default, Clone)]
127pub struct ContentDescriptor {
128    name: String,
129    #[serde(skip_serializing_if = "Option::is_none")]
130    summary: Option<String>,
131    #[serde(skip_serializing_if = "Option::is_none")]
132    description: Option<String>,
133    #[serde(skip_serializing_if = "default")]
134    required: bool,
135    schema: SchemaObject,
136    #[serde(skip_serializing_if = "default")]
137    deprecated: bool,
138}
139
140#[derive(Serialize, Deserialize, Default, Clone)]
141struct Method {
142    name: String,
143    #[serde(skip_serializing_if = "Vec::is_empty")]
144    tags: Vec<Tag>,
145    #[serde(skip_serializing_if = "Option::is_none")]
146    description: Option<String>,
147    params: Vec<ContentDescriptor>,
148    #[serde(skip_serializing_if = "Option::is_none")]
149    result: Option<ContentDescriptor>,
150    #[serde(skip_serializing_if = "Vec::is_empty")]
151    examples: Vec<ExamplePairing>,
152    #[serde(skip_serializing_if = "std::ops::Not::not")]
153    deprecated: bool,
154}
155#[derive(Clone, Debug)]
156pub struct MethodRouting {
157    min: Option<Versioning>,
158    max: Option<Versioning>,
159    pub route_to: String,
160}
161
162impl MethodRouting {
163    pub fn le(version: &str, route_to: &str) -> Self {
164        Self {
165            min: None,
166            max: Some(Versioning::new(version).unwrap()),
167            route_to: route_to.to_string(),
168        }
169    }
170
171    pub fn eq(version: &str, route_to: &str) -> Self {
172        Self {
173            min: Some(Versioning::new(version).unwrap()),
174            max: Some(Versioning::new(version).unwrap()),
175            route_to: route_to.to_string(),
176        }
177    }
178
179    pub fn matches(&self, version: &str) -> bool {
180        let version = Versioning::new(version);
181        match (&version, &self.min, &self.max) {
182            (Some(version), None, Some(max)) => version <= max,
183            (Some(version), Some(min), None) => version >= min,
184            (Some(version), Some(min), Some(max)) => version >= min && version <= max,
185            (_, _, _) => false,
186        }
187    }
188}
189
190#[derive(Serialize, Deserialize, Default, Clone)]
191pub struct ExamplePairing {
192    name: String,
193    #[serde(skip_serializing_if = "Option::is_none")]
194    description: Option<String>,
195    #[serde(skip_serializing_if = "Option::is_none")]
196    summary: Option<String>,
197    params: Vec<Example>,
198    result: Example,
199}
200
201impl ExamplePairing {
202    pub fn new(name: &str, params: Vec<(&str, Value)>, result: Value) -> Self {
203        Self {
204            name: name.to_string(),
205            description: None,
206            summary: None,
207            params: params
208                .into_iter()
209                .map(|(name, value)| Example {
210                    name: name.to_string(),
211                    summary: None,
212                    description: None,
213                    value,
214                })
215                .collect(),
216            result: Example {
217                name: "Result".to_string(),
218                summary: None,
219                description: None,
220                value: result,
221            },
222        }
223    }
224}
225
226#[derive(Serialize, Deserialize, Default, Clone)]
227pub struct Example {
228    name: String,
229    #[serde(skip_serializing_if = "Option::is_none")]
230    summary: Option<String>,
231    #[serde(skip_serializing_if = "Option::is_none")]
232    description: Option<String>,
233    value: Value,
234}
235
236#[derive(Serialize, Deserialize, Default, Clone)]
237struct Tag {
238    name: String,
239    #[serde(skip_serializing_if = "Option::is_none")]
240    summary: Option<String>,
241    #[serde(skip_serializing_if = "Option::is_none")]
242    description: Option<String>,
243}
244
245impl Tag {
246    pub fn new(name: &str) -> Self {
247        Self {
248            name: name.to_string(),
249            summary: None,
250            description: None,
251        }
252    }
253}
254
255#[derive(Serialize, Deserialize, Default, Clone)]
256#[serde(rename_all = "camelCase")]
257struct Info {
258    title: String,
259    #[serde(skip_serializing_if = "Option::is_none")]
260    description: Option<String>,
261    #[serde(skip_serializing_if = "Option::is_none")]
262    terms_of_service: Option<String>,
263    #[serde(skip_serializing_if = "Option::is_none")]
264    contact: Option<Contact>,
265    #[serde(skip_serializing_if = "Option::is_none")]
266    license: Option<License>,
267    version: String,
268}
269
270fn default<T>(value: &T) -> bool
271where
272    T: Default + PartialEq,
273{
274    value == &T::default()
275}
276
277#[derive(Serialize, Deserialize, Default, Clone)]
278struct Contact {
279    name: String,
280    #[serde(skip_serializing_if = "Option::is_none")]
281    url: Option<String>,
282    #[serde(skip_serializing_if = "Option::is_none")]
283    email: Option<String>,
284}
285#[derive(Serialize, Deserialize, Default, Clone)]
286struct License {
287    name: String,
288    #[serde(skip_serializing_if = "Option::is_none")]
289    url: Option<String>,
290}
291
292impl Default for RpcModuleDocBuilder {
293    fn default() -> Self {
294        let schema_generator = SchemaSettings::default()
295            .with(|s| {
296                s.definitions_path = "#/components/schemas/".to_string();
297            })
298            .into_generator();
299
300        Self {
301            schema_generator,
302            methods: BTreeMap::new(),
303            method_routing: Default::default(),
304            content_descriptors: BTreeMap::new(),
305        }
306    }
307}
308
309impl RpcModuleDocBuilder {
310    pub fn build(mut self) -> Module {
311        Module {
312            methods: self.methods.into_values().collect(),
313            components: Components {
314                content_descriptors: self.content_descriptors,
315                schemas: self
316                    .schema_generator
317                    .root_schema_for::<u8>()
318                    .definitions
319                    .into_iter()
320                    .map(|(name, schema)| (name, schema.into_object()))
321                    .collect::<BTreeMap<_, _>>(),
322            },
323            method_routing: self.method_routing,
324        }
325    }
326
327    pub fn add_method_routing(
328        &mut self,
329        namespace: &str,
330        name: &str,
331        route_to: &str,
332        comparator: &str,
333        version: &str,
334    ) {
335        let name = format!("{namespace}_{name}");
336        let route_to = format!("{namespace}_{route_to}");
337        let routing = match comparator {
338            "<=" => MethodRouting::le(version, &route_to),
339            "=" => MethodRouting::eq(version, &route_to),
340            _ => panic!("Unsupported version comparator {comparator}"),
341        };
342        if self.method_routing.insert(name.clone(), routing).is_some() {
343            panic!("Routing for method [{name}] already exists.")
344        }
345    }
346
347    pub fn add_method(
348        &mut self,
349        namespace: &str,
350        name: &str,
351        params: Vec<ContentDescriptor>,
352        result: Option<ContentDescriptor>,
353        doc: &str,
354        tag: Option<String>,
355        deprecated: bool,
356    ) {
357        let tags = tag.map(|t| Tag::new(&t)).into_iter().collect::<Vec<_>>();
358        self.add_method_internal(namespace, name, params, result, doc, tags, deprecated)
359    }
360
361    pub fn add_subscription(
362        &mut self,
363        namespace: &str,
364        name: &str,
365        params: Vec<ContentDescriptor>,
366        result: Option<ContentDescriptor>,
367        doc: &str,
368        tag: Option<String>,
369        deprecated: bool,
370    ) {
371        let mut tags = tag.map(|t| Tag::new(&t)).into_iter().collect::<Vec<_>>();
372        tags.push(Tag::new("Websocket"));
373        tags.push(Tag::new("PubSub"));
374        self.add_method_internal(namespace, name, params, result, doc, tags, deprecated)
375    }
376
377    fn add_method_internal(
378        &mut self,
379        namespace: &str,
380        name: &str,
381        params: Vec<ContentDescriptor>,
382        result: Option<ContentDescriptor>,
383        doc: &str,
384        tags: Vec<Tag>,
385        deprecated: bool,
386    ) {
387        let description = if doc.trim().is_empty() {
388            None
389        } else {
390            Some(doc.trim().to_string())
391        };
392        let name = format!("{}_{}", namespace, name);
393        self.methods.insert(
394            name.clone(),
395            Method {
396                name,
397                description,
398                params,
399                result,
400                tags,
401                examples: Vec::new(),
402                deprecated,
403            },
404        );
405    }
406
407    pub fn create_content_descriptor<T: JsonSchema>(
408        &mut self,
409        name: &str,
410        summary: Option<String>,
411        description: Option<String>,
412        required: bool,
413    ) -> ContentDescriptor {
414        let schema = self.schema_generator.subschema_for::<T>().into_object();
415        ContentDescriptor {
416            name: name.replace(' ', ""),
417            summary,
418            description,
419            required,
420            schema,
421            deprecated: false,
422        }
423    }
424}
425
426#[derive(Serialize, Deserialize, Clone)]
427#[serde(rename_all = "camelCase")]
428struct Components {
429    #[serde(skip_serializing_if = "BTreeMap::is_empty")]
430    content_descriptors: BTreeMap<String, ContentDescriptor>,
431    #[serde(skip_serializing_if = "BTreeMap::is_empty")]
432    schemas: BTreeMap<String, SchemaObject>,
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    #[test]
440    fn test_version_matching() {
441        let routing = MethodRouting::eq("1.5", "test");
442        assert!(routing.matches("1.5"));
443        assert!(!routing.matches("1.6"));
444        assert!(!routing.matches("1.4"));
445
446        let routing = MethodRouting::le("1.5", "test");
447        assert!(routing.matches("1.5"));
448        assert!(routing.matches("1.4.5"));
449        assert!(routing.matches("1.4"));
450        assert!(routing.matches("1.3"));
451
452        assert!(!routing.matches("1.6"));
453        assert!(!routing.matches("1.5.1"));
454    }
455}