iota_graphql_config/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::collections::BTreeSet;
6
7use proc_macro::TokenStream;
8use quote::{format_ident, quote};
9use syn::{
10    Attribute, Data, DataStruct, DeriveInput, Fields, FieldsNamed, Ident, Meta, NestedMeta,
11    parse_macro_input,
12};
13
14/// Attribute macro to be applied to config-based structs. It ensures that the
15/// struct derives serde traits, and `Debug`, that all fields are renamed with
16/// "kebab case", and adds a `#[serde(default = ...)]` implementation for each
17/// field that ensures that if the field is not present during deserialization,
18/// it is replaced with its default value, from the `Default` implementation for
19/// the config struct.
20#[expect(non_snake_case)]
21#[proc_macro_attribute]
22pub fn GraphQLConfig(_attr: TokenStream, input: TokenStream) -> TokenStream {
23    let DeriveInput {
24        attrs,
25        vis,
26        ident,
27        generics,
28        data,
29    } = parse_macro_input!(input as DeriveInput);
30
31    let Data::Struct(DataStruct {
32        struct_token,
33        fields,
34        semi_token,
35    }) = data
36    else {
37        panic!("GraphQL configs must be structs.");
38    };
39
40    let Fields::Named(FieldsNamed {
41        brace_token: _,
42        named,
43    }) = fields
44    else {
45        panic!("GraphQL configs must have named fields.");
46    };
47
48    // Figure out which derives need to be added to meet the criteria of a config
49    // struct.
50    let core_derives = core_derives(&attrs);
51
52    // Extract field names once to avoid having to check for their existence
53    // multiple times.
54    let fields_with_names: Vec<_> = named
55        .iter()
56        .map(|field| {
57            let Some(ident) = &field.ident else {
58                panic!("All fields must have an identifier.");
59            };
60
61            (ident, field)
62        })
63        .collect();
64
65    // Generate the fields with the `#[serde(default = ...)]` attribute.
66    let fields = fields_with_names.iter().map(|(name, field)| {
67        let default = format!("{ident}::__default_{name}");
68        quote! { #[serde(default = #default)] #field }
69    });
70
71    // Generate the default implementations for each field.
72    let defaults = fields_with_names.iter().map(|(name, field)| {
73        let ty = &field.ty;
74        let fn_name = format_ident!("__default_{}", name);
75        let cfg = extract_cfg(&field.attrs);
76
77        quote! {
78            #[doc(hidden)] #cfg
79            fn #fn_name() -> #ty {
80                Self::default().#name
81            }
82        }
83    });
84
85    TokenStream::from(quote! {
86        #[derive(#(#core_derives),*)]
87        #[serde(rename_all = "kebab-case")]
88        #(#attrs)* #vis #struct_token #ident #generics {
89            #(#fields),*
90        } #semi_token
91
92        impl #ident {
93            #(#defaults)*
94        }
95    })
96}
97
98/// Return a set of derives that should be added to the struct to make sure it
99/// derives all the things we expect from a config, namely `Serialize`,
100/// `Deserialize`, and `Debug`.
101///
102/// We cannot add core derives unconditionally, because they will conflict with
103/// existing ones.
104fn core_derives(attrs: &[Attribute]) -> BTreeSet<Ident> {
105    let mut derives = BTreeSet::from_iter([
106        format_ident!("Serialize"),
107        format_ident!("Deserialize"),
108        format_ident!("Debug"),
109        format_ident!("Clone"),
110        format_ident!("Eq"),
111        format_ident!("PartialEq"),
112    ]);
113
114    for attr in attrs {
115        let Ok(Meta::List(list)) = attr.parse_meta() else {
116            continue;
117        };
118
119        let Some(ident) = list.path.get_ident() else {
120            continue;
121        };
122
123        if ident != "derive" {
124            continue;
125        }
126
127        for nested in list.nested {
128            let NestedMeta::Meta(Meta::Path(path)) = nested else {
129                continue;
130            };
131
132            let Some(ident) = path.get_ident() else {
133                continue;
134            };
135
136            derives.remove(ident);
137        }
138    }
139
140    derives
141}
142
143/// Find the attribute that corresponds to a `#[cfg(...)]` annotation, if it
144/// exists.
145fn extract_cfg(attrs: &[Attribute]) -> Option<&Attribute> {
146    attrs.iter().find(|attr| {
147        let meta = attr.parse_meta().ok();
148        meta.is_some_and(|m| m.path().is_ident("cfg"))
149    })
150}