1use std::collections::BTreeMap;
6
7use move_core_types::{
8 annotated_visitor::{self, StructDriver, Traversal, ValueDriver},
9 language_storage::{StructTag, TypeTag},
10};
11
12use crate::balance::Balance;
13
14#[derive(Default)]
16pub(crate) struct BalanceTraversal {
17 balances: BTreeMap<TypeTag, u64>,
18}
19
20#[derive(Default)]
23struct Accumulator {
24 total: u64,
25}
26
27impl BalanceTraversal {
28 pub(crate) fn finish(self) -> BTreeMap<TypeTag, u64> {
30 self.balances
31 }
32}
33
34impl<'b, 'l> Traversal<'b, 'l> for BalanceTraversal {
35 type Error = annotated_visitor::Error;
36
37 fn traverse_struct(
38 &mut self,
39 driver: &mut StructDriver<'_, 'b, 'l>,
40 ) -> Result<(), Self::Error> {
41 let Some(coin_type) = is_balance(&driver.struct_layout().type_) else {
42 while driver.next_field(self)?.is_some() {}
44 return Ok(());
45 };
46
47 let mut acc = Accumulator::default();
48 while driver.next_field(&mut acc)?.is_some() {}
49 *self.balances.entry(coin_type).or_default() += acc.total;
50 Ok(())
51 }
52}
53
54impl<'b, 'l> Traversal<'b, 'l> for Accumulator {
55 type Error = annotated_visitor::Error;
56 fn traverse_u64(
57 &mut self,
58 _driver: &ValueDriver<'_, 'b, 'l>,
59 value: u64,
60 ) -> Result<(), Self::Error> {
61 self.total += value;
62 Ok(())
63 }
64}
65
66fn is_balance(s: &StructTag) -> Option<TypeTag> {
69 (Balance::is_balance(s) && s.type_params.len() == 1).then(|| s.type_params[0].clone())
70}
71
72#[cfg(test)]
73mod tests {
74 use std::str::FromStr;
75
76 use move_core_types::{
77 account_address::AccountAddress, annotated_value as A, identifier::Identifier,
78 language_storage::StructTag,
79 };
80
81 use super::*;
82 use crate::id::UID;
83
84 #[test]
85 fn test_traverse_balance() {
86 let layout = bal_t("0x42::foo::Bar");
87 let value = bal_v("0x42::foo::Bar", 42);
88
89 let bytes = serialize(value.clone());
90
91 let mut visitor = BalanceTraversal::default();
92 A::MoveValue::visit_deserialize(&bytes, &layout, &mut visitor).unwrap();
93 let balances = visitor.finish();
94
95 assert_eq!(balances, BTreeMap::from([(type_("0x42::foo::Bar"), 42)]));
96 }
97
98 #[test]
99 fn test_traverse_coin() {
100 let layout = coin_t("0x42::foo::Bar");
101 let value = coin_v("0x42::foo::Bar", "0x101", 42);
102
103 let bytes = serialize(value.clone());
104
105 let mut visitor = BalanceTraversal::default();
106 A::MoveValue::visit_deserialize(&bytes, &layout, &mut visitor).unwrap();
107 let balances = visitor.finish();
108
109 assert_eq!(balances, BTreeMap::from([(type_("0x42::foo::Bar"), 42)]));
110 }
111
112 #[test]
113 fn test_traverse_nested() {
114 use A::MoveTypeLayout as T;
115
116 let layout = layout_(
117 "0xa::foo::Bar",
118 vec![
119 ("b", bal_t("0x42::baz::Qux")),
120 ("c", coin_t("0x42::baz::Qux")),
121 ("d", T::Vector(Box::new(coin_t("0x42::quy::Frob")))),
122 ],
123 );
124
125 let value = value_(
126 "0xa::foo::Bar",
127 vec![
128 ("b", bal_v("0x42::baz::Qux", 42)),
129 ("c", coin_v("0x42::baz::Qux", "0x101", 43)),
130 (
131 "d",
132 A::MoveValue::Vector(vec![
133 coin_v("0x42::quy::Frob", "0x102", 44),
134 coin_v("0x42::quy::Frob", "0x103", 45),
135 ]),
136 ),
137 ],
138 );
139
140 let bytes = serialize(value.clone());
141
142 let mut visitor = BalanceTraversal::default();
143 A::MoveValue::visit_deserialize(&bytes, &layout, &mut visitor).unwrap();
144 let balances = visitor.finish();
145
146 assert_eq!(
147 balances,
148 BTreeMap::from([
149 (type_("0x42::baz::Qux"), 42 + 43),
150 (type_("0x42::quy::Frob"), 44 + 45),
151 ])
152 );
153 }
154
155 #[test]
156 fn test_traverse_primitive() {
157 use A::MoveTypeLayout as T;
158
159 let layout = T::U64;
160 let value = A::MoveValue::U64(42);
161 let bytes = serialize(value.clone());
162
163 let mut visitor = BalanceTraversal::default();
164 A::MoveValue::visit_deserialize(&bytes, &layout, &mut visitor).unwrap();
165 let balances = visitor.finish();
166
167 assert_eq!(balances, BTreeMap::from([]));
168 }
169
170 #[test]
171 fn test_traverse_fake_balance() {
172 use A::MoveTypeLayout as T;
173
174 let layout = layout_(
175 "0xa::foo::Bar",
176 vec![
177 ("b", bal_t("0x42::baz::Qux")),
178 ("c", coin_t("0x42::baz::Qux")),
179 (
180 "d",
181 layout_(
182 "0x3::balance::Balance<0x42::baz::Qux>",
184 vec![("value", T::U64)],
185 ),
186 ),
187 ],
188 );
189
190 let value = value_(
191 "0xa::foo::Bar",
192 vec![
193 ("b", bal_v("0x42::baz::Qux", 42)),
194 ("c", coin_v("0x42::baz::Qux", "0x101", 43)),
195 (
196 "d",
197 value_(
198 "0x3::balance::Balance<0x42::baz::Qux>",
199 vec![("value", A::MoveValue::U64(44))],
200 ),
201 ),
202 ],
203 );
204
205 let bytes = serialize(value.clone());
206
207 let mut visitor = BalanceTraversal::default();
208 A::MoveValue::visit_deserialize(&bytes, &layout, &mut visitor).unwrap();
209 let balances = visitor.finish();
210
211 assert_eq!(
212 balances,
213 BTreeMap::from([(type_("0x42::baz::Qux"), 42 + 43),])
214 );
215 }
216
217 fn uid_(addr: &str) -> A::MoveValue {
219 value_(
220 "0x2::object::UID",
221 vec![(
222 "id",
223 value_(
224 "0x2::object::ID",
225 vec![(
226 "bytes",
227 A::MoveValue::Address(AccountAddress::from_str(addr).unwrap()),
228 )],
229 ),
230 )],
231 )
232 }
233
234 fn bal_v(tag: &str, value: u64) -> A::MoveValue {
236 value_(
237 &format!("0x2::balance::Balance<{tag}>"),
238 vec![("value", A::MoveValue::U64(value))],
239 )
240 }
241
242 fn coin_v(tag: &str, id: &str, value: u64) -> A::MoveValue {
244 value_(
245 &format!("0x2::coin::Coin<{tag}>"),
246 vec![("id", uid_(id)), ("balance", bal_v(tag, value))],
247 )
248 }
249
250 fn bal_t(tag: &str) -> A::MoveTypeLayout {
252 layout_(
253 &format!("0x2::balance::Balance<{tag}>"),
254 vec![("value", A::MoveTypeLayout::U64)],
255 )
256 }
257
258 fn coin_t(tag: &str) -> A::MoveTypeLayout {
260 layout_(
261 &format!("0x2::coin::Coin<{tag}>"),
262 vec![
263 ("id", A::MoveTypeLayout::Struct(Box::new(UID::layout()))),
264 ("balance", bal_t(tag)),
265 ],
266 )
267 }
268
269 fn value_(rep: &str, fields: Vec<(&str, A::MoveValue)>) -> A::MoveValue {
271 let type_ = StructTag::from_str(rep).unwrap();
272 let fields = fields
273 .into_iter()
274 .map(|(name, value)| (Identifier::new(name).unwrap(), value))
275 .collect();
276
277 A::MoveValue::Struct(A::MoveStruct::new(type_, fields))
278 }
279
280 fn type_(rep: &str) -> TypeTag {
282 TypeTag::from_str(rep).unwrap()
283 }
284
285 fn layout_(rep: &str, fields: Vec<(&str, A::MoveTypeLayout)>) -> A::MoveTypeLayout {
287 let type_ = StructTag::from_str(rep).unwrap();
288 let fields = fields
289 .into_iter()
290 .map(|(name, layout)| A::MoveFieldLayout::new(Identifier::new(name).unwrap(), layout))
291 .collect();
292
293 A::MoveTypeLayout::Struct(Box::new(A::MoveStructLayout { type_, fields }))
294 }
295
296 fn serialize(value: A::MoveValue) -> Vec<u8> {
298 value.clone().undecorate().simple_serialize().unwrap()
299 }
300}