1use std::{
6 collections::{BTreeMap, HashMap, HashSet},
7 ops::Neg,
8};
9
10use async_trait::async_trait;
11use iota_json_rpc_types::BalanceChange;
12use iota_types::{
13 base_types::{ObjectID, ObjectRef, SequenceNumber},
14 coin::Coin,
15 digests::ObjectDigest,
16 effects::{TransactionEffects, TransactionEffectsAPI},
17 execution_status::ExecutionStatus,
18 gas_coin::GAS,
19 object::{Object, Owner},
20 storage::WriteKind,
21 transaction::InputObjectKind,
22};
23use move_core_types::language_storage::TypeTag;
24use tokio::sync::RwLock;
25
26pub async fn get_balance_changes_from_effect<P: ObjectProvider<Error = E>, E>(
27 object_provider: &P,
28 effects: &TransactionEffects,
29 input_objs: Vec<InputObjectKind>,
30 mocked_coin: Option<ObjectID>,
31) -> Result<Vec<BalanceChange>, E> {
32 let (_, gas_owner) = effects.gas_object();
33
34 if effects.status() != &ExecutionStatus::Success {
36 return Ok(vec![BalanceChange {
37 owner: gas_owner,
38 coin_type: GAS::type_tag(),
39 amount: effects.gas_cost_summary().net_gas_usage().neg() as i128,
40 }]);
41 }
42
43 let all_mutated = effects
44 .all_changed_objects()
45 .into_iter()
46 .filter_map(|((id, version, digest), _, _)| {
47 if matches!(mocked_coin, Some(coin) if id == coin) {
48 return None;
49 }
50 Some((id, version, Some(digest)))
51 })
52 .collect::<Vec<_>>();
53
54 let input_objs_to_digest = input_objs
55 .iter()
56 .filter_map(|k| match k {
57 InputObjectKind::ImmOrOwnedMoveObject(o) => Some((o.0, o.2)),
58 InputObjectKind::MovePackage(_) | InputObjectKind::SharedMoveObject { .. } => None,
59 })
60 .collect::<HashMap<ObjectID, ObjectDigest>>();
61 let unwrapped_then_deleted = effects
62 .unwrapped_then_deleted()
63 .iter()
64 .map(|e| e.0)
65 .collect::<HashSet<_>>();
66 get_balance_changes(
67 object_provider,
68 &effects
69 .modified_at_versions()
70 .into_iter()
71 .filter_map(|(id, version)| {
72 if matches!(mocked_coin, Some(coin) if id == coin) {
73 return None;
74 }
75 if unwrapped_then_deleted.contains(&id) {
77 return None;
78 }
79 Some((id, version, input_objs_to_digest.get(&id).cloned()))
80 })
81 .collect::<Vec<_>>(),
82 &all_mutated,
83 )
84 .await
85}
86
87pub async fn get_balance_changes<P: ObjectProvider<Error = E>, E>(
88 object_provider: &P,
89 modified_at_version: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
90 all_mutated: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
91) -> Result<Vec<BalanceChange>, E> {
92 let balances = fetch_coins(object_provider, modified_at_version)
94 .await?
95 .into_iter()
96 .fold(
97 BTreeMap::<_, i128>::new(),
98 |mut acc, (owner, type_, amount)| {
99 *acc.entry((owner, type_)).or_default() -= amount as i128;
100 acc
101 },
102 );
103 let balances = fetch_coins(object_provider, all_mutated)
105 .await?
106 .into_iter()
107 .fold(balances, |mut acc, (owner, type_, amount)| {
108 *acc.entry((owner, type_)).or_default() += amount as i128;
109 acc
110 });
111
112 Ok(balances
113 .into_iter()
114 .filter_map(|((owner, coin_type), amount)| {
115 if amount == 0 {
116 return None;
117 }
118 Some(BalanceChange {
119 owner,
120 coin_type,
121 amount,
122 })
123 })
124 .collect())
125}
126
127async fn fetch_coins<P: ObjectProvider<Error = E>, E>(
128 object_provider: &P,
129 objects: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
130) -> Result<Vec<(Owner, TypeTag, u64)>, E> {
131 let mut all_mutated_coins = vec![];
132 for (id, version, digest_opt) in objects {
133 let o = object_provider.get_object(id, version).await?;
135 if let Some(type_) = o.type_() {
136 if type_.is_coin() {
137 if let Some(digest) = digest_opt {
138 assert_eq!(
140 *digest,
141 o.digest(),
142 "Object digest mismatch--got bad data from object_provider?"
143 )
144 }
145 let [coin_type]: [TypeTag; 1] =
146 type_.clone().into_type_params().try_into().unwrap();
147 all_mutated_coins.push((
148 o.owner,
149 coin_type,
150 Coin::extract_balance_if_coin(&o).unwrap().unwrap(),
152 ))
153 }
154 }
155 }
156 Ok(all_mutated_coins)
157}
158
159#[async_trait]
160pub trait ObjectProvider {
161 type Error;
162 async fn get_object(
163 &self,
164 id: &ObjectID,
165 version: &SequenceNumber,
166 ) -> Result<Object, Self::Error>;
167 async fn find_object_lt_or_eq_version(
168 &self,
169 id: &ObjectID,
170 version: &SequenceNumber,
171 ) -> Result<Option<Object>, Self::Error>;
172}
173
174pub struct ObjectProviderCache<P> {
175 object_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), Object>>,
176 last_version_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), SequenceNumber>>,
177 provider: P,
178}
179
180impl<P> ObjectProviderCache<P> {
181 pub fn new(provider: P) -> Self {
182 Self {
183 object_cache: Default::default(),
184 last_version_cache: Default::default(),
185 provider,
186 }
187 }
188
189 pub fn new_with_cache(
190 provider: P,
191 written_objects: BTreeMap<ObjectID, (ObjectRef, Object, WriteKind)>,
192 ) -> Self {
193 let mut object_cache = BTreeMap::new();
194 let mut last_version_cache = BTreeMap::new();
195
196 for (object_id, (object_ref, object, _)) in written_objects {
197 let key = (object_id, object_ref.1);
198 object_cache.insert(key, object.clone());
199
200 match last_version_cache.get_mut(&key) {
201 Some(existing_seq_number) => {
202 if object_ref.1 > *existing_seq_number {
203 *existing_seq_number = object_ref.1
204 }
205 }
206 None => {
207 last_version_cache.insert(key, object_ref.1);
208 }
209 }
210 }
211
212 Self {
213 object_cache: RwLock::new(object_cache),
214 last_version_cache: RwLock::new(last_version_cache),
215 provider,
216 }
217 }
218
219 pub fn new_with_output_objects(provider: P, output_objects: Vec<Object>) -> Self {
220 let mut object_cache = BTreeMap::new();
221 let mut last_version_cache = BTreeMap::new();
222
223 for object in output_objects {
224 let object_id = object.id();
225 let version = object.version();
226
227 let key = (object_id, version);
228 object_cache.insert(key, object.clone());
229
230 match last_version_cache.get_mut(&key) {
231 Some(existing_seq_number) => {
232 if version > *existing_seq_number {
233 *existing_seq_number = version
234 }
235 }
236 None => {
237 last_version_cache.insert(key, version);
238 }
239 }
240 }
241
242 Self {
243 object_cache: RwLock::new(object_cache),
244 last_version_cache: RwLock::new(last_version_cache),
245 provider,
246 }
247 }
248}
249
250#[async_trait]
251impl<P, E> ObjectProvider for ObjectProviderCache<P>
252where
253 P: ObjectProvider<Error = E> + Sync + Send,
254 E: Sync + Send,
255{
256 type Error = P::Error;
257
258 async fn get_object(
259 &self,
260 id: &ObjectID,
261 version: &SequenceNumber,
262 ) -> Result<Object, Self::Error> {
263 if let Some(o) = self.object_cache.read().await.get(&(*id, *version)) {
264 return Ok(o.clone());
265 }
266 let o = self.provider.get_object(id, version).await?;
267 self.object_cache
268 .write()
269 .await
270 .insert((*id, *version), o.clone());
271 Ok(o)
272 }
273
274 async fn find_object_lt_or_eq_version(
275 &self,
276 id: &ObjectID,
277 version: &SequenceNumber,
278 ) -> Result<Option<Object>, Self::Error> {
279 if let Some(version) = self.last_version_cache.read().await.get(&(*id, *version)) {
280 return Ok(self.get_object(id, version).await.ok());
281 }
282 if let Some(o) = self
283 .provider
284 .find_object_lt_or_eq_version(id, version)
285 .await?
286 {
287 self.object_cache
288 .write()
289 .await
290 .insert((*id, o.version()), o.clone());
291 self.last_version_cache
292 .write()
293 .await
294 .insert((*id, *version), o.version());
295 Ok(Some(o))
296 } else {
297 Ok(None)
298 }
299 }
300}