iota_types/object/
balance_traversal.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::collections::BTreeMap;
6
7use move_core_types::{
8    annotated_visitor::{self, StructDriver, Traversal, ValueDriver},
9    language_storage::{StructTag, TypeTag},
10};
11
12use crate::balance::Balance;
13
14/// Traversal to gather the total balances of all coin types visited.
15#[derive(Default)]
16pub(crate) struct BalanceTraversal {
17    balances: BTreeMap<TypeTag, u64>,
18}
19
20/// Helper traversal to accumulate the values of all u64s visited. Used by
21/// `BalanceTraversal` to get the value of a `Balance` struct's field.
22#[derive(Default)]
23struct Accumulator {
24    total: u64,
25}
26
27impl BalanceTraversal {
28    /// Consume the traversal to get at its balance mapping.
29    pub(crate) fn finish(self) -> BTreeMap<TypeTag, u64> {
30        self.balances
31    }
32}
33
34impl<'b, 'l> Traversal<'b, 'l> for BalanceTraversal {
35    type Error = annotated_visitor::Error;
36
37    fn traverse_struct(
38        &mut self,
39        driver: &mut StructDriver<'_, 'b, 'l>,
40    ) -> Result<(), Self::Error> {
41        let Some(coin_type) = is_balance(&driver.struct_layout().type_) else {
42            // Not a balance, search recursively for balances among fields.
43            while driver.next_field(self)?.is_some() {}
44            return Ok(());
45        };
46
47        let mut acc = Accumulator::default();
48        while driver.next_field(&mut acc)?.is_some() {}
49        *self.balances.entry(coin_type).or_default() += acc.total;
50        Ok(())
51    }
52}
53
54impl<'b, 'l> Traversal<'b, 'l> for Accumulator {
55    type Error = annotated_visitor::Error;
56    fn traverse_u64(
57        &mut self,
58        _driver: &ValueDriver<'_, 'b, 'l>,
59        value: u64,
60    ) -> Result<(), Self::Error> {
61        self.total += value;
62        Ok(())
63    }
64}
65
66/// Returns `Some(T)` if the struct is a `iota::balance::Balance<T>`, and `None`
67/// otherwise.
68fn is_balance(s: &StructTag) -> Option<TypeTag> {
69    (Balance::is_balance(s) && s.type_params.len() == 1).then(|| s.type_params[0].clone())
70}
71
72#[cfg(test)]
73mod tests {
74    use std::str::FromStr;
75
76    use move_core_types::{
77        account_address::AccountAddress, annotated_value as A, identifier::Identifier,
78        language_storage::StructTag,
79    };
80
81    use super::*;
82    use crate::id::UID;
83
84    #[test]
85    fn test_traverse_balance() {
86        let layout = bal_t("0x42::foo::Bar");
87        let value = bal_v("0x42::foo::Bar", 42);
88
89        let bytes = serialize(value.clone());
90
91        let mut visitor = BalanceTraversal::default();
92        A::MoveValue::visit_deserialize(&bytes, &layout, &mut visitor).unwrap();
93        let balances = visitor.finish();
94
95        assert_eq!(balances, BTreeMap::from([(type_("0x42::foo::Bar"), 42)]));
96    }
97
98    #[test]
99    fn test_traverse_coin() {
100        let layout = coin_t("0x42::foo::Bar");
101        let value = coin_v("0x42::foo::Bar", "0x101", 42);
102
103        let bytes = serialize(value.clone());
104
105        let mut visitor = BalanceTraversal::default();
106        A::MoveValue::visit_deserialize(&bytes, &layout, &mut visitor).unwrap();
107        let balances = visitor.finish();
108
109        assert_eq!(balances, BTreeMap::from([(type_("0x42::foo::Bar"), 42)]));
110    }
111
112    #[test]
113    fn test_traverse_nested() {
114        use A::MoveTypeLayout as T;
115
116        let layout = layout_(
117            "0xa::foo::Bar",
118            vec![
119                ("b", bal_t("0x42::baz::Qux")),
120                ("c", coin_t("0x42::baz::Qux")),
121                ("d", T::Vector(Box::new(coin_t("0x42::quy::Frob")))),
122            ],
123        );
124
125        let value = value_(
126            "0xa::foo::Bar",
127            vec![
128                ("b", bal_v("0x42::baz::Qux", 42)),
129                ("c", coin_v("0x42::baz::Qux", "0x101", 43)),
130                (
131                    "d",
132                    A::MoveValue::Vector(vec![
133                        coin_v("0x42::quy::Frob", "0x102", 44),
134                        coin_v("0x42::quy::Frob", "0x103", 45),
135                    ]),
136                ),
137            ],
138        );
139
140        let bytes = serialize(value.clone());
141
142        let mut visitor = BalanceTraversal::default();
143        A::MoveValue::visit_deserialize(&bytes, &layout, &mut visitor).unwrap();
144        let balances = visitor.finish();
145
146        assert_eq!(
147            balances,
148            BTreeMap::from([
149                (type_("0x42::baz::Qux"), 42 + 43),
150                (type_("0x42::quy::Frob"), 44 + 45),
151            ])
152        );
153    }
154
155    #[test]
156    fn test_traverse_primitive() {
157        use A::MoveTypeLayout as T;
158
159        let layout = T::U64;
160        let value = A::MoveValue::U64(42);
161        let bytes = serialize(value.clone());
162
163        let mut visitor = BalanceTraversal::default();
164        A::MoveValue::visit_deserialize(&bytes, &layout, &mut visitor).unwrap();
165        let balances = visitor.finish();
166
167        assert_eq!(balances, BTreeMap::from([]));
168    }
169
170    #[test]
171    fn test_traverse_fake_balance() {
172        use A::MoveTypeLayout as T;
173
174        let layout = layout_(
175            "0xa::foo::Bar",
176            vec![
177                ("b", bal_t("0x42::baz::Qux")),
178                ("c", coin_t("0x42::baz::Qux")),
179                (
180                    "d",
181                    layout_(
182                        // Fake balance
183                        "0x3::balance::Balance<0x42::baz::Qux>",
184                        vec![("value", T::U64)],
185                    ),
186                ),
187            ],
188        );
189
190        let value = value_(
191            "0xa::foo::Bar",
192            vec![
193                ("b", bal_v("0x42::baz::Qux", 42)),
194                ("c", coin_v("0x42::baz::Qux", "0x101", 43)),
195                (
196                    "d",
197                    value_(
198                        "0x3::balance::Balance<0x42::baz::Qux>",
199                        vec![("value", A::MoveValue::U64(44))],
200                    ),
201                ),
202            ],
203        );
204
205        let bytes = serialize(value.clone());
206
207        let mut visitor = BalanceTraversal::default();
208        A::MoveValue::visit_deserialize(&bytes, &layout, &mut visitor).unwrap();
209        let balances = visitor.finish();
210
211        assert_eq!(
212            balances,
213            BTreeMap::from([(type_("0x42::baz::Qux"), 42 + 43),])
214        );
215    }
216
217    /// Create a UID Move Value for test purposes.
218    fn uid_(addr: &str) -> A::MoveValue {
219        value_(
220            "0x2::object::UID",
221            vec![(
222                "id",
223                value_(
224                    "0x2::object::ID",
225                    vec![(
226                        "bytes",
227                        A::MoveValue::Address(AccountAddress::from_str(addr).unwrap()),
228                    )],
229                ),
230            )],
231        )
232    }
233
234    /// Create a Balance value for testing purposes.
235    fn bal_v(tag: &str, value: u64) -> A::MoveValue {
236        value_(
237            &format!("0x2::balance::Balance<{tag}>"),
238            vec![("value", A::MoveValue::U64(value))],
239        )
240    }
241
242    /// Create a Coin value for testing purposes.
243    fn coin_v(tag: &str, id: &str, value: u64) -> A::MoveValue {
244        value_(
245            &format!("0x2::coin::Coin<{tag}>"),
246            vec![("id", uid_(id)), ("balance", bal_v(tag, value))],
247        )
248    }
249
250    /// Create a Balance layout for testing purposes.
251    fn bal_t(tag: &str) -> A::MoveTypeLayout {
252        layout_(
253            &format!("0x2::balance::Balance<{tag}>"),
254            vec![("value", A::MoveTypeLayout::U64)],
255        )
256    }
257
258    /// Create a Coin layout for testing purposes.
259    fn coin_t(tag: &str) -> A::MoveTypeLayout {
260        layout_(
261            &format!("0x2::coin::Coin<{tag}>"),
262            vec![
263                ("id", A::MoveTypeLayout::Struct(Box::new(UID::layout()))),
264                ("balance", bal_t(tag)),
265            ],
266        )
267    }
268
269    /// Create a struct value for test purposes.
270    fn value_(rep: &str, fields: Vec<(&str, A::MoveValue)>) -> A::MoveValue {
271        let type_ = StructTag::from_str(rep).unwrap();
272        let fields = fields
273            .into_iter()
274            .map(|(name, value)| (Identifier::new(name).unwrap(), value))
275            .collect();
276
277        A::MoveValue::Struct(A::MoveStruct::new(type_, fields))
278    }
279
280    // Create a type tag for test purposes.
281    fn type_(rep: &str) -> TypeTag {
282        TypeTag::from_str(rep).unwrap()
283    }
284
285    /// Create a struct layout for test purposes.
286    fn layout_(rep: &str, fields: Vec<(&str, A::MoveTypeLayout)>) -> A::MoveTypeLayout {
287        let type_ = StructTag::from_str(rep).unwrap();
288        let fields = fields
289            .into_iter()
290            .map(|(name, layout)| A::MoveFieldLayout::new(Identifier::new(name).unwrap(), layout))
291            .collect();
292
293        A::MoveTypeLayout::Struct(Box::new(A::MoveStructLayout { type_, fields }))
294    }
295
296    /// BCS encode Move value.
297    fn serialize(value: A::MoveValue) -> Vec<u8> {
298        value.clone().undecorate().simple_serialize().unwrap()
299    }
300}