From 3a09c8e1801605e60c45739c9348ceab5bc4a85f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sun, 14 Aug 2022 13:38:21 +0200 Subject: [PATCH] refactor state accessor, state cache, user, uiaa --- .../key_value/rooms/state_accessor.rs | 151 ++++++++++++++++++ src/database/key_value/rooms/state_cache.rs | 8 + src/database/key_value/rooms/user.rs | 110 +++++++++++++ src/database/key_value/uiaa.rs | 82 ++++++++++ src/service/rooms/state_accessor/data.rs | 51 ++++++ src/service/rooms/state_accessor/mod.rs | 115 +++---------- src/service/rooms/state_cache/data.rs | 3 + src/service/rooms/state_cache/mod.rs | 65 ++------ src/service/rooms/user/data.rs | 21 +++ src/service/rooms/user/mod.rs | 104 ++---------- src/service/uiaa/data.rs | 31 ++++ src/service/{uiaa.rs => uiaa/mod.rs} | 96 +---------- 12 files changed, 511 insertions(+), 326 deletions(-) create mode 100644 src/database/key_value/rooms/state_accessor.rs create mode 100644 src/database/key_value/rooms/state_cache.rs create mode 100644 src/database/key_value/rooms/user.rs create mode 100644 src/database/key_value/uiaa.rs create mode 100644 src/service/rooms/state_accessor/data.rs create mode 100644 src/service/rooms/state_cache/data.rs create mode 100644 src/service/rooms/user/data.rs create mode 100644 src/service/uiaa/data.rs rename src/service/{uiaa.rs => uiaa/mod.rs} (60%) diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs new file mode 100644 index 00000000..db81967d --- /dev/null +++ b/src/database/key_value/rooms/state_accessor.rs @@ -0,0 +1,151 @@ +impl service::room::state_accessor::Data for KeyValueDatabase { + async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + let mut result = BTreeMap::new(); + let mut i = 0; + for compressed in full_state.into_iter() { + let parsed = self.parse_compressed_state_event(compressed)?; + result.insert(parsed.0, parsed.1); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + Ok(result) + } + + async fn state_full( + &self, + shortstatehash: u64, + ) -> Result>> { + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + + let mut result = HashMap::new(); + let mut i = 0; + for compressed in full_state { + let (_, eventid) = self.parse_compressed_state_event(compressed)?; + if let Some(pdu) = self.get_pdu(&eventid)? { + result.insert( + ( + pdu.kind.to_string().into(), + pdu.state_key + .as_ref() + .ok_or_else(|| Error::bad_database("State event has no state key."))? + .clone(), + ), + pdu, + ); + } + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + + Ok(result) + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn state_get_id( + &self, + shortstatehash: u64, + event_type: &StateEventType, + state_key: &str, + ) -> Result>> { + let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { + Some(s) => s, + None => return Ok(None), + }; + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + Ok(full_state + .into_iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + .and_then(|compressed| { + self.parse_compressed_state_event(compressed) + .ok() + .map(|(_, id)| id) + })) + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn state_get( + &self, + shortstatehash: u64, + event_type: &StateEventType, + state_key: &str, + ) -> Result>> { + self.state_get_id(shortstatehash, event_type, state_key)? + .map_or(Ok(None), |event_id| self.get_pdu(&event_id)) + } + + /// Returns the state hash for this pdu. + fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { + self.eventid_shorteventid + .get(event_id.as_bytes())? + .map_or(Ok(None), |shorteventid| { + self.shorteventid_shortstatehash + .get(&shorteventid)? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database( + "Invalid shortstatehash bytes in shorteventid_shortstatehash", + ) + }) + }) + .transpose() + }) + } + + /// Returns the full room state. + async fn room_state_full( + &self, + room_id: &RoomId, + ) -> Result>> { + if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { + self.state_full(current_shortstatehash).await + } else { + Ok(HashMap::new()) + } + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn room_state_get_id( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, + ) -> Result>> { + if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { + self.state_get_id(current_shortstatehash, event_type, state_key) + } else { + Ok(None) + } + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn room_state_get( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, + ) -> Result>> { + if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { + self.state_get(current_shortstatehash, event_type, state_key) + } else { + Ok(None) + } + } diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs new file mode 100644 index 00000000..37814020 --- /dev/null +++ b/src/database/key_value/rooms/state_cache.rs @@ -0,0 +1,8 @@ +impl service::room::state_cache::Data for KeyValueDatabase { + fn mark_as_once_joined(user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + self.roomuseroncejoinedids.insert(&userroom_id, &[])?; + } +} diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs new file mode 100644 index 00000000..52145ced --- /dev/null +++ b/src/database/key_value/rooms/user.rs @@ -0,0 +1,110 @@ +impl service::room::user::Data for KeyValueDatabase { + fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_notificationcount + .insert(&userroom_id, &0_u64.to_be_bytes())?; + self.userroomid_highlightcount + .insert(&userroom_id, &0_u64.to_be_bytes())?; + + Ok(()) + } + + fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_notificationcount + .get(&userroom_id)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid notification count in db.")) + }) + .unwrap_or(Ok(0)) + } + + fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_highlightcount + .get(&userroom_id)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid highlight count in db.")) + }) + .unwrap_or(Ok(0)) + } + + fn associate_token_shortstatehash( + &self, + room_id: &RoomId, + token: u64, + shortstatehash: u64, + ) -> Result<()> { + let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); + + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(&token.to_be_bytes()); + + self.roomsynctoken_shortstatehash + .insert(&key, &shortstatehash.to_be_bytes()) + } + + fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { + let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); + + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(&token.to_be_bytes()); + + self.roomsynctoken_shortstatehash + .get(&key)? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash") + }) + }) + .transpose() + } + + fn get_shared_rooms<'a>( + &'a self, + users: Vec>, + ) -> Result>> + 'a> { + let iterators = users.into_iter().map(move |user_id| { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + + self.userroomid_joined + .scan_prefix(prefix) + .map(|(key, _)| { + let roomid_index = key + .iter() + .enumerate() + .find(|(_, &b)| b == 0xff) + .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? + .0 + + 1; // +1 because the room id starts AFTER the separator + + let room_id = key[roomid_index..].to_vec(); + + Ok::<_, Error>(room_id) + }) + .filter_map(|r| r.ok()) + }); + + // We use the default compare function because keys are sorted correctly (not reversed) + Ok(utils::common_elements(iterators, Ord::cmp) + .expect("users is not empty") + .map(|bytes| { + RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| { + Error::bad_database("Invalid RoomId bytes in userroomid_joined") + })?) + .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) + })) + } +} diff --git a/src/database/key_value/uiaa.rs b/src/database/key_value/uiaa.rs new file mode 100644 index 00000000..4d1dac57 --- /dev/null +++ b/src/database/key_value/uiaa.rs @@ -0,0 +1,82 @@ +impl service::uiaa::Data for KeyValueDatabase { + fn set_uiaa_request( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + request: &CanonicalJsonValue, + ) -> Result<()> { + self.userdevicesessionid_uiaarequest + .write() + .unwrap() + .insert( + (user_id.to_owned(), device_id.to_owned(), session.to_owned()), + request.to_owned(), + ); + + Ok(()) + } + + fn get_uiaa_request( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + ) -> Option { + self.userdevicesessionid_uiaarequest + .read() + .unwrap() + .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) + .map(|j| j.to_owned()) + } + + fn update_uiaa_session( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + uiaainfo: Option<&UiaaInfo>, + ) -> Result<()> { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(session.as_bytes()); + + if let Some(uiaainfo) = uiaainfo { + self.userdevicesessionid_uiaainfo.insert( + &userdevicesessionid, + &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), + )?; + } else { + self.userdevicesessionid_uiaainfo + .remove(&userdevicesessionid)?; + } + + Ok(()) + } + + fn get_uiaa_session( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + ) -> Result { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(session.as_bytes()); + + serde_json::from_slice( + &self + .userdevicesessionid_uiaainfo + .get(&userdevicesessionid)? + .ok_or(Error::BadRequest( + ErrorKind::Forbidden, + "UIAA session does not exist.", + ))?, + ) + .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) + } +} diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs new file mode 100644 index 00000000..a2b76e46 --- /dev/null +++ b/src/service/rooms/state_accessor/data.rs @@ -0,0 +1,51 @@ +pub trait Data { + /// Builds a StateMap by iterating over all keys that start + /// with state_hash, this gives the full state for the given state_hash. + async fn state_full_ids(&self, shortstatehash: u64) -> Result>>; + + async fn state_full( + &self, + shortstatehash: u64, + ) -> Result>>; + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn state_get_id( + &self, + shortstatehash: u64, + event_type: &StateEventType, + state_key: &str, + ) -> Result>>; + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn state_get( + &self, + shortstatehash: u64, + event_type: &StateEventType, + state_key: &str, + ) -> Result>>; + + /// Returns the state hash for this pdu. + fn pdu_shortstatehash(&self, event_id: &EventId) -> Result>; + + /// Returns the full room state. + async fn room_state_full( + &self, + room_id: &RoomId, + ) -> Result>>; + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn room_state_get_id( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, + ) -> Result>>; + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn room_state_get( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, + ) -> Result>>; +} diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index ae26a7c4..28a49a98 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -1,24 +1,18 @@ +mod data; +pub use data::Data; + +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[tracing::instrument(skip(self))] pub async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { - let full_state = self - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - let mut result = BTreeMap::new(); - let mut i = 0; - for compressed in full_state.into_iter() { - let parsed = self.parse_compressed_state_event(compressed)?; - result.insert(parsed.0, parsed.1); - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - Ok(result) + self.db.state_full_ids(shortstatehash) } #[tracing::instrument(skip(self))] @@ -26,36 +20,7 @@ &self, shortstatehash: u64, ) -> Result>> { - let full_state = self - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - - let mut result = HashMap::new(); - let mut i = 0; - for compressed in full_state { - let (_, eventid) = self.parse_compressed_state_event(compressed)?; - if let Some(pdu) = self.get_pdu(&eventid)? { - result.insert( - ( - pdu.kind.to_string().into(), - pdu.state_key - .as_ref() - .ok_or_else(|| Error::bad_database("State event has no state key."))? - .clone(), - ), - pdu, - ); - } - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - - Ok(result) + self.db.state_full(shortstatehash) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -66,23 +31,7 @@ event_type: &StateEventType, state_key: &str, ) -> Result>> { - let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { - Some(s) => s, - None => return Ok(None), - }; - let full_state = self - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - Ok(full_state - .into_iter() - .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - .and_then(|compressed| { - self.parse_compressed_state_event(compressed) - .ok() - .map(|(_, id)| id) - })) + self.db.state_get_id(shortstatehash, event_type, state_key) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -93,26 +42,12 @@ event_type: &StateEventType, state_key: &str, ) -> Result>> { - self.state_get_id(shortstatehash, event_type, state_key)? - .map_or(Ok(None), |event_id| self.get_pdu(&event_id)) + self.db.pdu_state_get(event_id) } /// Returns the state hash for this pdu. pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { - self.eventid_shorteventid - .get(event_id.as_bytes())? - .map_or(Ok(None), |shorteventid| { - self.shorteventid_shortstatehash - .get(&shorteventid)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database( - "Invalid shortstatehash bytes in shorteventid_shortstatehash", - ) - }) - }) - .transpose() - }) + self.db.pdu_shortstatehash(event_id) } /// Returns the full room state. @@ -121,11 +56,7 @@ &self, room_id: &RoomId, ) -> Result>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - self.state_full(current_shortstatehash).await - } else { - Ok(HashMap::new()) - } + self.db.room_state_full(room_id) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -136,11 +67,7 @@ event_type: &StateEventType, state_key: &str, ) -> Result>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - self.state_get_id(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } + self.db.room_state_get_id(room_id, event_type, state_key) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -151,10 +78,6 @@ event_type: &StateEventType, state_key: &str, ) -> Result>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - self.state_get(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } + self.db.room_state_get(room_id, event_type, state_key) } - +} diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs new file mode 100644 index 00000000..166d4f6b --- /dev/null +++ b/src/service/rooms/state_cache/data.rs @@ -0,0 +1,3 @@ +pub trait Data { + fn mark_as_once_joined(user_id: &UserId, room_id: &RoomId) -> Result<()>; +} diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index e7f457e6..778679de 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,4 +1,13 @@ +mod data; +pub use data::Data; +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { /// Update current membership data. #[tracing::instrument(skip(self, last_state, db))] pub fn update_membership( @@ -25,10 +34,6 @@ serverroom_id.push(0xff); serverroom_id.extend_from_slice(room_id.as_bytes()); - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - let mut roomuser_id = room_id.as_bytes().to_vec(); roomuser_id.push(0xff); roomuser_id.extend_from_slice(user_id.as_bytes()); @@ -38,7 +43,7 @@ // Check if the user never joined this room if !self.once_joined(user_id, room_id)? { // Add the user ID to the join list then - self.roomuseroncejoinedids.insert(&userroom_id, &[])?; + self.db.mark_as_once_joined(user_id, room_id)?; // Check if the room has a predecessor if let Some(predecessor) = self @@ -116,10 +121,6 @@ } } - if update_joined_count { - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } self.userroomid_joined.insert(&userroom_id, &[])?; self.roomuserid_joined.insert(&roomuser_id, &[])?; self.userroomid_invitestate.remove(&userroom_id)?; @@ -150,10 +151,6 @@ return Ok(()); } - if update_joined_count { - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } self.userroomid_invitestate.insert( &userroom_id, &serde_json::to_vec(&last_state.unwrap_or_default()) @@ -167,16 +164,6 @@ self.roomuserid_leftcount.remove(&roomuser_id)?; } MembershipState::Leave | MembershipState::Ban => { - if update_joined_count - && self - .room_members(room_id) - .chain(self.room_members_invited(room_id)) - .filter_map(|r| r.ok()) - .all(|u| u.server_name() != user_id.server_name()) - { - self.roomserverids.remove(&roomserver_id)?; - self.serverroomids.remove(&serverroom_id)?; - } self.userroomid_leftstate.insert( &userroom_id, &serde_json::to_vec(&Vec::>::new()).unwrap(), @@ -231,36 +218,6 @@ .unwrap() .insert(room_id.to_owned(), Arc::new(real_users)); - for old_joined_server in self.room_servers(room_id).filter_map(|r| r.ok()) { - if !joined_servers.remove(&old_joined_server) { - // Server not in room anymore - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xff); - roomserver_id.extend_from_slice(old_joined_server.as_bytes()); - - let mut serverroom_id = old_joined_server.as_bytes().to_vec(); - serverroom_id.push(0xff); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.remove(&roomserver_id)?; - self.serverroomids.remove(&serverroom_id)?; - } - } - - // Now only new servers are in joined_servers anymore - for server in joined_servers { - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xff); - roomserver_id.extend_from_slice(server.as_bytes()); - - let mut serverroom_id = server.as_bytes().to_vec(); - serverroom_id.push(0xff); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } - self.appservice_in_room_cache .write() .unwrap() @@ -714,4 +671,4 @@ Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) } - +} diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs new file mode 100644 index 00000000..47a44eef --- /dev/null +++ b/src/service/rooms/user/data.rs @@ -0,0 +1,21 @@ +pub trait Data { + fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + + fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; + + fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; + + fn associate_token_shortstatehash( + &self, + room_id: &RoomId, + token: u64, + shortstatehash: u64, + ) -> Result<()>; + + fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result>; + + fn get_shared_rooms<'a>( + &'a self, + users: Vec>, + ) -> Result>> + 'a>; +} diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 976ab5b3..45fb3551 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -1,46 +1,23 @@ +mod data; +pub use data::Data; - #[tracing::instrument(skip(self))] +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_notificationcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; - self.userroomid_highlightcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; - - Ok(()) + self.db.reset_notification_counts(user_id, room_id) } - #[tracing::instrument(skip(self))] pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_notificationcount - .get(&userroom_id)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid notification count in db.")) - }) - .unwrap_or(Ok(0)) + self.db.notification_count(user_id, room_id) } - #[tracing::instrument(skip(self))] pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_highlightcount - .get(&userroom_id)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid highlight count in db.")) - }) - .unwrap_or(Ok(0)) + self.db.highlight_count(user_id, room_id) } pub fn associate_token_shortstatehash( @@ -49,66 +26,17 @@ token: u64, shortstatehash: u64, ) -> Result<()> { - let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); - - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); - - self.roomsynctoken_shortstatehash - .insert(&key, &shortstatehash.to_be_bytes()) + self.db.associate_token_shortstatehash(user_id, room_id) } pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); - - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); - - self.roomsynctoken_shortstatehash - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash") - }) - }) - .transpose() + self.db.get_token_shortstatehash(room_id, token) } - #[tracing::instrument(skip(self))] pub fn get_shared_rooms<'a>( &'a self, users: Vec>, ) -> Result>> + 'a> { - let iterators = users.into_iter().map(move |user_id| { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.userroomid_joined - .scan_prefix(prefix) - .map(|(key, _)| { - let roomid_index = key - .iter() - .enumerate() - .find(|(_, &b)| b == 0xff) - .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? - .0 - + 1; // +1 because the room id starts AFTER the separator - - let room_id = key[roomid_index..].to_vec(); - - Ok::<_, Error>(room_id) - }) - .filter_map(|r| r.ok()) - }); - - // We use the default compare function because keys are sorted correctly (not reversed) - Ok(utils::common_elements(iterators, Ord::cmp) - .expect("users is not empty") - .map(|bytes| { - RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| { - Error::bad_database("Invalid RoomId bytes in userroomid_joined") - })?) - .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) - })) + self.db.get_shared_rooms(users) } - +} diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs new file mode 100644 index 00000000..40e69bda --- /dev/null +++ b/src/service/uiaa/data.rs @@ -0,0 +1,31 @@ +pub trait Data { + fn set_uiaa_request( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + request: &CanonicalJsonValue, + ) -> Result<()>; + + fn get_uiaa_request( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + ) -> Option; + + fn update_uiaa_session( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + uiaainfo: Option<&UiaaInfo>, + ) -> Result<()>; + + fn get_uiaa_session( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + ) -> Result; +} diff --git a/src/service/uiaa.rs b/src/service/uiaa/mod.rs similarity index 60% rename from src/service/uiaa.rs rename to src/service/uiaa/mod.rs index 12373139..593ea5f2 100644 --- a/src/service/uiaa.rs +++ b/src/service/uiaa/mod.rs @@ -1,31 +1,13 @@ -use std::{ - collections::BTreeMap, - sync::{Arc, RwLock}, -}; +mod data; +pub use data::Data; -use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; -use ruma::{ - api::client::{ - error::ErrorKind, - uiaa::{ - AuthType, IncomingAuthData, IncomingPassword, - IncomingUserIdentifier::UserIdOrLocalpart, UiaaInfo, - }, - }, - signatures::CanonicalJsonValue, - DeviceId, UserId, -}; -use tracing::error; +use crate::service::*; -use super::abstraction::Tree; - -pub struct Uiaa { - pub(super) userdevicesessionid_uiaainfo: Arc, // User-interactive authentication - pub(super) userdevicesessionid_uiaarequest: - RwLock, Box, String), CanonicalJsonValue>>, +pub struct Service { + db: D, } -impl Uiaa { +impl Service<_> { /// Creates a new Uiaa session. Make sure the session token is unique. pub fn create( &self, @@ -144,35 +126,13 @@ impl Uiaa { Ok((true, uiaainfo)) } - fn set_uiaa_request( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - request: &CanonicalJsonValue, - ) -> Result<()> { - self.userdevicesessionid_uiaarequest - .write() - .unwrap() - .insert( - (user_id.to_owned(), device_id.to_owned(), session.to_owned()), - request.to_owned(), - ); - - Ok(()) - } - pub fn get_uiaa_request( &self, user_id: &UserId, device_id: &DeviceId, session: &str, ) -> Option { - self.userdevicesessionid_uiaarequest - .read() - .unwrap() - .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) - .map(|j| j.to_owned()) + self.db.get_uiaa_request(user_id, device_id, session) } fn update_uiaa_session( @@ -182,46 +142,6 @@ impl Uiaa { session: &str, uiaainfo: Option<&UiaaInfo>, ) -> Result<()> { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - if let Some(uiaainfo) = uiaainfo { - self.userdevicesessionid_uiaainfo.insert( - &userdevicesessionid, - &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), - )?; - } else { - self.userdevicesessionid_uiaainfo - .remove(&userdevicesessionid)?; - } - - Ok(()) - } - - fn get_uiaa_session( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - ) -> Result { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - serde_json::from_slice( - &self - .userdevicesessionid_uiaainfo - .get(&userdevicesessionid)? - .ok_or(Error::BadRequest( - ErrorKind::Forbidden, - "UIAA session does not exist.", - ))?, - ) - .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) + self.db.update_uiaa_session(user_id, device_id, session, uiaainfo) } }