From 5108ce52c235101d9ce9f88a00e81b6ba0f9a641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sun, 10 Jul 2022 17:23:26 +0200 Subject: [PATCH] refactor: work on auth chain and state compressor --- src/database/key_value/rooms/auth_chain.rs | 24 ++++ .../key_value/rooms/state_compressor.rs | 48 +++++++ src/service/rooms/auth_chain/data.rs | 4 + src/service/rooms/auth_chain/mod.rs | 53 ++++++++ src/service/rooms/state_compressor/data.rs | 10 ++ src/service/rooms/state_compressor/mod.rs | 121 ++---------------- 6 files changed, 152 insertions(+), 108 deletions(-) create mode 100644 src/database/key_value/rooms/auth_chain.rs create mode 100644 src/database/key_value/rooms/state_compressor.rs create mode 100644 src/service/rooms/auth_chain/data.rs create mode 100644 src/service/rooms/auth_chain/mod.rs create mode 100644 src/service/rooms/state_compressor/data.rs diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/database/key_value/rooms/auth_chain.rs new file mode 100644 index 00000000..57dbb147 --- /dev/null +++ b/src/database/key_value/rooms/auth_chain.rs @@ -0,0 +1,24 @@ +impl service::room::auth_chain::Data for KeyValueDatabase { + fn get_cached_eventid_authchain<'a>() -> Result> { + self.shorteventid_authchain + .get(&shorteventid.to_be_bytes())? + .map(|chain| { + chain + .chunks_exact(size_of::()) + .map(|chunk| { + utils::u64_from_bytes(chunk).expect("byte length is correct") + }) + .collect() + }) + } + + fn cache_eventid_authchain<'a>(shorteventid: u64, auth_chain: &HashSet) -> Result<()> { + shorteventid_authchain.insert( + &shorteventid.to_be_bytes(), + &auth_chain + .iter() + .flat_map(|s| s.to_be_bytes().to_vec()) + .collect::>(), + ) + } +} diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/database/key_value/rooms/state_compressor.rs new file mode 100644 index 00000000..71a2f3a0 --- /dev/null +++ b/src/database/key_value/rooms/state_compressor.rs @@ -0,0 +1,48 @@ +impl service::room::state_compressor::Data for KeyValueDatabase { + fn get_statediff(shortstatehash: u64) -> Result { + let value = self + .shortstatehash_statediff + .get(&shortstatehash.to_be_bytes())? + .ok_or_else(|| Error::bad_database("State hash does not exist"))?; + let parent = + utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); + + let mut add_mode = true; + let mut added = HashSet::new(); + let mut removed = HashSet::new(); + + let mut i = size_of::(); + while let Some(v) = value.get(i..i + 2 * size_of::()) { + if add_mode && v.starts_with(&0_u64.to_be_bytes()) { + add_mode = false; + i += size_of::(); + continue; + } + if add_mode { + added.insert(v.try_into().expect("we checked the size above")); + } else { + removed.insert(v.try_into().expect("we checked the size above")); + } + i += 2 * size_of::(); + } + + StateDiff { parent, added, removed } + } + + fn save_statediff(shortstatehash: u64, diff: StateDiff) -> Result<()> { + let mut value = diff.parent.to_be_bytes().to_vec(); + for new in &diff.new { + value.extend_from_slice(&new[..]); + } + + if !diff.removed.is_empty() { + value.extend_from_slice(&0_u64.to_be_bytes()); + for removed in &diff.removed { + value.extend_from_slice(&removed[..]); + } + } + + self.shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value)?; + } +} diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs new file mode 100644 index 00000000..d8fde958 --- /dev/null +++ b/src/service/rooms/auth_chain/data.rs @@ -0,0 +1,4 @@ +pub trait Data { + fn get_cached_eventid_authchain<'a>() -> Result>; + fn cache_eventid_authchain<'a>(shorteventid: u64, auth_chain: &HashSet) -> Result>; +} diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs new file mode 100644 index 00000000..dfc289f3 --- /dev/null +++ b/src/service/rooms/auth_chain/mod.rs @@ -0,0 +1,53 @@ +mod data; +pub use data::Data; + +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { + #[tracing::instrument(skip(self))] + pub fn get_cached_eventid_authchain<'a>( + &'a self, + key: &[u64], + ) -> Result>>> { + // Check RAM cache + if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key.to_be_bytes()) { + return Ok(Some(Arc::clone(result))); + } + + // We only save auth chains for single events in the db + if key.len == 1 { + // Check DB cache + if let Some(chain) = self.db.get_cached_eventid_authchain(key[0]) + { + let chain = Arc::new(chain); + + // Cache in RAM + self.auth_chain_cache + .lock() + .unwrap() + .insert(vec![key[0]], Arc::clone(&chain)); + + return Ok(Some(chain)); + } + } + + Ok(None) + } + + #[tracing::instrument(skip(self))] + pub fn cache_auth_chain(&self, key: Vec, auth_chain: Arc>) -> Result<()> { + // Only persist single events in db + if key.len() == 1 { + self.db.cache_auth_chain(key[0], auth_chain)?; + } + + // Cache in RAM + self.auth_chain_cache.lock().unwrap().insert(key, auth_chain); + + Ok(()) + } +} diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs new file mode 100644 index 00000000..8b855cd2 --- /dev/null +++ b/src/service/rooms/state_compressor/data.rs @@ -0,0 +1,10 @@ +struct StateDiff { + parent: Option, + added: Vec, + removed: Vec, +} + +pub trait Data { + fn get_statediff(shortstatehash: u64) -> Result; + fn save_statediff(shortstatehash: u64, diff: StateDiff) -> Result<()>; +} diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 197ce844..d6d88e25 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -1,4 +1,13 @@ +mod data; +pub use data::Data; +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { /// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer. #[tracing::instrument(skip(self))] pub fn load_shortstatehash_info( @@ -21,31 +30,7 @@ return Ok(r.clone()); } - let value = self - .shortstatehash_statediff - .get(&shortstatehash.to_be_bytes())? - .ok_or_else(|| Error::bad_database("State hash does not exist"))?; - let parent = - utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); - - let mut add_mode = true; - let mut added = HashSet::new(); - let mut removed = HashSet::new(); - - let mut i = size_of::(); - while let Some(v) = value.get(i..i + 2 * size_of::()) { - if add_mode && v.starts_with(&0_u64.to_be_bytes()) { - add_mode = false; - i += size_of::(); - continue; - } - if add_mode { - added.insert(v.try_into().expect("we checked the size above")); - } else { - removed.insert(v.try_into().expect("we checked the size above")); - } - i += 2 * size_of::(); - } + self.db.get_statediff(shortstatehash)?; if parent != 0_u64 { let mut response = self.load_shortstatehash_info(parent)?; @@ -170,17 +155,7 @@ if parent_states.is_empty() { // There is no parent layer, create a new state - let mut value = 0_u64.to_be_bytes().to_vec(); // 0 means no parent - for new in &statediffnew { - value.extend_from_slice(&new[..]); - } - - if !statediffremoved.is_empty() { - warn!("Tried to create new state with removals"); - } - - self.shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value)?; + self.db.save_statediff(shortstatehash, StateDiff { parent: 0, new: statediffnew, removed: statediffremoved })?; return Ok(()); }; @@ -222,20 +197,7 @@ )?; } else { // Diff small enough, we add diff as layer on top of parent - let mut value = parent.0.to_be_bytes().to_vec(); - for new in &statediffnew { - value.extend_from_slice(&new[..]); - } - - if !statediffremoved.is_empty() { - value.extend_from_slice(&0_u64.to_be_bytes()); - for removed in &statediffremoved { - value.extend_from_slice(&removed[..]); - } - } - - self.shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value)?; + self.db.save_statediff(shortstatehash, StateDiff { parent: parent.0, new: statediffnew, removed: statediffremoved })?; } Ok(()) @@ -298,61 +260,4 @@ Ok((new_shortstatehash, statediffnew, statediffremoved)) } - - #[tracing::instrument(skip(self))] - pub fn get_auth_chain_from_cache<'a>( - &'a self, - key: &[u64], - ) -> Result>>> { - // Check RAM cache - if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { - return Ok(Some(Arc::clone(result))); - } - - // Check DB cache - if key.len() == 1 { - if let Some(chain) = - self.shorteventid_authchain - .get(&key[0].to_be_bytes())? - .map(|chain| { - chain - .chunks_exact(size_of::()) - .map(|chunk| { - utils::u64_from_bytes(chunk).expect("byte length is correct") - }) - .collect() - }) - { - let chain = Arc::new(chain); - - // Cache in RAM - self.auth_chain_cache - .lock() - .unwrap() - .insert(vec![key[0]], Arc::clone(&chain)); - - return Ok(Some(chain)); - } - } - - Ok(None) - } - - #[tracing::instrument(skip(self))] - pub fn cache_auth_chain(&self, key: Vec, chain: Arc>) -> Result<()> { - // Persist in db - if key.len() == 1 { - self.shorteventid_authchain.insert( - &key[0].to_be_bytes(), - &chain - .iter() - .flat_map(|s| s.to_be_bytes().to_vec()) - .collect::>(), - )?; - } - - // Cache in RAM - self.auth_chain_cache.lock().unwrap().insert(key, chain); - - Ok(()) - } +}