iota_json_rpc/
balance_changes.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::{BTreeMap, HashMap, HashSet},
7    ops::Neg,
8};
9
10use async_trait::async_trait;
11use iota_json_rpc_types::BalanceChange;
12use iota_types::{
13    base_types::{ObjectID, ObjectRef, SequenceNumber},
14    coin::Coin,
15    digests::ObjectDigest,
16    effects::{TransactionEffects, TransactionEffectsAPI},
17    execution_status::ExecutionStatus,
18    gas_coin::GAS,
19    object::{Object, Owner},
20    storage::WriteKind,
21    transaction::InputObjectKind,
22};
23use move_core_types::language_storage::TypeTag;
24use tokio::sync::RwLock;
25
26pub async fn get_balance_changes_from_effect<P: ObjectProvider<Error = E>, E>(
27    object_provider: &P,
28    effects: &TransactionEffects,
29    input_objs: Vec<InputObjectKind>,
30    mocked_coin: Option<ObjectID>,
31) -> Result<Vec<BalanceChange>, E> {
32    let (_, gas_owner) = effects.gas_object();
33
34    // Only charge gas when tx fails, skip all object parsing
35    if effects.status() != &ExecutionStatus::Success {
36        return Ok(vec![BalanceChange {
37            owner: gas_owner,
38            coin_type: GAS::type_tag(),
39            amount: effects.gas_cost_summary().net_gas_usage().neg() as i128,
40        }]);
41    }
42
43    let all_mutated = effects
44        .all_changed_objects()
45        .into_iter()
46        .filter_map(|((id, version, digest), _, _)| {
47            if matches!(mocked_coin, Some(coin) if id == coin) {
48                return None;
49            }
50            Some((id, version, Some(digest)))
51        })
52        .collect::<Vec<_>>();
53
54    let input_objs_to_digest = input_objs
55        .iter()
56        .filter_map(|k| match k {
57            InputObjectKind::ImmOrOwnedMoveObject(o) => Some((o.0, o.2)),
58            InputObjectKind::MovePackage(_) | InputObjectKind::SharedMoveObject { .. } => None,
59        })
60        .collect::<HashMap<ObjectID, ObjectDigest>>();
61    let unwrapped_then_deleted = effects
62        .unwrapped_then_deleted()
63        .iter()
64        .map(|e| e.0)
65        .collect::<HashSet<_>>();
66    get_balance_changes(
67        object_provider,
68        &effects
69            .modified_at_versions()
70            .into_iter()
71            .filter_map(|(id, version)| {
72                if matches!(mocked_coin, Some(coin) if id == coin) {
73                    return None;
74                }
75                // We won't be able to get dynamic object from object provider today
76                if unwrapped_then_deleted.contains(&id) {
77                    return None;
78                }
79                Some((id, version, input_objs_to_digest.get(&id).cloned()))
80            })
81            .collect::<Vec<_>>(),
82        &all_mutated,
83    )
84    .await
85}
86
87pub async fn get_balance_changes<P: ObjectProvider<Error = E>, E>(
88    object_provider: &P,
89    modified_at_version: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
90    all_mutated: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
91) -> Result<Vec<BalanceChange>, E> {
92    // 1. subtract all input coins
93    let balances = fetch_coins(object_provider, modified_at_version)
94        .await?
95        .into_iter()
96        .fold(
97            BTreeMap::<_, i128>::new(),
98            |mut acc, (owner, type_, amount)| {
99                *acc.entry((owner, type_)).or_default() -= amount as i128;
100                acc
101            },
102        );
103    // 2. add all mutated coins
104    let balances = fetch_coins(object_provider, all_mutated)
105        .await?
106        .into_iter()
107        .fold(balances, |mut acc, (owner, type_, amount)| {
108            *acc.entry((owner, type_)).or_default() += amount as i128;
109            acc
110        });
111
112    Ok(balances
113        .into_iter()
114        .filter_map(|((owner, coin_type), amount)| {
115            if amount == 0 {
116                return None;
117            }
118            Some(BalanceChange {
119                owner,
120                coin_type,
121                amount,
122            })
123        })
124        .collect())
125}
126
127async fn fetch_coins<P: ObjectProvider<Error = E>, E>(
128    object_provider: &P,
129    objects: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
130) -> Result<Vec<(Owner, TypeTag, u64)>, E> {
131    let mut all_mutated_coins = vec![];
132    for (id, version, digest_opt) in objects {
133        // TODO: use multi get object
134        let o = object_provider.get_object(id, version).await?;
135        if let Some(type_) = o.type_() {
136            if type_.is_coin() {
137                if let Some(digest) = digest_opt {
138                    // TODO: can we return Err here instead?
139                    assert_eq!(
140                        *digest,
141                        o.digest(),
142                        "Object digest mismatch--got bad data from object_provider?"
143                    )
144                }
145                let [coin_type]: [TypeTag; 1] =
146                    type_.clone().into_type_params().try_into().unwrap();
147                all_mutated_coins.push((
148                    o.owner,
149                    coin_type,
150                    // we know this is a coin, safe to unwrap
151                    Coin::extract_balance_if_coin(&o).unwrap().unwrap(),
152                ))
153            }
154        }
155    }
156    Ok(all_mutated_coins)
157}
158
159#[async_trait]
160pub trait ObjectProvider {
161    type Error;
162    async fn get_object(
163        &self,
164        id: &ObjectID,
165        version: &SequenceNumber,
166    ) -> Result<Object, Self::Error>;
167    async fn find_object_lt_or_eq_version(
168        &self,
169        id: &ObjectID,
170        version: &SequenceNumber,
171    ) -> Result<Option<Object>, Self::Error>;
172}
173
174pub struct ObjectProviderCache<P> {
175    object_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), Object>>,
176    last_version_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), SequenceNumber>>,
177    provider: P,
178}
179
180impl<P> ObjectProviderCache<P> {
181    pub fn new(provider: P) -> Self {
182        Self {
183            object_cache: Default::default(),
184            last_version_cache: Default::default(),
185            provider,
186        }
187    }
188
189    pub fn new_with_cache(
190        provider: P,
191        written_objects: BTreeMap<ObjectID, (ObjectRef, Object, WriteKind)>,
192    ) -> Self {
193        let mut object_cache = BTreeMap::new();
194        let mut last_version_cache = BTreeMap::new();
195
196        for (object_id, (object_ref, object, _)) in written_objects {
197            let key = (object_id, object_ref.1);
198            object_cache.insert(key, object.clone());
199
200            match last_version_cache.get_mut(&key) {
201                Some(existing_seq_number) => {
202                    if object_ref.1 > *existing_seq_number {
203                        *existing_seq_number = object_ref.1
204                    }
205                }
206                None => {
207                    last_version_cache.insert(key, object_ref.1);
208                }
209            }
210        }
211
212        Self {
213            object_cache: RwLock::new(object_cache),
214            last_version_cache: RwLock::new(last_version_cache),
215            provider,
216        }
217    }
218
219    pub fn new_with_output_objects(provider: P, output_objects: Vec<Object>) -> Self {
220        let mut object_cache = BTreeMap::new();
221        let mut last_version_cache = BTreeMap::new();
222
223        for object in output_objects {
224            let object_id = object.id();
225            let version = object.version();
226
227            let key = (object_id, version);
228            object_cache.insert(key, object.clone());
229
230            match last_version_cache.get_mut(&key) {
231                Some(existing_seq_number) => {
232                    if version > *existing_seq_number {
233                        *existing_seq_number = version
234                    }
235                }
236                None => {
237                    last_version_cache.insert(key, version);
238                }
239            }
240        }
241
242        Self {
243            object_cache: RwLock::new(object_cache),
244            last_version_cache: RwLock::new(last_version_cache),
245            provider,
246        }
247    }
248}
249
250#[async_trait]
251impl<P, E> ObjectProvider for ObjectProviderCache<P>
252where
253    P: ObjectProvider<Error = E> + Sync + Send,
254    E: Sync + Send,
255{
256    type Error = P::Error;
257
258    async fn get_object(
259        &self,
260        id: &ObjectID,
261        version: &SequenceNumber,
262    ) -> Result<Object, Self::Error> {
263        if let Some(o) = self.object_cache.read().await.get(&(*id, *version)) {
264            return Ok(o.clone());
265        }
266        let o = self.provider.get_object(id, version).await?;
267        self.object_cache
268            .write()
269            .await
270            .insert((*id, *version), o.clone());
271        Ok(o)
272    }
273
274    async fn find_object_lt_or_eq_version(
275        &self,
276        id: &ObjectID,
277        version: &SequenceNumber,
278    ) -> Result<Option<Object>, Self::Error> {
279        if let Some(version) = self.last_version_cache.read().await.get(&(*id, *version)) {
280            return Ok(self.get_object(id, version).await.ok());
281        }
282        if let Some(o) = self
283            .provider
284            .find_object_lt_or_eq_version(id, version)
285            .await?
286        {
287            self.object_cache
288                .write()
289                .await
290                .insert((*id, o.version()), o.clone());
291            self.last_version_cache
292                .write()
293                .await
294                .insert((*id, *version), o.version());
295            Ok(Some(o))
296        } else {
297            Ok(None)
298        }
299    }
300}